ChainDataset
class paddle.io. ChainDataset [源代码]
将多个流式数据集级联的数据集。
用于级联的数据集须都是 paddle.io.IterableDataset
数据集,将各流式数据集按顺序级联为一个数据集。
参数:
- datasets (list of IterableDataset) - 待级联的多个数据集。
返回:Dataset,级联后的流式数据集
代码示例
import numpy as np
import paddle
from paddle.io import IterableDataset, ChainDataset
# define a random dataset
class RandomDataset(IterableDataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __iter__(self):
for i in range(10):
image = np.random.random([32]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
yield image, label
dataset = ChainDataset([RandomDataset(10), RandomDataset(10)])
for image, label in iter(dataset):
print(image, label)