random_split

class paddle.io. random_split ( dataset, lengths, generator=None ) [源代码]

给定子集合dataset的长度数组,随机切分出原数据集合的非重复子集合。

参数:

  • dataset (Dataset) - 此参数必须是 paddle.io.Datasetpaddle.io.IterableDataset 的一个子类实例或实现了 __len__ 的Python对象,用于生成样本下标。默认值为None。

  • lengths (list) - 总和为原数组长度的,子集合长度数组。

  • generator (Generator) - 指定采样 data_source 的采样器。默认值为None。

返回: list, 返回按给定长度数组描述随机分割的原数据集合的非重复子集合。

代码示例

  1. import paddle
  2. from paddle.io import random_split
  3. a_list = paddle.io.random_split(range(10), [3, 7])
  4. print(len(a_list))
  5. # 2
  6. for idx, v in enumerate(a_list[0]):
  7. print(idx, v)
  8. # output of the first subset
  9. # 0 1
  10. # 1 3
  11. # 2 9
  12. for idx, v in enumerate(a_list[1]):
  13. print(idx, v)
  14. # output of the second subset
  15. # 0 5
  16. # 1 7
  17. # 2 8
  18. # 3 6
  19. # 4 0
  20. # 5 2
  21. # 6 4