RandomSampler
class paddle.io.RandomSampler
( data_source=None, replacement=False, num_samples=None, generator=None ) [源代码]
随机迭代样本,产生重排下标,如果 replacement = False
,则会采样整个数据集;如果 replacement = True
,则会按照 num_samples
指定的样本数采集。
参数
data_source (Dataset) - 此参数必须是
paddle.io.Dataset
或paddle.io.IterableDataset
的一个子类实例或实现了__len__
的Python对象,用于生成样本下标。默认值为None。replacement (bool) - 如果为
False
则会采样整个数据集,如果为True
则会按num_samples
指定的样本数采集。默认值为False
。num_samples (int) - 如果
replacement
设置为True
则按此参数采集对应的样本数。默认值为None。generator (Generator) - 指定采样
data_source
的采样器。默认值为None。
返回
RandomSampler, 返回随机采样下标的采样器
代码示例
from paddle.io import Dataset, RandomSampler
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
sampler = RandomSampler(data_source=RandomDataset(100))
for index in sampler:
print(index)