数据加载和处理教程

译者:yportne13

作者Sasank Chilamkurthy

在解决机器学习问题的时候,人们花了大量精力准备数据。pytorch提供了许多工具来让载入数据更简单并尽量让你的代码的可读性更高。在这篇教程中,我们将从一个容易处理的数据集中学习如何加载和预处理/增强数据。

在运行这个教程前请先确保你已安装以下的包:

  • scikit-image: 图形接口以及变换
  • pandas: 便于处理csv文件
  1. from __future__ import print_function, division
  2. import os
  3. import torch
  4. import pandas as pd
  5. from skimage import io, transform
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. from torch.utils.data import Dataset, DataLoader
  9. from torchvision import transforms, utils
  10. # Ignore warnings
  11. import warnings
  12. warnings.filterwarnings("ignore")
  13. plt.ion() # interactive mode

我们要处理的是一个面部姿态的数据集。也就是按如下方式标注的人脸:

https://pytorch.org/tutorials/_images/landmarked_face2.png

每张脸标注了68个不同的特征点。

注意

这里下载数据集并把它放置在 ‘data/faces/’路径下。这个数据集实际上是对ImageNet中的人脸图像使用表现出色的DLIB姿势估计模型(dlib’s pose estimation) 生成的。

数据集是按如下规则打包成的csv文件:

  1. image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y
  2. 0805personali01.jpg,27,83,27,98, ... 84,134
  3. 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312

快速读取csv并将标注点数据写入(N,2)数组中,其中N是特征点的数量。

  1. landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')
  2. n = 65
  3. img_name = landmarks_frame.iloc[n, 0]
  4. landmarks = landmarks_frame.iloc[n, 1:].as_matrix()
  5. landmarks = landmarks.astype('float').reshape(-1, 2)
  6. print('Image name: {}'.format(img_name))
  7. print('Landmarks shape: {}'.format(landmarks.shape))
  8. print('First 4 Landmarks: {}'.format(landmarks[:4]))

输出:

  1. Image name: person-7.jpg
  2. Landmarks shape: (68, 2)
  3. First 4 Landmarks: [[32\. 65.]
  4. [33\. 76.]
  5. [34\. 86.]
  6. [34\. 97.]]

写一个简单的辅助函数来展示一张图片和它对应的标注点作为例子。

  1. def show_landmarks(image, landmarks):
  2. """Show image with landmarks"""
  3. plt.imshow(image)
  4. plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
  5. plt.pause(0.001) # pause a bit so that plots are updated
  6. plt.figure()
  7. show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
  8. landmarks)
  9. plt.show()

https://pytorch.org/tutorials/_images/sphx_glr_data_loading_tutorial_001.png

数据集类 Dataset class

torch.utils.data.Dataset 是一个代表数据集的抽象类。你自定的数据集类应该继承自 Dataset 类并重新实现以下方法:

  • __len__ 实现 len(dataset) 返还数据集的尺寸。
  • __getitem__ 用来获取一些索引数据,例如 使用dataset[i] 获得第i个样本。

让我们来为我们的数据集创建一个类。我们将在 __init__ 中读取csv的文件内容,在 __getitem__中读取图片。这么做是为了节省内存空间。只有在需要用到图片的时候才读取它而不是一开始就把图片全部存进内存里。

我们的数据样本将按这样一个字典 {'image': image, 'landmarks': landmarks}组织。 我们的数据集类将添加一个可选参数 transform 以方便对样本进行预处理。下一节我们会看到什么时候需要用到 transform 参数。

  1. class FaceLandmarksDataset(Dataset):
  2. """Face Landmarks dataset."""
  3. def __init__(self, csv_file, root_dir, transform=None):
  4. """
  5. Args:
  6. csv_file (string): Path to the csv file with annotations.
  7. root_dir (string): Directory with all the images.
  8. transform (callable, optional): Optional transform to be applied
  9. on a sample.
  10. """
  11. self.landmarks_frame = pd.read_csv(csv_file)
  12. self.root_dir = root_dir
  13. self.transform = transform
  14. def __len__(self):
  15. return len(self.landmarks_frame)
  16. def __getitem__(self, idx):
  17. img_name = os.path.join(self.root_dir,
  18. self.landmarks_frame.iloc[idx, 0])
  19. image = io.imread(img_name)
  20. landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
  21. landmarks = landmarks.astype('float').reshape(-1, 2)
  22. sample = {'image': image, 'landmarks': landmarks}
  23. if self.transform:
  24. sample = self.transform(sample)
  25. return sample

