Trains a ResNet on the CIFAR10 dataset.

ResNet v1:Deep Residual Learning for Image Recognition

ResNet v2:Identity Mappings in Deep Residual Networks

Modeln200-epoch accuracyOriginal paper accuracysec/epoch GTX1080Ti
ResNet20 v1392.16 %91.25 %35
ResNet32 v1592.46 %92.49 %50
ResNet44 v1792.50 %92.83 %70
ResNet56 v1992.71 %93.03 %90
ResNet110 v11892.65 %93.39+-.16 %165
ResNet164 v127- %94.07 %-
ResNet1001 v1N/A- %92.39 %-
Modeln200-epoch accuracyOriginal paper accuracysec/epoch GTX1080Ti
ResNet20 v22- %- %—-
ResNet32 v2N/ANA %NA %NA
ResNet44 v2N/ANA %NA %NA
ResNet56 v2693.01 %NA %100
ResNet110 v21293.15 %93.63 %180
ResNet164 v218- %94.54 %-
ResNet1001 v2111- %95.08+-.14 %-
  1. from __future__ import print_function
  2. import keras
  3. from keras.layers import Dense, Conv2D, BatchNormalization, Activation
  4. from keras.layers import AveragePooling2D, Input, Flatten
  5. from keras.optimizers import Adam
  6. from keras.callbacks import ModelCheckpoint, LearningRateScheduler
  7. from keras.callbacks import ReduceLROnPlateau
  8. from keras.preprocessing.image import ImageDataGenerator
  9. from keras.regularizers import l2
  10. from keras import backend as K
  11. from keras.models import Model
  12. from keras.datasets import cifar10
  13. import numpy as np
  14. import os
  15. # Training parameters
  16. batch_size = 32 # orig paper trained all networks with batch_size=128
  17. epochs = 200
  18. data_augmentation = True
  19. num_classes = 10
  20. # Subtracting pixel mean improves accuracy
  21. subtract_pixel_mean = True
  22. # Model parameter
  23. # ----------------------------------------------------------------------------
  24. # | | 200-epoch | Orig Paper| 200-epoch | Orig Paper| sec/epoch
  25. # Model | n | ResNet v1 | ResNet v1 | ResNet v2 | ResNet v2 | GTX1080Ti
  26. # |v1(v2)| %Accuracy | %Accuracy | %Accuracy | %Accuracy | v1 (v2)
  27. # ----------------------------------------------------------------------------
  28. # ResNet20 | 3 (2)| 92.16 | 91.25 | ----- | ----- | 35 (---)
  29. # ResNet32 | 5(NA)| 92.46 | 92.49 | NA | NA | 50 ( NA)
  30. # ResNet44 | 7(NA)| 92.50 | 92.83 | NA | NA | 70 ( NA)
  31. # ResNet56 | 9 (6)| 92.71 | 93.03 | 93.01 | NA | 90 (100)
  32. # ResNet110 |18(12)| 92.65 | 93.39+-.16| 93.15 | 93.63 | 165(180)
  33. # ResNet164 |27(18)| ----- | 94.07 | ----- | 94.54 | ---(---)
  34. # ResNet1001| (111)| ----- | 92.39 | ----- | 95.08+-.14| ---(---)
  35. # ---------------------------------------------------------------------------
  36. n = 3
  37. # Model version
  38. # Orig paper: version = 1 (ResNet v1), Improved ResNet: version = 2 (ResNet v2)
  39. version = 1
  40. # Computed depth from supplied model parameter n
  41. if version == 1:
  42. depth = n * 6 + 2
  43. elif version == 2:
  44. depth = n * 9 + 2
  45. # Model name, depth and version
  46. model_type = 'ResNet%dv%d' % (depth, version)
  47. # Load the CIFAR10 data.
  48. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  49. # Input image dimensions.
  50. input_shape = x_train.shape[1:]
  51. # Normalize data.
  52. x_train = x_train.astype('float32') / 255
  53. x_test = x_test.astype('float32') / 255
  54. # If subtract pixel mean is enabled
  55. if subtract_pixel_mean:
  56. x_train_mean = np.mean(x_train, axis=0)
  57. x_train -= x_train_mean
  58. x_test -= x_train_mean
  59. print('x_train shape:', x_train.shape)
  60. print(x_train.shape[0], 'train samples')
  61. print(x_test.shape[0], 'test samples')
  62. print('y_train shape:', y_train.shape)
  63. # Convert class vectors to binary class matrices.
  64. y_train = keras.utils.to_categorical(y_train, num_classes)
  65. y_test = keras.utils.to_categorical(y_test, num_classes)
  66. def lr_schedule(epoch):
  67. """Learning Rate Schedule
  68. Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
  69. Called automatically every epoch as part of callbacks during training.
  70. # Arguments
  71. epoch (int): The number of epochs
  72. # Returns
  73. lr (float32): learning rate
  74. """
  75. lr = 1e-3
  76. if epoch > 180:
  77. lr *= 0.5e-3
  78. elif epoch > 160:
  79. lr *= 1e-3
  80. elif epoch > 120:
  81. lr *= 1e-2
  82. elif epoch > 80:
  83. lr *= 1e-1
  84. print('Learning rate: ', lr)
  85. return lr
  86. def resnet_layer(inputs,
  87. num_filters=16,
  88. kernel_size=3,
  89. strides=1,
  90. activation='relu',
  91. batch_normalization=True,
  92. conv_first=True):
  93. """2D Convolution-Batch Normalization-Activation stack builder
  94. # Arguments
  95. inputs (tensor): input tensor from input image or previous layer
  96. num_filters (int): Conv2D number of filters
  97. kernel_size (int): Conv2D square kernel dimensions
  98. strides (int): Conv2D square stride dimensions
  99. activation (string): activation name
  100. batch_normalization (bool): whether to include batch normalization
  101. conv_first (bool): conv-bn-activation (True) or
  102. bn-activation-conv (False)
  103. # Returns
  104. x (tensor): tensor as input to the next layer
  105. """
  106. conv = Conv2D(num_filters,
  107. kernel_size=kernel_size,
  108. strides=strides,
  109. padding='same',
  110. kernel_initializer='he_normal',
  111. kernel_regularizer=l2(1e-4))
  112. x = inputs
  113. if conv_first:
  114. x = conv(x)
  115. if batch_normalization:
  116. x = BatchNormalization()(x)
  117. if activation is not None:
  118. x = Activation(activation)(x)
  119. else:
  120. if batch_normalization:
  121. x = BatchNormalization()(x)
  122. if activation is not None:
  123. x = Activation(activation)(x)
  124. x = conv(x)
  125. return x
  126. def resnet_v1(input_shape, depth, num_classes=10):
  127. """ResNet Version 1 Model builder [a]
  128. Stacks of 2 x (3 x 3) Conv2D-BN-ReLU
  129. Last ReLU is after the shortcut connection.
  130. At the beginning of each stage, the feature map size is halved (downsampled)
  131. by a convolutional layer with strides=2, while the number of filters is
  132. doubled. Within each stage, the layers have the same number filters and the
  133. same number of filters.
  134. Features maps sizes:
  135. stage 0: 32x32, 16
  136. stage 1: 16x16, 32
  137. stage 2: 8x8, 64
  138. The Number of parameters is approx the same as Table 6 of [a]:
  139. ResNet20 0.27M
  140. ResNet32 0.46M
  141. ResNet44 0.66M
  142. ResNet56 0.85M
  143. ResNet110 1.7M
  144. # Arguments
  145. input_shape (tensor): shape of input image tensor
  146. depth (int): number of core convolutional layers
  147. num_classes (int): number of classes (CIFAR10 has 10)
  148. # Returns
  149. model (Model): Keras model instance
  150. """
  151. if (depth - 2) % 6 != 0:
  152. raise ValueError('depth should be 6n+2 (eg 20, 32, 44 in [a])')
  153. # Start model definition.
  154. num_filters = 16
  155. num_res_blocks = int((depth - 2) / 6)
  156. inputs = Input(shape=input_shape)
  157. x = resnet_layer(inputs=inputs)
  158. # Instantiate the stack of residual units
  159. for stack in range(3):
  160. for res_block in range(num_res_blocks):
  161. strides = 1
  162. if stack > 0 and res_block == 0: # first layer but not first stack
  163. strides = 2 # downsample
  164. y = resnet_layer(inputs=x,
  165. num_filters=num_filters,
  166. strides=strides)
  167. y = resnet_layer(inputs=y,
  168. num_filters=num_filters,
  169. activation=None)
  170. if stack > 0 and res_block == 0: # first layer but not first stack
  171. # linear projection residual shortcut connection to match
  172. # changed dims
  173. x = resnet_layer(inputs=x,
  174. num_filters=num_filters,
  175. kernel_size=1,
  176. strides=strides,
  177. activation=None,
  178. batch_normalization=False)
  179. x = keras.layers.add([x, y])
  180. x = Activation('relu')(x)
  181. num_filters *= 2
  182. # Add classifier on top.
  183. # v1 does not use BN after last shortcut connection-ReLU
  184. x = AveragePooling2D(pool_size=8)(x)
  185. y = Flatten()(x)
  186. outputs = Dense(num_classes,
  187. activation='softmax',
  188. kernel_initializer='he_normal')(y)
  189. # Instantiate model.
  190. model = Model(inputs=inputs, outputs=outputs)
  191. return model
  192. def resnet_v2(input_shape, depth, num_classes=10):
  193. """ResNet Version 2 Model builder [b]
  194. Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as
  195. bottleneck layer
  196. First shortcut connection per layer is 1 x 1 Conv2D.
  197. Second and onwards shortcut connection is identity.
  198. At the beginning of each stage, the feature map size is halved (downsampled)
  199. by a convolutional layer with strides=2, while the number of filter maps is
  200. doubled. Within each stage, the layers have the same number filters and the
  201. same filter map sizes.
  202. Features maps sizes:
  203. conv1 : 32x32, 16
  204. stage 0: 32x32, 64
  205. stage 1: 16x16, 128
  206. stage 2: 8x8, 256
  207. # Arguments
  208. input_shape (tensor): shape of input image tensor
  209. depth (int): number of core convolutional layers
  210. num_classes (int): number of classes (CIFAR10 has 10)
  211. # Returns
  212. model (Model): Keras model instance
  213. """
  214. if (depth - 2) % 9 != 0:
  215. raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')
  216. # Start model definition.
  217. num_filters_in = 16
  218. num_res_blocks = int((depth - 2) / 9)
  219. inputs = Input(shape=input_shape)
  220. # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths
  221. x = resnet_layer(inputs=inputs,
  222. num_filters=num_filters_in,
  223. conv_first=True)
  224. # Instantiate the stack of residual units
  225. for stage in range(3):
  226. for res_block in range(num_res_blocks):
  227. activation = 'relu'
  228. batch_normalization = True
  229. strides = 1
  230. if stage == 0:
  231. num_filters_out = num_filters_in * 4
  232. if res_block == 0: # first layer and first stage
  233. activation = None
  234. batch_normalization = False
  235. else:
  236. num_filters_out = num_filters_in * 2
  237. if res_block == 0: # first layer but not first stage
  238. strides = 2 # downsample
  239. # bottleneck residual unit
  240. y = resnet_layer(inputs=x,
  241. num_filters=num_filters_in,
  242. kernel_size=1,
  243. strides=strides,
  244. activation=activation,
  245. batch_normalization=batch_normalization,
  246. conv_first=False)
  247. y = resnet_layer(inputs=y,
  248. num_filters=num_filters_in,
  249. conv_first=False)
  250. y = resnet_layer(inputs=y,
  251. num_filters=num_filters_out,
  252. kernel_size=1,
  253. conv_first=False)
  254. if res_block == 0:
  255. # linear projection residual shortcut connection to match
  256. # changed dims
  257. x = resnet_layer(inputs=x,
  258. num_filters=num_filters_out,
  259. kernel_size=1,
  260. strides=strides,
  261. activation=None,
  262. batch_normalization=False)
  263. x = keras.layers.add([x, y])
  264. num_filters_in = num_filters_out
  265. # Add classifier on top.
  266. # v2 has BN-ReLU before Pooling
  267. x = BatchNormalization()(x)
  268. x = Activation('relu')(x)
  269. x = AveragePooling2D(pool_size=8)(x)
  270. y = Flatten()(x)
  271. outputs = Dense(num_classes,
  272. activation='softmax',
  273. kernel_initializer='he_normal')(y)
  274. # Instantiate model.
  275. model = Model(inputs=inputs, outputs=outputs)
  276. return model
  277. if version == 2:
  278. model = resnet_v2(input_shape=input_shape, depth=depth)
  279. else:
  280. model = resnet_v1(input_shape=input_shape, depth=depth)
  281. model.compile(loss='categorical_crossentropy',
  282. optimizer=Adam(learning_rate=lr_schedule(0)),
  283. metrics=['accuracy'])
  284. model.summary()
  285. print(model_type)
  286. # Prepare model model saving directory.
  287. save_dir = os.path.join(os.getcwd(), 'saved_models')
  288. model_name = 'cifar10_%s_model.{epoch:03d}.h5' % model_type
  289. if not os.path.isdir(save_dir):
  290. os.makedirs(save_dir)
  291. filepath = os.path.join(save_dir, model_name)
  292. # Prepare callbacks for model saving and for learning rate adjustment.
  293. checkpoint = ModelCheckpoint(filepath=filepath,
  294. monitor='val_acc',
  295. verbose=1,
  296. save_best_only=True)
  297. lr_scheduler = LearningRateScheduler(lr_schedule)
  298. lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
  299. cooldown=0,
  300. patience=5,
  301. min_lr=0.5e-6)
  302. callbacks = [checkpoint, lr_reducer, lr_scheduler]
  303. # Run training, with or without data augmentation.
  304. if not data_augmentation:
  305. print('Not using data augmentation.')
  306. model.fit(x_train, y_train,
  307. batch_size=batch_size,
  308. epochs=epochs,
  309. validation_data=(x_test, y_test),
  310. shuffle=True,
  311. callbacks=callbacks)
  312. else:
  313. print('Using real-time data augmentation.')
  314. # This will do preprocessing and realtime data augmentation:
  315. datagen = ImageDataGenerator(
  316. # set input mean to 0 over the dataset
  317. featurewise_center=False,
  318. # set each sample mean to 0
  319. samplewise_center=False,
  320. # divide inputs by std of dataset
  321. featurewise_std_normalization=False,
  322. # divide each input by its std
  323. samplewise_std_normalization=False,
  324. # apply ZCA whitening
  325. zca_whitening=False,
  326. # epsilon for ZCA whitening
  327. zca_epsilon=1e-06,
  328. # randomly rotate images in the range (deg 0 to 180)
  329. rotation_range=0,
  330. # randomly shift images horizontally
  331. width_shift_range=0.1,
  332. # randomly shift images vertically
  333. height_shift_range=0.1,
  334. # set range for random shear
  335. shear_range=0.,
  336. # set range for random zoom
  337. zoom_range=0.,
  338. # set range for random channel shifts
  339. channel_shift_range=0.,
  340. # set mode for filling points outside the input boundaries
  341. fill_mode='nearest',
  342. # value used for fill_mode = "constant"
  343. cval=0.,
  344. # randomly flip images
  345. horizontal_flip=True,
  346. # randomly flip images
  347. vertical_flip=False,
  348. # set rescaling factor (applied before any other transformation)
  349. rescale=None,
  350. # set function that will be applied on each input
  351. preprocessing_function=None,
  352. # image data format, either "channels_first" or "channels_last"
  353. data_format=None,
  354. # fraction of images reserved for validation (strictly between 0 and 1)
  355. validation_split=0.0)
  356. # Compute quantities required for featurewise normalization
  357. # (std, mean, and principal components if ZCA whitening is applied).
  358. datagen.fit(x_train)
  359. # Fit the model on the batches generated by datagen.flow().
  360. model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
  361. validation_data=(x_test, y_test),
  362. epochs=epochs, verbose=1, workers=4,
  363. callbacks=callbacks)
  364. # Score trained model.
  365. scores = model.evaluate(x_test, y_test, verbose=1)
  366. print('Test loss:', scores[0])
  367. print('Test accuracy:', scores[1])