DatasetFolder
class paddle.vision.datasets. DatasetFolder ( root, loader=None, extensions=None, transform=None, is_valid_file=None ) [源代码]
一种通用的数据加载方式,当输入以如下的格式存放时: root/class_a/1.ext root/class_a/2.ext root/class_a/3.ext
root/class_b/123.ext root/class_b/456.ext root/class_b/789.ext
参数:
root (str) - 根目录路径。
loader (callable,可选) - 可以加载数据路径的一个函数,如果该值没有设定,默认使用
cv2.imread
。默认值:None。extensions (tuple[str],可选) - 允许的数据后缀列表,如果该值没有设定,默认使用
('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
。默认值:None。transform (callable,可选) - 图片数据的预处理,若为
None
即为不做预处理。默认值为None
。is_valid_file (callable,可选) - 根据每条数据的路径来判断是否合法的一个函数。默认值:None。
代码示例:
import os
import cv2
import tempfile
import shutil
import numpy as np
from paddle.vision.datasets import DatasetFolder
def make_fake_dir():
data_dir = tempfile.mkdtemp()
for i in range(2):
sub_dir = os.path.join(data_dir, 'class_' + str(i))
if not os.path.exists(sub_dir):
os.makedirs(sub_dir)
for j in range(2):
fake_img = (np.random.random((32, 32, 3)) * 255).astype('uint8')
cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img)
return data_dir
temp_dir = make_fake_dir()
# temp_dir is root dir
# temp_dir/class_1/img1_1.jpg
# temp_dir/class_2/img2_1.jpg
data_folder = DatasetFolder(temp_dir)
for items in data_folder:
break
shutil.rmtree(temp_dir)