让我们实例化这个类并创建几个数据。我们将会打印出前四个例子的尺寸并展示标注的特征点。

  1. face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
  2. root_dir='data/faces/')
  3. fig = plt.figure()
  4. for i in range(len(face_dataset)):
  5. sample = face_dataset[i]
  6. print(i, sample['image'].shape, sample['landmarks'].shape)
  7. ax = plt.subplot(1, 4, i + 1)
  8. plt.tight_layout()
  9. ax.set_title('Sample #{}'.format(i))
  10. ax.axis('off')
  11. show_landmarks(**sample)
  12. if i == 3:
  13. plt.show()
  14. break

https://pytorch.org/tutorials/_images/sphx_glr_data_loading_tutorial_002.png

输出:

  1. 0 (324, 215, 3) (68, 2)
  2. 1 (500, 333, 3) (68, 2)
  3. 2 (250, 258, 3) (68, 2)
  4. 3 (434, 290, 3) (68, 2)

转换 Transforms

通过上面的例子我们会发现图片并不是同样的尺寸。绝大多数神经网络都假定图片的尺寸相同。因此我们需要做一些预处理。让我们创建三个转换:

  • Rescale: 缩放图片
  • RandomCrop: 对图片进行随机裁剪。这是一种数据增强操作
  • ToTensor: 把 numpy 格式图片转为 torch 格式图片 (我们需要交换坐标轴).

我们会把它们写成可调用的类的形式而不是简单的函数,这样就不需要每次调用时传递一遍参数。我们只需要实现 __call__ 方法,必要的时候实现 __init__ 方法。我们可以这样调用这些转换:

  1. tsfm = Transform(params)
  2. transformed_sample = tsfm(sample)

观察下面这些转换是如何应用在图像和标签上的。

  1. class Rescale(object):
  2. """Rescale the image in a sample to a given size.
  3. Args:
  4. output_size (tuple or int): Desired output size. If tuple, output is
  5. matched to output_size. If int, smaller of image edges is matched
  6. to output_size keeping aspect ratio the same.
  7. """
  8. def __init__(self, output_size):
  9. assert isinstance(output_size, (int, tuple))
  10. self.output_size = output_size
  11. def __call__(self, sample):
  12. image, landmarks = sample['image'], sample['landmarks']
  13. h, w = image.shape[:2]
  14. if isinstance(self.output_size, int):
  15. if h > w:
  16. new_h, new_w = self.output_size * h / w, self.output_size
  17. else:
  18. new_h, new_w = self.output_size, self.output_size * w / h
  19. else:
  20. new_h, new_w = self.output_size
  21. new_h, new_w = int(new_h), int(new_w)
  22. img = transform.resize(image, (new_h, new_w))
  23. # h and w are swapped for landmarks because for images,
  24. # x and y axes are axis 1 and 0 respectively
  25. landmarks = landmarks * [new_w / w, new_h / h]
  26. return {'image': img, 'landmarks': landmarks}
  27. class RandomCrop(object):
  28. """Crop randomly the image in a sample.
  29. Args:
  30. output_size (tuple or int): Desired output size. If int, square crop
  31. is made.
  32. """
  33. def __init__(self, output_size):
  34. assert isinstance(output_size, (int, tuple))
  35. if isinstance(output_size, int):
  36. self.output_size = (output_size, output_size)
  37. else:
  38. assert len(output_size) == 2
  39. self.output_size = output_size
  40. def __call__(self, sample):
  41. image, landmarks = sample['image'], sample['landmarks']
  42. h, w = image.shape[:2]
  43. new_h, new_w = self.output_size
  44. top = np.random.randint(0, h - new_h)
  45. left = np.random.randint(0, w - new_w)
  46. image = image[top: top + new_h,
  47. left: left + new_w]
  48. landmarks = landmarks - [left, top]
  49. return {'image': image, 'landmarks': landmarks}
  50. class ToTensor(object):
  51. """Convert ndarrays in sample to Tensors."""
  52. def __call__(self, sample):
  53. image, landmarks = sample['image'], sample['landmarks']
  54. # swap color axis because
  55. # numpy image: H x W x C
  56. # torch image: C X H X W
  57. image = image.transpose((2, 0, 1))
  58. return {'image': torch.from_numpy(image),
  59. 'landmarks': torch.from_numpy(landmarks)}

组合转换 Compose transforms

接下来我们把这些转换应用到一个例子上。

我们想要把图像的短边调整为256,然后随机裁剪 (randomcrop) 为224大小的正方形。也就是说,我们打算组合一个 RescaleRandomCrop 的变换。 我们可以调用一个简单的类 torchvision.transforms.Compose 来实现这一操作。

  1. scale = Rescale(256)
  2. crop = RandomCrop(128)
  3. composed = transforms.Compose([Rescale(256),
  4. RandomCrop(224)])
  5. # Apply each of the above transforms on sample.
  6. fig = plt.figure()
  7. sample = face_dataset[65]
  8. for i, tsfrm in enumerate([scale, crop, composed]):
  9. transformed_sample = tsfrm(sample)
  10. ax = plt.subplot(1, 3, i + 1)
  11. plt.tight_layout()
  12. ax.set_title(type(tsfrm).__name__)
  13. show_landmarks(**transformed_sample)
  14. plt.show()

