itertools —- 为高效循环创建迭代器的函数


本模块实现一系列 iterator ,这些迭代器受到APL,Haskell和SML的启发。为了适用于Python,它们都被重新写过。

本模块标准化了一个快速、高效利用内存的核心工具集,这些工具本身或组合都很有用。它们一起形成了“迭代器代数”,这使得在纯Python中有可能创建简洁又高效的专用工具。

例如,SML有一个制表工具: tabulate(f),它可产生一个序列 f(0), f(1), ...。在Python中可以组合 map()count() 实现: map(f, count())

这些工具及其内置对应物也能很好地配合 operator 模块中的快速函数来使用。 例如,乘法运算符可以被映射到两个向量之间执行高效的点积: sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))

无穷迭代器:

迭代器

实参

结果

示例

count()

[start[, step]]

start, start+step, start+2*step, …

count(10) → 10 11 12 13 14 …

cycle()

p

p0, p1, … plast, p0, p1, …

cycle(‘ABCD’) → A B C D A B C D …

repeat()

elem [,n]

elem, elem, elem, … 重复无限次或n次

repeat(10, 3) → 10 10 10

根据最短输入序列长度停止的迭代器:

迭代器

实参

结果

示例

accumulate()

p [,func]

p0, p0+p1, p0+p1+p2, …

accumulate([1,2,3,4,5]) → 1 3 6 10 15

batched()

p, n

(p0, p1, …, p_n-1), …

batched(‘ABCDEFG’, n=3) → ABC DEF G

chain()

p, q, …

p0, p1, … plast, q0, q1, …

chain(‘ABC’, ‘DEF’) → A B C D E F

chain.from_iterable()

iterable — 可迭代对象

p0, p1, … plast, q0, q1, …

chain.from_iterable([‘ABC’, ‘DEF’]) → A B C D E F

compress()

data, selectors

(d[0] if s[0]), (d[1] if s[1]), …

compress(‘ABCDEF’, [1,0,1,0,1,1]) → A C E F

dropwhile()

predicate, seq

seq[n], seq[n+1], 从 predicate 未通过时开始

dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8

filterfalse()

predicate, seq

predicate(elem) 未通过的 seq 元素

filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8

groupby()

iterable[, key]

根据key(v)值分组的迭代器

groupby([‘A’,’B’,’DEF’], len) → (1, A B) (3, DEF)

islice()

seq, [start,] stop [, step]

seq[start:stop:step]中的元素

islice(‘ABCDEFG’, 2, None) → C D E F G

pairwise()

iterable — 可迭代对象

(p[0], p[1]), (p[1], p[2])

pairwise(‘ABCDEFG’) → AB BC CD DE EF FG

starmap()

func, seq

func(seq[0]), func(seq[1]), …

starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000

takewhile()

predicate, seq

seq[0], seq[1], 直到 predicate 未通过

takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4

tee()

it, n

it1, it2, … itn 将一个迭代器拆分为n个迭代器

zip_longest()

p, q, …

(p[0], q[0]), (p[1], q[1]), …

zip_longest(‘ABCD’, ‘xy’, fillvalue=’-‘) → Ax By C- D-

排列组合迭代器:

迭代器

实参

结果

product()

p, q, … [repeat=1]

笛卡尔积,相当于嵌套的for循环

permutations()

p[, r]

长度r元组,所有可能的排列,无重复元素

combinations()

p, r

长度r元组,有序,无重复元素

combinations_with_replacement()

p, r

长度r元组,有序,元素可重复

例子

结果

product(‘ABCD’, repeat=2)

AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD

permutations(‘ABCD’, 2)

AB AC AD BA BC BD CA CB CD DA DB DC

combinations(‘ABCD’, 2)

AB AC AD BC BD CD

combinations_with_replacement(‘ABCD’, 2)

AA AB AC AD BB BC BD CC CD DD

Itertool 函数

The following functions all construct and return iterators. Some provide streams of infinite length, so they should only be accessed by functions or loops that truncate the stream.

itertools.accumulate(iterable[, function, *, initial=None])

创建一个返回累积汇总值或来自其他双目运算函数的累积结果的迭代器。

function 默认为加法运算。 function 应当接受两个参数,即一个累积汇总值和一个来自 iterable 的值。

如果提供了 initial 值,将从该值开始累积并且输出将比输入可迭代对象多一个元素。

