IterableDataset

class paddle.io. IterableDataset [源代码]

概述迭代式数据集的方法和行为的抽象类。

迭代式(iterable style)数据集需要继承这个基类,迭代式数据集为只能依次迭代式获取样本的数据集,类似Python中的迭代器,所有迭代式数据集须实现以下方法:

__iter__: 依次返回数据赝本。

注解

迭代式数据集不需要实现 __getitem____len__,也不可以调用迭代式数据集的这两个方法。

paddle.io.DataLoader

代码示例

  1. import numpy as np
  2. from paddle.io import IterableDataset
  3. # define a random dataset
  4. class RandomDataset(IterableDataset):
  5. def __init__(self, num_samples):
  6. self.num_samples = num_samples
  7. def __iter__(self):
  8. for i in range(self.num_samples):
  9. image = np.random.random([784]).astype('float32')
  10. label = np.random.randint(0, 9, (1, )).astype('int64')
  11. yield image, label
  12. dataset = RandomDataset(10)
  13. for img, lbl in dataset:
  14. print(img, lbl)

paddle.io.DataLoadernum_workers > 0 时,每个子进程都会遍历全量的数据集返回全量样本,所以数据集会重复 num_workers 次,如果需要数据集样本不会重复返回,可通过如下两种方法避免样本重复,两种方法中都需要通过 paddle.io.get_worker_info 获取各子进程的信息。

  1. 通过 __iter__ 函数划分各子进程的数据

代码示例1

  1. import math
  2. import paddle
  3. import numpy as np
  4. from paddle.io import IterableDataset, DataLoader, get_worker_info
  5. class SplitedIterableDataset(IterableDataset):
  6. def __init__(self, start, end):
  7. self.start = start
  8. self.end = end
  9. def __iter__(self):
  10. worker_info = get_worker_info()
  11. if worker_info is None:
  12. iter_start = self.start
  13. iter_end = self.end
  14. else:
  15. per_worker = int(
  16. math.ceil((self.end - self.start) / float(
  17. worker_info.num_workers)))
  18. worker_id = worker_info.id
  19. iter_start = self.start + worker_id * per_worker
  20. iter_end = min(iter_start + per_worker, self.end)
  21. for i in range(iter_start, iter_end):
  22. yield np.array([i])
  23. dataset = SplitedIterableDataset(start=2, end=9)
  24. dataloader = DataLoader(
  25. dataset,
  26. num_workers=2,
  27. batch_size=1,
  28. drop_last=True)
  29. for data in dataloader:
  30. print(data)
  31. # outputs: [2, 5, 3, 6, 4, 7]
  1. 通过各子进程初始化函数 worker_inif_fn 划分子进程数据

代码示例2

  1. import math
  2. import paddle
  3. import numpy as np
  4. from paddle.io import IterableDataset, DataLoader, get_worker_info
  5. class RangeIterableDataset(IterableDataset):
  6. def __init__(self, start, end):
  7. self.start = start
  8. self.end = end
  9. def __iter__(self):
  10. for i in range(self.start, self.end):
  11. yield np.array([i])
  12. dataset = RangeIterableDataset(start=2, end=9)
  13. def worker_init_fn(worker_id):
  14. worker_info = get_worker_info()
  15. dataset = worker_info.dataset
  16. start = dataset.start
  17. end = dataset.end
  18. num_per_worker = int(
  19. math.ceil((end - start) / float(worker_info.num_workers)))
  20. worker_id = worker_info.id
  21. dataset.start = start + worker_id * num_per_worker
  22. dataset.end = min(dataset.start + num_per_worker, end)
  23. dataloader = DataLoader(
  24. dataset,
  25. num_workers=2,
  26. batch_size=1,
  27. drop_last=True,
  28. worker_init_fn=worker_init_fn)
  29. for data in dataloader:
  30. print(data)
  31. # outputs: [2, 5, 3, 6, 4, 7]