tf.data :数据集的构建与预处理

很多时候,我们希望使用自己的数据集来训练模型。然而,面对一堆格式不一的原始数据文件,将其预处理并读入程序的过程往往十分繁琐,甚至比模型的设计还要耗费精力。比如,为了读入一批图像文件,我们可能需要纠结于python的各种图像处理包(比如 pillow ),自己设计Batch的生成方式,最后还可能在运行的效率上不尽如人意。为此,TensorFlow提供了 tf.data 这一模块,包括了一套灵活的数据集构建API,能够帮助我们快速、高效地构建数据输入的流水线,尤其适用于数据量巨大的场景。

数据集对象的建立

tf.data 的核心是 tf.data.Dataset 类,提供了对数据集的高层封装。tf.data.Dataset 由一系列的可迭代访问的元素(element)组成,每个元素包含一个或多个张量。比如说,对于一个由图像组成的数据集,每个元素可以是一个形状为 长×宽×通道数 的图片张量,也可以是由图片张量和图片标签张量组成的元组(Tuple)。

最基础的建立 tf.data.Dataset 的方法是使用 tf.data.Dataset.from_tensor_slices() ,适用于数据量较小(能够整个装进内存)的情况。具体而言,如果我们的数据集中的所有元素通过张量的第0维,拼接成一个大的张量(例如,前节的MNIST数据集的训练集即为一个 [60000, 28, 28, 1] 的张量,表示了60000张28*28的单通道灰度图像),那么我们提供一个这样的张量或者第0维大小相同的多个张量作为输入,即可按张量的第0维展开来构建数据集,数据集的元素数量为张量第0位的大小。具体示例如下:

  1. import tensorflow as tf
  2. import numpy as np
  3.  
  4. X = tf.constant([2013, 2014, 2015, 2016, 2017])
  5. Y = tf.constant([12000, 14000, 15000, 16500, 17500])
  6.  
  7. # 也可以使用NumPy数组,效果相同
  8. # X = np.array([2013, 2014, 2015, 2016, 2017])
  9. # Y = np.array([12000, 14000, 15000, 16500, 17500])
  10.  
  11. dataset = tf.data.Dataset.from_tensor_slices((X, Y))
  12.  
  13. for x, y in dataset:
  14. print(x.numpy(), y.numpy())

输出:

  1. 2013 12000
  2. 2014 14000
  3. 2015 15000
  4. 2016 16500
  5. 2017 17500

警告

当提供多个张量作为输入时,张量的第0维大小必须相同,且必须将多个张量作为元组(Tuple,即使用Python中的小括号)拼接并作为输入。

类似地,我们可以载入前章的MNIST数据集:

  1. import matplotlib.pyplot as plt
  2.  
  3. (train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()
  4. train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1) # [60000, 28, 28, 1]
  5. mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
  6.  
  7. for image, label in mnist_dataset:
  8. plt.title(label.numpy())
  9. plt.imshow(image.numpy()[:, :, 0])
  10. plt.show()

输出

../../_images/mnist_1.png

提示

TensorFlow Datasets提供了一个基于 tf.data.Datasets 的开箱即用的数据集集合,相关内容可参考 TensorFlow Datasets 。例如,使用以下语句:

  1. import tensorflow_datasets as tfds
  2. dataset = tfds.load("mnist", split=tfds.Split.TRAIN)

即可快速载入MNIST数据集。

数据集对象的预处理

tf.data.Dataset 类为我们提供了多种数据集预处理方法。最常用的如:

  • Dataset.map(f) :对数据集中的每个元素应用函数 f ,得到一个新的数据集(这部分往往结合 tf.io 进行读写和解码文件, tf.image 进行图像处理);

  • Dataset.shuffle(buffer_size) :将数据集打乱(设定一个固定大小的缓冲区(Buffer),取出前 buffer_size 个元素放入,并从缓冲区中随机采样,采样后的数据用后续数据替换);

  • Dataset.batch(batch_size) :将数据集分成批次,即对每 batch_size 个元素,使用 tf.stack() 在第0维合并,成为一个元素。

  • Dataset.prefetch() :预取出数据集中的若干个元素

