Train a simple deep CNN on the CIFAR10 small images dataset using augmentation.

Using TensorFlow internal augmentation APIs by replacing ImageGenerator withan embedded AugmentLayer using LambdaLayer, which is faster on GPU.

Benchmark of ImageGenerator(IG) vs AugmentLayer(AL) both using augmentation2D:

(backend = Tensorflow-GPU, Nvidia Tesla P100-SXM2)

Epoch no.IG %AccuracyIG PerformanceAL %AccuracyAL Performance
144.8415 ms/step45.54358 us/step
252.348 ms/step50.55285 us/step
865.458 ms/step65.59281 us/step
2576.748 ms/step76.17280 us/step
10078.818 ms/step78.70285 us/step

Settings: horizontal_flip = True

Epoch no.IG %AccuracyIG PerformanceAL %AccuracyAL Performance
143.4615 ms/step42.21334 us/step
248.9511 ms/step48.06282 us/step
863.5911 ms/step61.35290 us/step
2572.2512 ms/step71.08287 us/step
10076.3511 ms/step74.62286 us/step

Settings: rotation = 30.0

(Corner process and rotation precision by ImageGenerator and AugmentLayerare slightly different.)

  1. from __future__ import print_function
  2. import keras
  3. from keras.datasets import cifar10
  4. from keras.models import Sequential
  5. from keras.layers import Dense, Dropout, Activation, Flatten
  6. from keras.layers import Conv2D, Lambda, MaxPooling2D
  7. from keras import backend as K
  8. import os
  9. if K.backend() != 'tensorflow':
  10. raise RuntimeError('This example can only run with the '
  11. 'TensorFlow backend, '
  12. 'because it requires TF-native augmentation APIs')
  13. import tensorflow as tf
  14. def augment_2d(inputs, rotation=0, horizontal_flip=False, vertical_flip=False):
  15. """Apply additive augmentation on 2D data.
  16. # Arguments
  17. rotation: A float, the degree range for rotation (0 <= rotation < 180),
  18. e.g. 3 for random image rotation between (-3.0, 3.0).
  19. horizontal_flip: A boolean, whether to allow random horizontal flip,
  20. e.g. true for 50% possibility to flip image horizontally.
  21. vertical_flip: A boolean, whether to allow random vertical flip,
  22. e.g. true for 50% possibility to flip image vertically.
  23. # Returns
  24. input data after augmentation, whose shape is the same as its original.
  25. """
  26. if inputs.dtype != tf.float32:
  27. inputs = tf.image.convert_image_dtype(inputs, dtype=tf.float32)
  28. with tf.name_scope('augmentation'):
  29. shp = tf.shape(inputs)
  30. batch_size, height, width = shp[0], shp[1], shp[2]
  31. width = tf.cast(width, tf.float32)
  32. height = tf.cast(height, tf.float32)
  33. transforms = []
  34. identity = tf.constant([1, 0, 0, 0, 1, 0, 0, 0], dtype=tf.float32)
  35. if rotation > 0:
  36. angle_rad = rotation * 3.141592653589793 / 180.0
  37. angles = tf.random_uniform([batch_size], -angle_rad, angle_rad)
  38. f = tf.contrib.image.angles_to_projective_transforms(angles,
  39. height, width)
  40. transforms.append(f)
  41. if horizontal_flip:
  42. coin = tf.less(tf.random_uniform([batch_size], 0, 1.0), 0.5)
  43. shape = [-1., 0., width, 0., 1., 0., 0., 0.]
  44. flip_transform = tf.convert_to_tensor(shape, dtype=tf.float32)
  45. flip = tf.tile(tf.expand_dims(flip_transform, 0), [batch_size, 1])
  46. noflip = tf.tile(tf.expand_dims(identity, 0), [batch_size, 1])
  47. transforms.append(tf.where(coin, flip, noflip))
  48. if vertical_flip:
  49. coin = tf.less(tf.random_uniform([batch_size], 0, 1.0), 0.5)
  50. shape = [1., 0., 0., 0., -1., height, 0., 0.]
  51. flip_transform = tf.convert_to_tensor(shape, dtype=tf.float32)
  52. flip = tf.tile(tf.expand_dims(flip_transform, 0), [batch_size, 1])
  53. noflip = tf.tile(tf.expand_dims(identity, 0), [batch_size, 1])
  54. transforms.append(tf.where(coin, flip, noflip))
  55. if transforms:
  56. f = tf.contrib.image.compose_transforms(*transforms)
  57. inputs = tf.contrib.image.transform(inputs, f, interpolation='BILINEAR')
  58. return inputs
  59. batch_size = 32
  60. num_classes = 10
  61. epochs = 100
  62. num_predictions = 20
  63. save_dir = '/tmp/saved_models'
  64. model_name = 'keras_cifar10_trained_model.h5'
  65. # The data, split between train and test sets:
  66. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  67. print('x_train shape:', x_train.shape)
  68. print(x_train.shape[0], 'train samples')
  69. print(x_test.shape[0], 'test samples')
  70. # Convert class vectors to binary class matrices.
  71. y_train = keras.utils.to_categorical(y_train, num_classes)
  72. y_test = keras.utils.to_categorical(y_test, num_classes)
  73. model = Sequential()
  74. model.add(Lambda(augment_2d,
  75. input_shape=x_train.shape[1:],
  76. arguments={'rotation': 8.0, 'horizontal_flip': True}))
  77. model.add(Conv2D(32, (3, 3), padding='same'))
  78. model.add(Activation('relu'))
  79. model.add(Conv2D(32, (3, 3)))
  80. model.add(Activation('relu'))
  81. model.add(MaxPooling2D(pool_size=(2, 2)))
  82. model.add(Dropout(0.25))
  83. model.add(Conv2D(64, (3, 3), padding='same'))
  84. model.add(Activation('relu'))
  85. model.add(Conv2D(64, (3, 3)))
  86. model.add(Activation('relu'))
  87. model.add(MaxPooling2D(pool_size=(2, 2)))
  88. model.add(Dropout(0.25))
  89. model.add(Flatten())
  90. model.add(Dense(512))
  91. model.add(Activation('relu'))
  92. model.add(Dropout(0.5))
  93. model.add(Dense(num_classes))
  94. model.add(Activation('softmax'))
  95. # initiate RMSprop optimizer
  96. opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)
  97. # Let's train the model using RMSprop
  98. model.compile(loss='categorical_crossentropy',
  99. optimizer=opt,
  100. metrics=['accuracy'])
  101. x_train = x_train.astype('float32')
  102. x_test = x_test.astype('float32')
  103. x_train /= 255
  104. x_test /= 255
  105. model.fit(x_train, y_train,
  106. batch_size=batch_size,
  107. epochs=epochs,
  108. validation_data=(x_test, y_test),
  109. shuffle=True)
  110. # Save model and weights
  111. if not os.path.isdir(save_dir):
  112. os.makedirs(save_dir)
  113. model_path = os.path.join(save_dir, model_name)
  114. model.save(model_path)
  115. print('Saved trained model at %s ' % model_path)
  116. # Score trained model.
  117. scores = model.evaluate(x_test, y_test, verbose=1)
  118. print('Test loss:', scores[0])
  119. print('Test accuracy:', scores[1])