大致相当于:

  1. def accumulate(iterable, function=operator.add, *, initial=None):
  2. 'Return running totals'
  3. # accumulate([1,2,3,4,5]) → 1 3 6 10 15
  4. # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115
  5. # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120
  6. iterator = iter(iterable)
  7. total = initial
  8. if initial is None:
  9. try:
  10. total = next(iterator)
  11. except StopIteration:
  12. return
  13. yield total
  14. for element in iterator:
  15. total = function(total, element)
  16. yield total

To compute a running minimum, set function to min(). For a running maximum, set function to max(). Or for a running product, set function to operator.mul(). To build an amortization table, accumulate the interest and apply payments:

  1. >>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
  2. >>> list(accumulate(data, max)) # running maximum
  3. [3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
  4. >>> list(accumulate(data, operator.mul)) # running product
  5. [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]
  6. # Amortize a 5% loan of 1000 with 10 annual payments of 90
  7. >>> update = lambda balance, payment: round(balance * 1.05) - payment
  8. >>> list(accumulate(repeat(90, 10), update, initial=1_000))
  9. [1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]

参考一个类似函数 functools.reduce() ,它只返回一个最终累积值。

Added in version 3.2.

在 3.3 版本发生变更: 添加了可选的 function 形参。

在 3.8 版本发生变更: 添加了可选的 initial 形参。

itertools.batched(iterable, n, *, strict=False)

来自 iterable 的长度为 n 元组形式的批次数据。 最后一个批次可能短于 n

如果 strict 为真值,将在最终的批次短于 n 时引发 ValueError

循环处理输入可迭代对象并将数据积累为长度至多为 n 的元组。 输入将被惰性地消耗,能填满一个批次即可。 结果将在批次填满或输入可迭代对象被耗尽时产生:

  1. >>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
  2. >>> unflattened = list(batched(flattened_data, 2))
  3. >>> unflattened
  4. [('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]

大致相当于:

  1. def batched(iterable, n, *, strict=False):
  2. # batched('ABCDEFG', 3) → ABC DEF G
  3. if n < 1:
  4. raise ValueError('n must be at least one')
  5. iterator = iter(iterable)
  6. while batch := tuple(islice(iterator, n)):
  7. if strict and len(batch) != n:
  8. raise ValueError('batched(): incomplete batch')
  9. yield batch

Added in version 3.12.

在 3.13 版本发生变更: 增加了 strict 选项。

itertools.chain(*iterables)

Make an iterator that returns elements from the first iterable until it is exhausted, then proceeds to the next iterable, until all of the iterables are exhausted. This combines multiple data sources into a single iterator. Roughly equivalent to:

  1. def chain(*iterables):
  2. # chain('ABC', 'DEF') → A B C D E F
  3. for iterable in iterables:
  4. yield from iterable

classmethod chain.from_iterable(iterable)

构建类似 chain() 迭代器的另一个选择。从一个单独的可迭代参数中得到链式输入,该参数是延迟计算的。大致相当于:

  1. def from_iterable(iterables):
  2. # chain.from_iterable(['ABC', 'DEF']) → A B C D E F
  3. for iterable in iterables:
  4. yield from iterable

itertools.combinations(iterable, r)

返回由输入 iterable 中元素组成长度为 r 的子序列。

输出结果是 product() 的子序列其中只保留属于 iterable 的子序列的条目。 输出的长度由 math.comb() 给出,该函数在 0 ≤ r ≤ n 时为计算 n! / r! / (n - r)! 而在 r > n 时为 0。

组合元组是根据输入的 iterable 的顺序以词典排序方式发出的。 如果输入的 iterable 是已排序的,则输出的元组将按排序后的顺序产生。

元素是的唯一性是基于它们的位置,而不是它们的值。 如果输入的元素都是唯一的,则将每个组合中将不会有重复的值。

大致相当于:

  1. def combinations(iterable, r):
  2. # combinations('ABCD', 2) → AB AC AD BC BD CD
  3. # combinations(range(4), 3) → 012 013 023 123
  4. pool = tuple(iterable)
  5. n = len(pool)
  6. if r > n:
  7. return
  8. indices = list(range(r))
  9. yield tuple(pool[i] for i in indices)
  10. while True:
  11. for i in reversed(range(r)):
  12. if indices[i] != i + n - r:
  13. break
  14. else:
  15. return
  16. indices[i] += 1
  17. for j in range(i+1, r):
  18. indices[j] = indices[j-1] + 1
  19. yield tuple(pool[i] for i in indices)

itertools.combinations_with_replacement(iterable, r)

返回由输入 iterable 中元素组成的长度为 r 的子序列,允许每个元素可重复出现。

输出是 product() 的子序列,其中仅保留也属于 iterable 的子序列的条目(可能有重复的元素)。 当 n > 0 时返回的子序列数量为 (n + r - 1)! / r! / (n - 1)!

组合元组是根据输入的 iterable 的顺序以词典排序方式发出的。 如果输入的 iterable 是已排序的,则输出的元组将按已排序的顺序产生。

元素的唯一性是基于它们的位置,而不是它们的值。 如果输入的元素都是唯一的,则生成的组合也将是唯一的。

大致相当于:

  1. def combinations_with_replacement(iterable, r):
  2. # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC
  3. pool = tuple(iterable)
  4. n = len(pool)
  5. if not n and r:
  6. return
  7. indices = [0] * r
  8. yield tuple(pool[i] for i in indices)
  9. while True:
  10. for i in reversed(range(r)):
  11. if indices[i] != n - 1:
  12. break
  13. else:
  14. return
  15. indices[i:] = [indices[i] + 1] * (r - i)
  16. yield tuple(pool[i] for i in indices)

Added in version 3.1.

itertools.compress(data, selectors)

创建一个迭代器,它返回来自 dataselectors 中对应元素为真值的元素。 当 dataselectors 可迭代对象被耗尽时将停止。 大致相当于:

  1. def compress(data, selectors):
  2. # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F
  3. return (datum for datum, selector in zip(data, selectors) if selector)

Added in version 3.1.

itertools.count(start=0, step=1)

创建一个迭代器,它返回从 start 开始的均匀间隔的值。 可与 map() 配合使用以生成连续的数据点或与 zip() 配合使用以添加序列数字。 大致相当于:

  1. def count(start=0, step=1):
  2. # count(10) → 10 11 12 13 14 ...
  3. # count(2.5, 0.5) → 2.5 3.0 3.5 ...
  4. n = start
  5. while True:
  6. yield n
  7. n += step

当对浮点数计数时,替换为乘法代码有时会有更高的精度,例如: (start + step * i for i in count())

在 3.1 版本发生变更: 增加参数 step ,允许非整型。

itertools.cycle(iterable)

创建一个迭代器,它返回来自 iterable 中的元素并保存每个元素的拷贝。 当 iterable 耗尽时,返回来自已保存拷贝中的元素。 将无限重复进行。 大致相当于:

  1. def cycle(iterable):
  2. # cycle('ABCD') → A B C D A B C D A B C D ...
  3. saved = []
  4. for element in iterable:
  5. yield element
  6. saved.append(element)
  7. while saved:
  8. for element in saved:
  9. yield element

这个迭代工具可能需要很大的辅助存储(取决于 iterable 的长度)。

itertools.dropwhile(predicate, iterable)

创建一个迭代器,它将丢弃来自 iterablepredicate 为真值的元素然后返回每个元素。 大致相当于:

  1. def dropwhile(predicate, iterable):
  2. # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8
  3. iterator = iter(iterable)
  4. for x in iterator:
  5. if not predicate(x):
  6. yield x
  7. break
  8. for x in iterator:
  9. yield x

请注意它将不产生 任何 输出直到 predicate 首次变为假值,所以此迭代工具可能具有很长的启动时间。

itertools.filterfalse(predicate, iterable)

创建一个迭代器,它过滤来自 iterable 的元素从而只返回其中 predicate 返回假值的元素。 如果 predicateNone,则返回本身为假值的条目。 大致相当于:

  1. def filterfalse(predicate, iterable):
  2. # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8
  3. if predicate is None:
  4. predicate = bool
  5. for x in iterable:
  6. if not predicate(x):
  7. yield x

itertools.groupby(iterable, key=None)

创建一个迭代器,返回 iterable 中连续的键和组。key 是一个计算元素键值函数。如果未指定或为 Nonekey 缺省为恒等函数(identity function),返回元素不变。一般来说,iterable 需用同一个键值函数预先排序。

groupby() 操作类似于Unix中的 uniq。当每次 key 函数产生的键值改变时,迭代器会分组或生成一个新组(这就是为什么通常需要使用同一个键值函数先对数据进行排序)。这种行为与SQL的GROUP BY操作不同,SQL的操作会忽略输入的顺序将相同键值的元素分在同组中。

返回的组本身也是一个迭代器,它与 groupby() 共享底层的可迭代对象。因为源是共享的,当 groupby() 对象向后迭代时,前一个组将消失。因此如果稍后还需要返回结果,可保存为列表:

  1. groups = []
  2. uniquekeys = []
  3. data = sorted(data, key=keyfunc)
  4. for k, g in groupby(data, keyfunc):
  5. groups.append(list(g)) # 将 group 迭代器以列表形式保存
  6. uniquekeys.append(k)

groupby() 大致相当于:

  1. def groupby(iterable, key=None):
  2. # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
  3. # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D
  4. keyfunc = (lambda x: x) if key is None else key
  5. iterator = iter(iterable)
  6. exhausted = False
  7. def _grouper(target_key):
  8. nonlocal curr_value, curr_key, exhausted
  9. yield curr_value
  10. for curr_value in iterator:
  11. curr_key = keyfunc(curr_value)
  12. if curr_key != target_key:
  13. return
  14. yield curr_value
  15. exhausted = True
  16. try:
  17. curr_value = next(iterator)
  18. except StopIteration:
  19. return
  20. curr_key = keyfunc(curr_value)
  21. while not exhausted:
  22. target_key = curr_key
  23. curr_group = _grouper(target_key)
  24. yield curr_key, curr_group
  25. if curr_key == target_key:
  26. for _ in curr_group:
  27. pass

itertools.islice(iterable, stop)

itertools.islice(iterable, start, stop[, step])

创建一个迭代器,它返回 iterable 的选定元素。 效果与序列切片类似但不支持负的 start, stopstep 值。

如果 start 为零或为 None,迭代将从零开始。 在其他情况下,iterable 中的元素将被跳过直至到达 start

If stop is None, iteration continues until the input is exhausted, if at all. Otherwise, it stops at the specified position.

如果 stepNone,则步长默认为一。 元素将被逐一返回除非 step 被设为大于一的数,此情况将导致部分条目被跳过。

大致相当于:

  1. def islice(iterable, *args):
  2. # islice('ABCDEFG', 2) → A B
  3. # islice('ABCDEFG', 2, 4) → C D
  4. # islice('ABCDEFG', 2, None) → C D E F G
  5. # islice('ABCDEFG', 0, None, 2) → A C E G
  6. s = slice(*args)
  7. start = 0 if s.start is None else s.start
  8. stop = s.stop
  9. step = 1 if s.step is None else s.step
  10. if start < 0 or (stop is not None and stop < 0) or step <= 0:
  11. raise ValueError
  12. indices = count() if stop is None else range(max(start, stop))
  13. next_i = start
  14. for i, element in zip(indices, iterable):
  15. if i == next_i:
  16. yield element
  17. next_i += step

If the input is an iterator, then fully consuming the islice advances the input iterator by max(start, stop) steps regardless of the step value.

itertools.pairwise(iterable)

返回从输入 iterable 中获取的连续重叠对。

输出迭代器中 2 元组的数量将比输入的数量少一个。 如果输入可迭代对象中少于两个值则它将为空。

大致相当于:

  1. def pairwise(iterable):
  2. # pairwise('ABCDEFG') → AB BC CD DE EF FG
  3. iterator = iter(iterable)
  4. a = next(iterator, None)
  5. for b in iterator:
  6. yield a, b
  7. a = b

Added in version 3.10.

itertools.permutations(iterable, r=None)

根据 iterable 返回连续的 r 长度 元素的排列

如果 r 未指定或为 Noner 默认设置为 iterable 的长度,这种情况下,生成所有全长排列。

输出结果是 product() 的子序列并已过滤掉其中的重复元素。 输出的长度由 math.perm() 给出,它在 0 ≤ r ≤ n 时为计算 n! / (n - r)! 而在 r > n 时则为零。

排列元组是根据输入的 iterable 的顺序以词典排序的形式发出的。 如果输入的 iterable 是已排序的,则输出的元组将按已排序的顺序产生。

元素的唯一性是基于它们的位置,而不是它们的值。 如果输入的元素都是唯一的,则在排列中就不会有重复的元素。

大致相当于:

  1. def permutations(iterable, r=None):
  2. # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC
  3. # permutations(range(3)) → 012 021 102 120 201 210
  4. pool = tuple(iterable)
  5. n = len(pool)
  6. r = n if r is None else r
  7. if r > n:
  8. return
  9. indices = list(range(n))
  10. cycles = list(range(n, n-r, -1))
  11. yield tuple(pool[i] for i in indices[:r])
  12. while n:
  13. for i in reversed(range(r)):
  14. cycles[i] -= 1
  15. if cycles[i] == 0:
  16. indices[i:] = indices[i+1:] + indices[i:i+1]
  17. cycles[i] = n - i
  18. else:
  19. j = cycles[i]
  20. indices[i], indices[-j] = indices[-j], indices[i]
  21. yield tuple(pool[i] for i in indices[:r])
  22. break
  23. else:
  24. return

itertools.product(*iterables, repeat=1)

Cartesian product of the input iterables.

大致相当于生成器表达式中的嵌套循环。例如, product(A, B)((x,y) for x in A for y in B) 返回结果一样。

嵌套循环像里程表那样循环变动,每次迭代时将最右侧的元素向后迭代。这种模式形成了一种字典序,因此如果输入的可迭代对象是已排序的,笛卡尔积元组依次序发出。

要计算可迭代对象自身的笛卡尔积,将可选参数 repeat 设定为要重复的次数。例如,product(A, repeat=4)product(A, A, A, A) 是一样的。

该函数大致相当于下面的代码,只不过实际实现方案不会在内存中创建中间结果。

  1. def product(*iterables, repeat=1):
  2. # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
  3. # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111
  4. if repeat < 0:
  5. raise ValueError('repeat argument cannot be negative')
  6. pools = [tuple(pool) for pool in iterables] * repeat
  7. result = [[]]
  8. for pool in pools:
  9. result = [x+[y] for x in result for y in pool]
  10. for prod in result:
  11. yield tuple(prod)

product() 运行之前,它会完全耗尽输入的可迭代对象,在内存中保留值的临时池以生成结果积。 相应地,它只适用于有限的输入。

itertools.repeat(object[, times])

创建一个持续地返回 object 的迭代器。 将会无限期地运行除非指定了 times 参数。

大致相当于:

  1. def repeat(object, times=None):
  2. # repeat(10, 3) → 10 10 10
  3. if times is None:
  4. while True:
  5. yield object
  6. else:
  7. for i in range(times):
  8. yield object

repeat 的一个常见用途是向 mapzip 提供一个常量值的流:

  1. >>> list(map(pow, range(10), repeat(2)))
  2. [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

itertools.starmap(function, iterable)

创建一个迭代器,它使用从 iterable 获取的参数来计算 function。 当参数形参已被“预先 zip”为元组时可代替 map() 来使用。

map()starmap() 之间的区别类似于 function(a,b)function(*c) 之间的差异。 大致相当于:

  1. def starmap(function, iterable):
  2. # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000
  3. for args in iterable:
  4. yield function(*args)

itertools.takewhile(predicate, iterable)

创建一个迭代器,它返回来自 iterablepredicate 为真值的元素。 大致相当于:

  1. def takewhile(predicate, iterable):
  2. # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4
  3. for x in iterable:
  4. if not predicate(x):
  5. break
  6. yield x

请注意,第一个未能满足 predicate 条件的元素将从输入迭代器中消耗掉并且没有办法访问它。 当应用程序想在 takewhile 运行到耗尽后进一步消耗输入迭代器时这可能会造成问题。 要绕过这个问题,可以考虑改用 more-iterools before_and_after()

itertools.tee(iterable, n=2)

从一个可迭代对象中返回 n 个独立的迭代器。

大致相当于:

  1. def tee(iterable, n=2):
  2. if n < 0:
  3. raise ValueError
  4. if n == 0:
  5. return ()
  6. iterator = _tee(iterable)
  7. result = [iterator]
  8. for _ in range(n - 1):
  9. result.append(_tee(iterator))
  10. return tuple(result)
  11. class _tee:
  12. def __init__(self, iterable):
  13. it = iter(iterable)
  14. if isinstance(it, _tee):
  15. self.iterator = it.iterator
  16. self.link = it.link
  17. else:
  18. self.iterator = it
  19. self.link = [None, None]
  20. def __iter__(self):
  21. return self
  22. def __next__(self):
  23. link = self.link
  24. if link[1] is None:
  25. link[0] = next(self.iterator)
  26. link[1] = [None, None]
  27. value, self.link = link
  28. return value

When the input iterable is already a tee iterator object, all members of the return tuple are constructed as if they had been produced by the upstream tee() call. This “flattening step” allows nested tee() calls to share the same underlying data chain and to have a single update step rather than a chain of calls.

The flattening property makes tee iterators efficiently peekable:

  1. def lookahead(tee_iterator):
  2. "Return the next value without moving the input forward"
  3. [forked_iterator] = tee(tee_iterator, 1)
  4. return next(forked_iterator)
  1. >>> iterator = iter('abcdef')
  2. >>> [iterator] = tee(iterator, 1) # Make the input peekable
  3. >>> next(iterator) # Move the iterator forward
  4. 'a'
  5. >>> lookahead(iterator) # Check next value
  6. 'b'
  7. >>> next(iterator) # Continue moving forward
  8. 'b'

tee 迭代器不是线程安全的。 当同时使用由同一个 tee() 调用所返回的迭代器时可能引发 RuntimeError,即使原本的 iterable 是线程安全的。is threadsafe.

该迭代工具可能需要相当大的辅助存储空间(这取决于要保存多少临时数据)。通常,如果一个迭代器在另一个迭代器开始之前就要使用大部份或全部数据,使用 list() 会比 tee() 更快。

itertools.zip_longest(*iterables, fillvalue=None)

创建一个迭代器,它聚合了来自 iterables 中每一项的对应元素。

如果 iterables 中每一项的长度不同,则缺失的值将以 fillvalue 填充。 如果未指定,则 fillvalue 默认为 None

迭代将持续进行直至其中最长的可迭代对象被耗尽。

大致相当于:

  1. def zip_longest(*iterables, fillvalue=None):
  2. # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D-
  3. iterators = list(map(iter, iterables))
  4. num_active = len(iterators)
  5. if not num_active:
  6. return
  7. while True:
  8. values = []
  9. for i, iterator in enumerate(iterators):
  10. try:
  11. value = next(iterator)
  12. except StopIteration:
  13. num_active -= 1
  14. if not num_active:
  15. return
  16. iterators[i] = repeat(fillvalue)
  17. value = fillvalue
  18. values.append(value)
  19. yield tuple(values)

如果 iterables 中的每一项可能有无限长度,则 zip_longest() 函数应当用限制调用次数的代码进行包装(例如 islice()takewhile() 等)。

itertools 配方

本节将展示如何使用现有的 itertools 作为基础构件来创建扩展的工具集。

这些 itertools 专题的主要目的是教学。 各个专题显示了对单个工具的各种思维方式 — 例如,chain.from_iterable 被关联到展平的概念。 这些专题还给出了有关这些工具的组合方式的想法 — 例如,starmap()repeat() 应当如何一起工作。 这些专题还显示了 itertools 与 operatorcollections 模块以及内置迭代工具如 map(), filter(), reversed()enumerate() 相互配合的使用模式。

这些例程的次要目的是作为一个孵化器使用。 accumulate(), compress()pairwise() 等迭代工具最初就是作为例程引入的。 目前,sliding_window(), iter_index()sieve() 例程正在被测试以确定它们是否堪当大任。

基本上所有这些配方和许许多多其他配方都可以通过 Python Package Index 上的 more-itertools 项目来安装:

  1. python -m pip install more-itertools

许多例程提供了与底层工具集相当的高性能。 更好的内存效率是通过每次只处理一个元素而不是将整个可迭代对象放入内存来保证的。 代码量的精简是通过以 函数式风格 来链接工具来实现的。 运行的早速度是通过选择使用“矢量化”构件来取代会导致较大解释器开销的 for 循环和 生成器 来达成的。

  1. import collections
  2. import contextlib
  3. import functools
  4. import math
  5. import operator
  6. import random
  7. def take(n, iterable):
  8. "Return first n items of the iterable as a list."
  9. return list(islice(iterable, n))
  10. def prepend(value, iterable):
  11. "Prepend a single value in front of an iterable."
  12. # prepend(1, [2, 3, 4]) → 1 2 3 4
  13. return chain([value], iterable)
  14. def tabulate(function, start=0):
  15. "Return function(0), function(1), ..."
  16. return map(function, count(start))
  17. def repeatfunc(func, times=None, *args):
  18. "Repeat calls to func with specified arguments."
  19. if times is None:
  20. return starmap(func, repeat(args))
  21. return starmap(func, repeat(args, times))
  22. def flatten(list_of_lists):
  23. "Flatten one level of nesting."
  24. return chain.from_iterable(list_of_lists)
  25. def ncycles(iterable, n):
  26. "Returns the sequence elements n times."
  27. return chain.from_iterable(repeat(tuple(iterable), n))
  28. def tail(n, iterable):
  29. "Return an iterator over the last n items."
  30. # tail(3, 'ABCDEFG') → E F G
  31. return iter(collections.deque(iterable, maxlen=n))
  32. def consume(iterator, n=None):
  33. "Advance the iterator n-steps ahead. If n is None, consume entirely."
  34. # Use functions that consume iterators at C speed.
  35. if n is None:
  36. collections.deque(iterator, maxlen=0)
  37. else:
  38. next(islice(iterator, n, n), None)
  39. def nth(iterable, n, default=None):
  40. "Returns the nth item or a default value."
  41. return next(islice(iterable, n, None), default)
  42. def quantify(iterable, predicate=bool):
  43. "Given a predicate that returns True or False, count the True results."
  44. return sum(map(predicate, iterable))
  45. def first_true(iterable, default=False, predicate=None):
  46. "Returns the first true value or the *default* if there is no true value."
  47. # first_true([a,b,c], x) → a or b or c or x
  48. # first_true([a,b], x, f) → a if f(a) else b if f(b) else x
  49. return next(filter(predicate, iterable), default)
  50. def all_equal(iterable, key=None):
  51. "Returns True if all the elements are equal to each other."
  52. # all_equal('4٤௪౪໔', key=int) → True
  53. return len(take(2, groupby(iterable, key))) <= 1
  54. def unique_justseen(iterable, key=None):
  55. "Yield unique elements, preserving order. Remember only the element just seen."
  56. # unique_justseen('AAAABBBCCDAABBB') → A B C D A B
  57. # unique_justseen('ABBcCAD', str.casefold) → A B c A D
  58. if key is None:
  59. return map(operator.itemgetter(0), groupby(iterable))
  60. return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
  61. def unique_everseen(iterable, key=None):
  62. "Yield unique elements, preserving order. Remember all elements ever seen."
  63. # unique_everseen('AAAABBBCCDAABBB') → A B C D
  64. # unique_everseen('ABBcCAD', str.casefold) → A B c D
  65. seen = set()
  66. if key is None:
  67. for element in filterfalse(seen.__contains__, iterable):
  68. seen.add(element)
  69. yield element
  70. else:
  71. for element in iterable:
  72. k = key(element)
  73. if k not in seen:
  74. seen.add(k)
  75. yield element
  76. def unique(iterable, key=None, reverse=False):
  77. "Yield unique elements in sorted order. Supports unhashable inputs."
  78. # unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
  79. return unique_justseen(sorted(iterable, key=key, reverse=reverse), key=key)
  80. def sliding_window(iterable, n):
  81. "Collect data into overlapping fixed-length chunks or blocks."
  82. # sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
  83. iterator = iter(iterable)
  84. window = collections.deque(islice(iterator, n - 1), maxlen=n)
  85. for x in iterator:
  86. window.append(x)
  87. yield tuple(window)
  88. def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
  89. "Collect data into non-overlapping fixed-length chunks or blocks."
  90. # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
  91. # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
  92. # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
  93. iterators = [iter(iterable)] * n
  94. match incomplete:
  95. case 'fill':
  96. return zip_longest(*iterators, fillvalue=fillvalue)
  97. case 'strict':
  98. return zip(*iterators, strict=True)
  99. case 'ignore':
  100. return zip(*iterators)
  101. case _:
  102. raise ValueError('Expected fill, strict, or ignore')
  103. def roundrobin(*iterables):
  104. "Visit input iterables in a cycle until each is exhausted."
  105. # roundrobin('ABC', 'D', 'EF') → A D E B F C
  106. # Algorithm credited to George Sakkis
  107. iterators = map(iter, iterables)
  108. for num_active in range(len(iterables), 0, -1):
  109. iterators = cycle(islice(iterators, num_active))
  110. yield from map(next, iterators)
  111. def subslices(seq):
  112. "Return all contiguous non-empty subslices of a sequence."
  113. # subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
  114. slices = starmap(slice, combinations(range(len(seq) + 1), 2))
  115. return map(operator.getitem, repeat(seq), slices)
  116. def iter_index(iterable, value, start=0, stop=None):
  117. "Return indices where a value occurs in a sequence or iterable."
  118. # iter_index('AABCADEAF', 'A') → 0 1 4 7
  119. seq_index = getattr(iterable, 'index', None)
  120. if seq_index is None:
  121. iterator = islice(iterable, start, stop)
  122. for i, element in enumerate(iterator, start):
  123. if element is value or element == value:
  124. yield i
  125. else:
  126. stop = len(iterable) if stop is None else stop
  127. i = start
  128. with contextlib.suppress(ValueError):
  129. while True:
  130. yield (i := seq_index(value, i, stop))
  131. i += 1
  132. def iter_except(func, exception, first=None):
  133. "Convert a call-until-exception interface to an iterator interface."
  134. # iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
  135. with contextlib.suppress(exception):
  136. if first is not None:
  137. yield first()
  138. while True:
  139. yield func()

下面的例程具有更数学化的风格:

  1. def powerset(iterable):
  2. "powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
  3. s = list(iterable)
  4. return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
  5. def sum_of_squares(iterable):
  6. "Add up the squares of the input values."
  7. # sum_of_squares([10, 20, 30]) → 1400
  8. return math.sumprod(*tee(iterable))
  9. def reshape(matrix, cols):
  10. "Reshape a 2-D matrix to have a given number of columns."
  11. # reshape([(0, 1), (2, 3), (4, 5)], 3) → (0, 1, 2), (3, 4, 5)
  12. return batched(chain.from_iterable(matrix), cols, strict=True)
  13. def transpose(matrix):
  14. "Swap the rows and columns of a 2-D matrix."
  15. # transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
  16. return zip(*matrix, strict=True)
  17. def matmul(m1, m2):
  18. "Multiply two matrices."
  19. # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
  20. n = len(m2[0])
  21. return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)
  22. def convolve(signal, kernel):
  23. """Discrete linear convolution of two iterables.
  24. Equivalent to polynomial multiplication.
  25. Convolutions are mathematically commutative; however, the inputs are
  26. evaluated differently. The signal is consumed lazily and can be
  27. infinite. The kernel is fully consumed before the calculations begin.
  28. Article: https://betterexplained.com/articles/intuitive-convolution/
  29. Video: https://www.youtube.com/watch?v=KuXjwB4LzSA
  30. """
  31. # convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
  32. # convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
  33. # convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
  34. # convolve(data, [1, -2, 1]) → 2nd derivative estimate
  35. kernel = tuple(kernel)[::-1]
  36. n = len(kernel)
  37. padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
  38. windowed_signal = sliding_window(padded_signal, n)
  39. return map(math.sumprod, repeat(kernel), windowed_signal)
  40. def polynomial_from_roots(roots):
  41. """Compute a polynomial's coefficients from its roots.
  42. (x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60
  43. """
  44. # polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
  45. factors = zip(repeat(1), map(operator.neg, roots))
  46. return list(functools.reduce(convolve, factors, [1]))
  47. def polynomial_eval(coefficients, x):
  48. """Evaluate a polynomial at a specific value.
  49. Computes with better numeric stability than Horner's method.
  50. """
  51. # Evaluate x³ -4x² -17x + 60 at x = 5
  52. # polynomial_eval([1, -4, -17, 60], x=5) → 0
  53. n = len(coefficients)
  54. if not n:
  55. return type(x)(0)
  56. powers = map(pow, repeat(x), reversed(range(n)))
  57. return math.sumprod(coefficients, powers)
  58. def polynomial_derivative(coefficients):
  59. """Compute the first derivative of a polynomial.
  60. f(x) = x³ -4x² -17x + 60
  61. f'(x) = 3x² -8x -17
  62. """
  63. # polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
  64. n = len(coefficients)
  65. powers = reversed(range(1, n))
  66. return list(map(operator.mul, coefficients, powers))
  67. def sieve(n):
  68. "Primes less than n."
  69. # sieve(30) → 2 3 5 7 11 13 17 19 23 29
  70. if n > 2:
  71. yield 2
  72. data = bytearray((0, 1)) * (n // 2)
  73. for p in iter_index(data, 1, start=3, stop=math.isqrt(n) + 1):
  74. data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
  75. yield from iter_index(data, 1, start=3)
  76. def factor(n):
  77. "Prime factors of n."
  78. # factor(99) → 3 3 11
  79. # factor(1_000_000_000_000_007) → 47 59 360620266859
  80. # factor(1_000_000_000_000_403) → 1000000000000403
  81. for prime in sieve(math.isqrt(n) + 1):
  82. while not n % prime:
  83. yield prime
  84. n //= prime
  85. if n == 1:
  86. return
  87. if n > 1:
  88. yield n
  89. def totient(n):
  90. "Count of natural numbers up to n that are coprime to n."
  91. # https://mathworld.wolfram.com/TotientFunction.html
  92. # totient(12) → 4 because len([1, 5, 7, 11]) == 4
  93. for prime in set(factor(n)):
  94. n -= n // prime
  95. return n