DatasetFolder

class paddle.vision.datasets. DatasetFolder ( root, loader=None, extensions=None, transform=None, is_valid_file=None ) [源代码]

一种通用的数据加载方式,当输入以如下的格式存放时: root/class_a/1.ext root/class_a/2.ext root/class_a/3.ext

root/class_b/123.ext root/class_b/456.ext root/class_b/789.ext

参数:

  • root (str) - 根目录路径。

  • loader (callable,可选) - 可以加载数据路径的一个函数,如果该值没有设定,默认使用 cv2.imread 。默认值:None。

  • extensions (tuple[str],可选) - 允许的数据后缀列表,如果该值没有设定,默认使用 ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') 。默认值:None。

  • transform (callable,可选) - 图片数据的预处理,若为 None 即为不做预处理。默认值为None

  • is_valid_file (callable,可选) - 根据每条数据的路径来判断是否合法的一个函数。默认值:None。

代码示例

  1. import os
  2. import cv2
  3. import tempfile
  4. import shutil
  5. import numpy as np
  6. from paddle.vision.datasets import DatasetFolder
  7. def make_fake_dir():
  8. data_dir = tempfile.mkdtemp()
  9. for i in range(2):
  10. sub_dir = os.path.join(data_dir, 'class_' + str(i))
  11. if not os.path.exists(sub_dir):
  12. os.makedirs(sub_dir)
  13. for j in range(2):
  14. fake_img = (np.random.random((32, 32, 3)) * 255).astype('uint8')
  15. cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img)
  16. return data_dir
  17. temp_dir = make_fake_dir()
  18. # temp_dir is root dir
  19. # temp_dir/class_1/img1_1.jpg
  20. # temp_dir/class_2/img2_1.jpg
  21. data_folder = DatasetFolder(temp_dir)
  22. for items in data_folder:
  23. break
  24. shutil.rmtree(temp_dir)