This example demonstrates how to write custom layers for Keras.

We build a custom activation layer called 'Antirectifier',which modifies the shape of the tensor that passes through it.We need to specify two methods: compute_output_shape and call.

Note that the same result can also be achieved via a Lambda layer.

Because our custom layer is written with primitives from the Kerasbackend (K), our code can run both on TensorFlow and Theano.

  1. from __future__ import print_function
  2. import keras
  3. from keras.models import Sequential
  4. from keras import layers
  5. from keras.datasets import mnist
  6. from keras import backend as K
  7. class Antirectifier(layers.Layer):
  8. '''This is the combination of a sample-wise
  9. L2 normalization with the concatenation of the
  10. positive part of the input with the negative part
  11. of the input. The result is a tensor of samples that are
  12. twice as large as the input samples.
  13. It can be used in place of a ReLU.
  14. # Input shape
  15. 2D tensor of shape (samples, n)
  16. # Output shape
  17. 2D tensor of shape (samples, 2*n)
  18. # Theoretical justification
  19. When applying ReLU, assuming that the distribution
  20. of the previous output is approximately centered around 0.,
  21. you are discarding half of your input. This is inefficient.
  22. Antirectifier allows to return all-positive outputs like ReLU,
  23. without discarding any data.
  24. Tests on MNIST show that Antirectifier allows to train networks
  25. with twice less parameters yet with comparable
  26. classification accuracy as an equivalent ReLU-based network.
  27. '''
  28. def compute_output_shape(self, input_shape):
  29. shape = list(input_shape)
  30. assert len(shape) == 2 # only valid for 2D tensors
  31. shape[-1] *= 2
  32. return tuple(shape)
  33. def call(self, inputs):
  34. inputs -= K.mean(inputs, axis=1, keepdims=True)
  35. inputs = K.l2_normalize(inputs, axis=1)
  36. pos = K.relu(inputs)
  37. neg = K.relu(-inputs)
  38. return K.concatenate([pos, neg], axis=1)
  39. # global parameters
  40. batch_size = 128
  41. num_classes = 10
  42. epochs = 40
  43. # the data, split between train and test sets
  44. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  45. x_train = x_train.reshape(60000, 784)
  46. x_test = x_test.reshape(10000, 784)
  47. x_train = x_train.astype('float32')
  48. x_test = x_test.astype('float32')
  49. x_train /= 255
  50. x_test /= 255
  51. print(x_train.shape[0], 'train samples')
  52. print(x_test.shape[0], 'test samples')
  53. # convert class vectors to binary class matrices
  54. y_train = keras.utils.to_categorical(y_train, num_classes)
  55. y_test = keras.utils.to_categorical(y_test, num_classes)
  56. # build the model
  57. model = Sequential()
  58. model.add(layers.Dense(256, input_shape=(784,)))
  59. model.add(Antirectifier())
  60. model.add(layers.Dropout(0.1))
  61. model.add(layers.Dense(256))
  62. model.add(Antirectifier())
  63. model.add(layers.Dropout(0.1))
  64. model.add(layers.Dense(num_classes))
  65. model.add(layers.Activation('softmax'))
  66. # compile the model
  67. model.compile(loss='categorical_crossentropy',
  68. optimizer='rmsprop',
  69. metrics=['accuracy'])
  70. # train the model
  71. model.fit(x_train, y_train,
  72. batch_size=batch_size,
  73. epochs=epochs,
  74. verbose=1,
  75. validation_data=(x_test, y_test))
  76. # next, compare with an equivalent network
  77. # with2x bigger Dense layers and ReLU