除此以外,还有 Dataset.repeat() (重复数据集的元素)、 Dataset.reduce() (与Map相对的聚合操作)、 ``Dataset.take()``()等,可参考 API文档 进一步了解。

以下以MNIST数据集进行示例。

使用 Dataset.map() 将所有图片旋转90度:

  1. def rot90(image, label):
  2. image = tf.image.rot90(image)
  3. return image, label
  4.  
  5. mnist_dataset = mnist_dataset.map(rot90)
  6.  
  7. for image, label in mnist_dataset:
  8. plt.title(label.numpy())
  9. plt.imshow(image.numpy()[:, :, 0])
  10. plt.show()
  11.  

输出

../../_images/mnist_1_rot90.png

使用 Dataset.batch() 将数据集划分批次,每个批次的大小为4:

  1. mnist_dataset = mnist_dataset.batch(4)
  2.  
  3. for images, labels in mnist_dataset: # image: [4, 28, 28, 1], labels: [4]
  4. fig, axs = plt.subplots(1, 4)
  5. for i in range(4):
  6. axs[i].set_title(labels.numpy()[i])
  7. axs[i].imshow(images.numpy()[i, :, :, 0])
  8. plt.show()

输出

../../_images/mnist_batch.png

使用 Dataset.shuffle() 将数据打散后再设置批次,缓存大小设置为10000:

  1. mnist_dataset = mnist_dataset.shuffle(buffer_size=10000).batch(4)
  2.  
  3. for images, labels in mnist_dataset:
  4. fig, axs = plt.subplots(1, 4)
  5. for i in range(4):
  6. axs[i].set_title(labels.numpy()[i])
  7. axs[i].imshow(images.numpy()[i, :, :, 0])
  8. plt.show()

输出

../../_images/mnist_shuffle_1.png第一次运行

../../_images/mnist_shuffle_2.png第二次运行

可见每次的数据都会被随机打散。

Dataset.shuffle() 时缓冲区大小 buffer_size 的设置

tf.data.Dataset 作为一个针对大规模数据设计的迭代器,本身无法方便地获得自身元素的数量或随机访问元素。因此,为了高效且较为充分地打散数据集,需要一些特定的方法。Dataset.shuffle() 采取了以下方法:

  • 设定一个固定大小为 buffer_size 的缓冲区(Buffer);

  • 初始化时,取出数据集中的前 buffer_size 个元素放入缓冲区;

  • 每次需要从数据集中取元素时,即从缓冲区中随机采样一个元素并取出,然后从后续的元素中取出一个放回到之前被取出的位置,以维持缓冲区的大小。

因此,缓冲区的大小需要根据数据集的特性和数据排列顺序特点来进行合理的设置。比如:

  • buffer_size 设置为1时,其实等价于没有进行任何打散;

  • 当数据集的标签顺序分布极为不均匀(例如二元分类时数据集前N个的标签为0,后N个的标签为1)时,较小的缓冲区大小会使得训练时取出的Batch数据很可能全为同一标签,从而影响训练效果。一般而言,数据集的顺序分布若较为随机,则缓冲区的大小可较小,否则则需要设置较大的缓冲区。

数据集元素的获取与使用

构建好数据并预处理后,我们需要从其中迭代获取数据以用于训练。tf.data.Dataset 是一个Python的可迭代对象,因此可以使用For循环迭代获取数据,即:

  1. dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
  2. for a, b, c, ... in dataset:
  3. # 对张量a, b, c等进行操作,例如送入模型进行训练

也可以使用 iter() 显式创建一个Python迭代器并使用 next() 获取下一个元素,即:

  1. dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
  2. it = iter(dataset)
  3. a_0, b_0, c_0, ... = next(it)
  4. a_1, b_1, c_1, ... = next(it)

Keras支持使用 tf.data.Dataset 直接作为输入。当调用 tf.keras.Modelfit()evaluate() 方法时,可以将参数中的输入数据 x 指定为一个元素格式为 (输入数据, 标签数据)Dataset ,并忽略掉参数中的标签数据 y 。例如,对于上述的MNIST数据集,常规的Keras训练方式是:

  1. model.fit(x=train_data, y=train_label, epochs=num_epochs, batch_size=batch_size)

