ChainDataset

class paddle.io. ChainDataset [源代码]

将多个流式数据集级联的数据集。

用于级联的数据集须都是 paddle.io.IterableDataset 数据集,将各流式数据集按顺序级联为一个数据集。

参数:

  • datasets (list of IterableDataset) - 待级联的多个数据集。

返回:Dataset,级联后的流式数据集

代码示例

  1. import numpy as np
  2. import paddle
  3. from paddle.io import IterableDataset, ChainDataset
  4. # define a random dataset
  5. class RandomDataset(IterableDataset):
  6. def __init__(self, num_samples):
  7. self.num_samples = num_samples
  8. def __iter__(self):
  9. for i in range(10):
  10. image = np.random.random([32]).astype('float32')
  11. label = np.random.randint(0, 9, (1, )).astype('int64')
  12. yield image, label
  13. dataset = ChainDataset([RandomDataset(10), RandomDataset(10)])
  14. for image, label in iter(dataset):
  15. print(image, label)