在 CIFAR10 小型图像数据集上训练一个深度卷积神经网络。

在 25 轮迭代后 验证集准确率达到 75%,在 50 轮后达到 79%。(尽管目前仍然欠拟合)。

  1. from __future__ import print_function
  2. import keras
  3. from keras.datasets import cifar10
  4. from keras.preprocessing.image import ImageDataGenerator
  5. from keras.models import Sequential
  6. from keras.layers import Dense, Dropout, Activation, Flatten
  7. from keras.layers import Conv2D, MaxPooling2D
  8. import os
  9. batch_size = 32
  10. num_classes = 10
  11. epochs = 100
  12. data_augmentation = True
  13. num_predictions = 20
  14. save_dir = os.path.join(os.getcwd(), 'saved_models')
  15. model_name = 'keras_cifar10_trained_model.h5'
  16. # 数据,切分为训练和测试集。
  17. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  18. print('x_train shape:', x_train.shape)
  19. print(x_train.shape[0], 'train samples')
  20. print(x_test.shape[0], 'test samples')
  21. # 将类向量转换为二进制类矩阵。
  22. y_train = keras.utils.to_categorical(y_train, num_classes)
  23. y_test = keras.utils.to_categorical(y_test, num_classes)
  24. model = Sequential()
  25. model.add(Conv2D(32, (3, 3), padding='same',
  26. input_shape=x_train.shape[1:]))
  27. model.add(Activation('relu'))
  28. model.add(Conv2D(32, (3, 3)))
  29. model.add(Activation('relu'))
  30. model.add(MaxPooling2D(pool_size=(2, 2)))
  31. model.add(Dropout(0.25))
  32. model.add(Conv2D(64, (3, 3), padding='same'))
  33. model.add(Activation('relu'))
  34. model.add(Conv2D(64, (3, 3)))
  35. model.add(Activation('relu'))
  36. model.add(MaxPooling2D(pool_size=(2, 2)))
  37. model.add(Dropout(0.25))
  38. model.add(Flatten())
  39. model.add(Dense(512))
  40. model.add(Activation('relu'))
  41. model.add(Dropout(0.5))
  42. model.add(Dense(num_classes))
  43. model.add(Activation('softmax'))
  44. # 初始化 RMSprop 优化器。
  45. opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)
  46. # 利用 RMSprop 来训练模型。
  47. model.compile(loss='categorical_crossentropy',
  48. optimizer=opt,
  49. metrics=['accuracy'])
  50. x_train = x_train.astype('float32')
  51. x_test = x_test.astype('float32')
  52. x_train /= 255
  53. x_test /= 255
  54. if not data_augmentation:
  55. print('Not using data augmentation.')
  56. model.fit(x_train, y_train,
  57. batch_size=batch_size,
  58. epochs=epochs,
  59. validation_data=(x_test, y_test),
  60. shuffle=True)
  61. else:
  62. print('Using real-time data augmentation.')
  63. # 这一步将进行数据处理和实时数据增益。data augmentation:
  64. datagen = ImageDataGenerator(
  65. featurewise_center=False, # 将整个数据集的均值设为0
  66. samplewise_center=False, # 将每个样本的均值设为0
  67. featurewise_std_normalization=False, # 将输入除以整个数据集的标准差
  68. samplewise_std_normalization=False, # 将输入除以其标准差
  69. zca_whitening=False, # 运用 ZCA 白化
  70. zca_epsilon=1e-06, # ZCA 白化的 epsilon值
  71. rotation_range=0, # 随机旋转图像范围 (角度, 0 to 180)
  72. # 随机水平移动图像 (总宽度的百分比)
  73. width_shift_range=0.1,
  74. # 随机垂直移动图像 (总高度的百分比)
  75. height_shift_range=0.1,
  76. shear_range=0., # 设置随机裁剪范围
  77. zoom_range=0., # 设置随机放大范围
  78. channel_shift_range=0., # 设置随机通道切换的范围
  79. # 设置填充输入边界之外的点的模式
  80. fill_mode='nearest',
  81. cval=0., # 在 fill_mode = "constant" 时使用的值
  82. horizontal_flip=True, # 随机水平翻转图像
  83. vertical_flip=False, # 随机垂直翻转图像
  84. # 设置缩放因子 (在其他转换之前使用)
  85. rescale=None,
  86. # 设置将应用于每一个输入的函数
  87. preprocessing_function=None,
  88. # 图像数据格式,"channels_first" 或 "channels_last" 之一
  89. data_format=None,
  90. # 保留用于验证的图像比例(严格在0和1之间)
  91. validation_split=0.0)
  92. # 计算特征标准化所需的计算量
  93. # (如果应用 ZCA 白化,则为 std,mean和主成分).
  94. datagen.fit(x_train)
  95. # 利用由 datagen.flow() 生成的批来训练模型
  96. model.fit_generator(datagen.flow(x_train, y_train,
  97. batch_size=batch_size),
  98. epochs=epochs,
  99. validation_data=(x_test, y_test),
  100. workers=4)
  101. # 保存模型和权重
  102. if not os.path.isdir(save_dir):
  103. os.makedirs(save_dir)
  104. model_path = os.path.join(save_dir, model_name)
  105. model.save(model_path)
  106. print('Saved trained model at %s ' % model_path)
  107. # 评估训练模型
  108. scores = model.evaluate(x_test, y_test, verbose=1)
  109. print('Test loss:', scores[0])
  110. print('Test accuracy:', scores[1])