使用 tf.data.Dataset 后,我们可以直接传入 Dataset

  1. model.fit(mnist_dataset, epochs=num_epochs)

由于已经通过 Dataset.batch() 方法划分了数据集的批次,所以这里也无需提供批次的大小。

实例:cats_vs_dogs图像分类

以下代码以猫狗图片二分类任务为示例,展示了使用 tf.data 结合 tf.iotf.image 建立 tf.data.Dataset 数据集,并进行训练和测试的完整过程。数据集可至 这里 下载。

  1. import tensorflow as tf
  2. import os
  3.  
  4. num_epochs = 10
  5. batch_size = 32
  6. learning_rate = 0.001
  7. data_dir = 'C:/datasets/cats_vs_dogs'
  8. train_cats_dir = data_dir + '/train/cats/'
  9. train_dogs_dir = data_dir + '/train/dogs/'
  10. test_cats_dir = data_dir + '/valid/cats/'
  11. test_dogs_dir = data_dir + '/valid/dogs/'
  12.  
  13. def _decode_and_resize(filename, label):
  14. image_string = tf.io.read_file(filename)
  15. image_decoded = tf.image.decode_jpeg(image_string)
  16. image_resized = tf.image.resize(image_decoded, [256, 256]) / 255.0
  17. return image_resized, label
  18.  
  19. if __name__ == '__main__':
  20. # 构建训练数据集
  21. train_cat_filenames = tf.constant([train_cats_dir + filename for filename in os.listdir(train_cats_dir)])
  22. train_dog_filenames = tf.constant([train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)])
  23. train_filenames = tf.concat([train_cat_filenames, train_dog_filenames], axis=-1)
  24. train_labels = tf.concat([
  25. tf.zeros(train_cat_filenames.shape, dtype=tf.int32),
  26. tf.ones(train_dog_filenames.shape, dtype=tf.int32)],
  27. axis=-1)
  28.  
  29. train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))
  30. train_dataset = train_dataset.map(_decode_and_resize)
  31. # 取出前buffer_size个数据放入buffer,并从其中随机采样,采样后的数据用后续数据替换
  32. train_dataset = train_dataset.shuffle(buffer_size=23000)
  33. train_dataset = train_dataset.batch(batch_size)
  34.  
  35. model = tf.keras.Sequential([
  36. tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(256, 256, 3)),
  37. tf.keras.layers.MaxPooling2D(),
  38. tf.keras.layers.Conv2D(32, 5, activation='relu'),
  39. tf.keras.layers.MaxPooling2D(),
  40. tf.keras.layers.Flatten(),
  41. tf.keras.layers.Dense(64, activation='relu'),
  42. tf.keras.layers.Dense(2, activation='softmax')
  43. ])
  44.  
  45. model.compile(
  46. optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
  47. loss=tf.keras.losses.sparse_categorical_crossentropy,
  48. metrics=[tf.keras.metrics.sparse_categorical_accuracy]
  49. )
  50.  
  51. model.fit(train_dataset, epochs=num_epochs)

使用以下代码进行测试:

  1. # 构建测试数据集
  2. test_cat_filenames = tf.constant([test_cats_dir + filename for filename in os.listdir(test_cats_dir)])
  3. test_dog_filenames = tf.constant([test_dogs_dir + filename for filename in os.listdir(test_dogs_dir)])
  4. test_filenames = tf.concat([test_cat_filenames, test_dog_filenames], axis=-1)
  5. test_labels = tf.concat([
  6. tf.zeros(test_cat_filenames.shape, dtype=tf.int32),
  7. tf.ones(test_dog_filenames.shape, dtype=tf.int32)],
  8. axis=-1)
  9.  
  10. test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))
  11. test_dataset = test_dataset.map(_decode_and_resize)
  12. test_dataset = test_dataset.batch(batch_size)
  13.  
  14. print(model.metrics_names)
  15. print(model.evaluate(test_dataset))