Python迭代器和生成器

学习时长:2.5小时 难度:中级 练习:4个

迭代器基础

什么是迭代器?

迭代器是一个可以记住遍历位置的对象,它提供了一种逐个访问集合元素的方法,而不需要暴露集合的内部表示。

可迭代对象


# Python中的常见可迭代对象
numbers = [1, 2, 3, 4, 5]  # 列表
chars = "Hello"            # 字符串
user_info = {"name": "张三", "age": 25}  # 字典

# 使用for循环遍历
for num in numbers:
    print(num)

# 使用iter()函数获取迭代器
iterator = iter(numbers)
print(next(iterator))  # 1
print(next(iterator))  # 2
                

迭代器协议

自定义迭代器

class CountDown:
    def __init__(self, start):
        self.start = start

    def __iter__(self):
        return self

    def __next__(self):
        if self.start <= 0:
            raise StopIteration
        self.start -= 1
        return self.start + 1

# 使用自定义迭代器
counter = CountDown(5)
for num in counter:
    print(num)  # 输出:5, 4, 3, 2, 1
                    
注意事项

迭代器是一次性的,用完就需要重新创建。不要尝试重复使用已经遍历完的迭代器。

生成器基础

什么是生成器?

生成器是一种特殊的迭代器,它使用yield语句来生成值。生成器函数在每次调用时会保存其执行状态。


# 生成器函数
def fibonacci(n):
    a, b = 0, 1
    for _ in range(n):
        yield a
        a, b = b, a + b

# 使用生成器
fib = fibonacci(5)
for num in fib:
    print(num)  # 输出:0, 1, 1, 2, 3

# 生成无限序列
def infinite_counter(start=0):
    while True:
        yield start
        start += 1
                

生成器表达式

列表推导式 vs 生成器表达式

# 列表推导式(一次性生成所有元素)
squares_list = [x**2 for x in range(1000000)]  # 占用大量内存

# 生成器表达式(按需生成元素)
squares_gen = (x**2 for x in range(1000000))   # 内存高效

# 使用生成器表达式
for square in squares_gen:
    if square > 100:
        break
    print(square)
                    

高级生成器特性

send()和yield表达式


def counter_with_step():
    count = 0
    step = 1
    while True:
        # yield可以接收send()发送的值
        new_step = yield count
        if new_step is not None:
            step = new_step
        count += step

# 使用send()
gen = counter_with_step()
print(next(gen))      # 0
print(gen.send(2))    # 2
print(next(gen))      # 4
                

生成器的其他方法

  • close(): 停止生成器
  • throw(): 向生成器抛出异常

实际应用场景

1. 大文件处理

使用生成器处理大型日志文件,避免一次性将整个文件加载到内存中:


def parse_log_file(file_path):
    """分析大型日志文件的生成器"""
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            # 解析每行日志
            if '[ERROR]' in line:
                timestamp = line[0:19]
                error_msg = line[line.find('[ERROR]')+7:].strip()
                yield {
                    'timestamp': timestamp,
                    'level': 'ERROR',
                    'message': error_msg
                }

# 使用示例
for error_log in parse_log_file('app.log'):
    print(f"时间: {error_log['timestamp']}")
    print(f"错误信息: {error_log['message']}\n")
                    
2. 数据流处理管道

使用生成器创建数据处理管道,实现ETL(提取、转换、加载)过程:


def extract_data(csv_file):
    """从CSV文件提取数据"""
    with open(csv_file, 'r', encoding='utf-8') as f:
        for line in f:
            yield line.strip().split(',')

def transform_data(rows):
    """转换数据:计算学生平均分"""
    for row in rows:
        if len(row) >= 4:  # 假设格式:姓名,语文,数学,英语
            name = row[0]
            scores = [float(score) for score in row[1:]]
            average = sum(scores) / len(scores)
            yield (name, average)

def load_data(results):
    """加载数据:将结果写入新文件"""
    with open('results.txt', 'w', encoding='utf-8') as f:
        for name, avg in results:
            f.write(f"{name}: {avg:.2f}\n")
            yield f"{name}的平均分处理完成"

# 使用示例
data = extract_data('scores.csv')
transformed = transform_data(data)
for status in load_data(transformed):
    print(status)
                    
3. 内存优化

使用生成器处理大量图片文件,实现批量图片处理:


from PIL import Image
import os

def image_processor(directory):
    """处理目录中的所有图片"""
    for filename in os.listdir(directory):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(directory, filename)
            try:
                with Image.open(image_path) as img:
                    # 调整图片大小
                    resized = img.resize((800, 600))
                    # 保存处理后的图片
                    new_path = os.path.join(directory, f'processed_{filename}')
                    resized.save(new_path)
                    yield f"已处理: {filename}"
            except Exception as e:
                yield f"处理失败 {filename}: {str(e)}"

# 使用示例
for status in image_processor('images/'):
    print(status)
                    
4. 实时数据处理

使用生成器实现实时数据监控和处理:


import time
from datetime import datetime

def sensor_data_simulator():
    """模拟传感器数据流"""
    while True:
        yield {
            'timestamp': datetime.now(),
            'temperature': random.uniform(20, 30),
            'humidity': random.uniform(40, 60)
        }
        time.sleep(1)  # 模拟每秒产生一个数据点

def alert_monitor(data_stream, temp_threshold=28):
    """监控温度异常"""
    for data in data_stream:
        if data['temperature'] > temp_threshold:
            yield {
                'time': data['timestamp'],
                'alert': f"温度过高: {data['temperature']:.1f}°C"
            }

def data_logger(alert_stream):
    """记录警报信息"""
    with open('temperature_alerts.log', 'a', encoding='utf-8') as f:
        for alert in alert_stream:
            log_entry = f"[{alert['time']}] {alert['alert']}\n"
            f.write(log_entry)
            yield "警报已记录"

# 使用示例
sensor_data = sensor_data_simulator()
alerts = alert_monitor(sensor_data)
for status in data_logger(alerts):
    print(status)
    if input("继续监控?(y/n): ").lower() != 'y':
        break
                    

练习与实践

练习1:自定义范围迭代器

实现一个类似于range()的自定义迭代器,支持步长和负数范围。


class MyRange:
    """
    实现一个支持步长的范围迭代器
    用法:
    for i in MyRange(1, 10, 2):
        print(i)  # 输出:1, 3, 5, 7, 9
    """
    def __init__(self, start, stop, step=1):
        # 在此编写你的代码
        pass
                    

练习2:文件读取生成器

创建一个生成器函数,可以分块读取大文件,每次产生指定大小的数据块。


def read_in_chunks(file_path, chunk_size=1024):
    """
    分块读取文件的生成器函数
    """
    # 在此编写你的代码
    pass
                    

练习3:数据流处理管道

使用生成器创建一个数据处理管道,包含多个转换步骤。


def number_generator(n):
    """生成1到n的数字"""
    pass

def filter_even(numbers):
    """过滤出偶数"""
    pass

def multiply_by_two(numbers):
    """将数字乘以2"""
    pass

# 使用示例
numbers = number_generator(10)
even_numbers = filter_even(numbers)
result = multiply_by_two(even_numbers)