Dataset

class paddle.io. Dataset [源代码]

概述Dataset的方法和行为的抽象类。

映射式(map-style)数据集需要继承这个基类,映射式数据集为可以通过一个键值索引并获取指定样本的数据集,所有映射式数据集须实现以下方法:

__getitem__: 根据给定索引获取数据集中指定样本,在 paddle.io.DataLoader 中需要使用此函数通过下标获取样本。

__len__: 返回数据集样本个数, paddle.io.BatchSampler 中需要样本个数生成下标序列。

paddle.io.DataLoader

代码示例

  1. import numpy as np
  2. from paddle.io import Dataset
  3. # define a random dataset
  4. class RandomDataset(Dataset):
  5. def __init__(self, num_samples):
  6. self.num_samples = num_samples
  7. def __getitem__(self, idx):
  8. image = np.random.random([784]).astype('float32')
  9. label = np.random.randint(0, 9, (1, )).astype('int64')
  10. return image, label
  11. def __len__(self):
  12. return self.num_samples
  13. dataset = RandomDataset(10)
  14. for i in range(len(dataset)):
  15. print(dataset[i])

使用本API的教程文档