https://pytorch.org/tutorials/_images/sphx_glr_data_loading_tutorial_003.png

迭代数据集 Iterating through the dataset

让我们把这些整合起来以创建一个带组合转换的数据集。 总结一下,每次这个数据集被采样时:

  • 及时地从文件中读取图片
  • 对读取的图片应用转换
  • 由于其中一步操作是随机的 (randomcrop) , 数据被增强了

我们可以像之前那样使用 for i in range 循环来对所有创建的数据集执行同样的操作。

  1. transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
  2. root_dir='data/faces/',
  3. transform=transforms.Compose([
  4. Rescale(256),
  5. RandomCrop(224),
  6. ToTensor()
  7. ]))
  8. for i in range(len(transformed_dataset)):
  9. sample = transformed_dataset[i]
  10. print(i, sample['image'].size(), sample['landmarks'].size())
  11. if i == 3:
  12. break

输出:

  1. 0 torch.Size([3, 224, 224]) torch.Size([68, 2])
  2. 1 torch.Size([3, 224, 224]) torch.Size([68, 2])
  3. 2 torch.Size([3, 224, 224]) torch.Size([68, 2])
  4. 3 torch.Size([3, 224, 224]) torch.Size([68, 2])

但是,对所有数据集简单的使用 for 循环牺牲了许多功能,尤其是:

  • 批处理数据(Batching the data)
  • 打乱数据(Shuffling the data)
  • 使用多线程 multiprocessing 并行加载数据。

torch.utils.data.DataLoader 这个迭代器提供了以上所有功能。 下面使用的参数必须是清楚的。 一个值得关注的参数是 collate_fn. 你可以通过 collate_fn 来决定如何对数据进行批处理。 但是绝大多数情况下默认值就能运行良好。

  1. dataloader = DataLoader(transformed_dataset, batch_size=4,
  2. shuffle=True, num_workers=4)
  3. # Helper function to show a batch
  4. def show_landmarks_batch(sample_batched):
  5. """Show image with landmarks for a batch of samples."""
  6. images_batch, landmarks_batch = \
  7. sample_batched['image'], sample_batched['landmarks']
  8. batch_size = len(images_batch)
  9. im_size = images_batch.size(2)
  10. grid = utils.make_grid(images_batch)
  11. plt.imshow(grid.numpy().transpose((1, 2, 0)))
  12. for i in range(batch_size):
  13. plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size,
  14. landmarks_batch[i, :, 1].numpy(),
  15. s=10, marker='.', c='r')
  16. plt.title('Batch from dataloader')
  17. for i_batch, sample_batched in enumerate(dataloader):
  18. print(i_batch, sample_batched['image'].size(),
  19. sample_batched['landmarks'].size())
  20. # observe 4th batch and stop.
  21. if i_batch == 3:
  22. plt.figure()
  23. show_landmarks_batch(sample_batched)
  24. plt.axis('off')
  25. plt.ioff()
  26. plt.show()
  27. break

https://pytorch.org/tutorials/_images/sphx_glr_data_loading_tutorial_004.png

输出:

  1. 0 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
  2. 1 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
  3. 2 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])
  4. 3 torch.Size([4, 3, 224, 224]) torch.Size([4, 68, 2])

后记: torchvision

在这篇教程中我们学习了如何构造和使用数据集类 (datasets), 转换 (transforms) 和数据加载器 (dataloader)。 torchvision 包提供了常用的数据集类 (datasets) 和转换 (transforms)。 你可能不需要自己构造这些类。 torchvision 中还有一个更常用的数据集类 ImageFolder. 它假定了数据集是以如下方式构造的:

  1. root/ants/xxx.png
  2. root/ants/xxy.jpeg
  3. root/ants/xxz.png
  4. .
  5. .
  6. .
  7. root/bees/123.jpg
  8. root/bees/nsdf3.png
  9. root/bees/asd932_.png

其中 ‘ants’, ‘bees’ 等是分类标签。 在 PIL.Image 中你也可以使用类似的转换 (transforms) 例如 RandomHorizontalFlip, Scale。利用这些你可以按如下的方式创建一个数据加载器 (dataloader) :

  1. import torch
  2. from torchvision import transforms, datasets
  3. data_transform = transforms.Compose([
  4. transforms.RandomSizedCrop(224),
  5. transforms.RandomHorizontalFlip(),
  6. transforms.ToTensor(),
  7. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  8. std=[0.229, 0.224, 0.225])
  9. ])
  10. hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
  11. transform=data_transform)
  12. dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
  13. batch_size=4, shuffle=True,
  14. num_workers=4)

带训练部分的例程可以参考这里 Transfer Learning Tutorial.