This script demonstrates the use of a convolutional LSTM network.

This network is used to predict the next frame of an artificiallygenerated movie which contains moving squares.

  1. from keras.models import Sequential
  2. from keras.layers.convolutional import Conv3D
  3. from keras.layers.convolutional_recurrent import ConvLSTM2D
  4. from keras.layers.normalization import BatchNormalization
  5. import numpy as np
  6. import pylab as plt
  7. # We create a layer which take as input movies of shape
  8. # (n_frames, width, height, channels) and returns a movie
  9. # of identical shape.
  10. seq = Sequential()
  11. seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
  12. input_shape=(None, 40, 40, 1),
  13. padding='same', return_sequences=True))
  14. seq.add(BatchNormalization())
  15. seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
  16. padding='same', return_sequences=True))
  17. seq.add(BatchNormalization())
  18. seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
  19. padding='same', return_sequences=True))
  20. seq.add(BatchNormalization())
  21. seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
  22. padding='same', return_sequences=True))
  23. seq.add(BatchNormalization())
  24. seq.add(Conv3D(filters=1, kernel_size=(3, 3, 3),
  25. activation='sigmoid',
  26. padding='same', data_format='channels_last'))
  27. seq.compile(loss='binary_crossentropy', optimizer='adadelta')
  28. # Artificial data generation:
  29. # Generate movies with 3 to 7 moving squares inside.
  30. # The squares are of shape 1x1 or 2x2 pixels,
  31. # which move linearly over time.
  32. # For convenience we first create movies with bigger width and height (80x80)
  33. # and at the end we select a 40x40 window.
  34. def generate_movies(n_samples=1200, n_frames=15):
  35. row = 80
  36. col = 80
  37. noisy_movies = np.zeros((n_samples, n_frames, row, col, 1), dtype=np.float)
  38. shifted_movies = np.zeros((n_samples, n_frames, row, col, 1),
  39. dtype=np.float)
  40. for i in range(n_samples):
  41. # Add 3 to 7 moving squares
  42. n = np.random.randint(3, 8)
  43. for j in range(n):
  44. # Initial position
  45. xstart = np.random.randint(20, 60)
  46. ystart = np.random.randint(20, 60)
  47. # Direction of motion
  48. directionx = np.random.randint(0, 3) - 1
  49. directiony = np.random.randint(0, 3) - 1
  50. # Size of the square
  51. w = np.random.randint(2, 4)
  52. for t in range(n_frames):
  53. x_shift = xstart + directionx * t
  54. y_shift = ystart + directiony * t
  55. noisy_movies[i, t, x_shift - w: x_shift + w,
  56. y_shift - w: y_shift + w, 0] += 1
  57. # Make it more robust by adding noise.
  58. # The idea is that if during inference,
  59. # the value of the pixel is not exactly one,
  60. # we need to train the network to be robust and still
  61. # consider it as a pixel belonging to a square.
  62. if np.random.randint(0, 2):
  63. noise_f = (-1)**np.random.randint(0, 2)
  64. noisy_movies[i, t,
  65. x_shift - w - 1: x_shift + w + 1,
  66. y_shift - w - 1: y_shift + w + 1,
  67. 0] += noise_f * 0.1
  68. # Shift the ground truth by 1
  69. x_shift = xstart + directionx * (t + 1)
  70. y_shift = ystart + directiony * (t + 1)
  71. shifted_movies[i, t, x_shift - w: x_shift + w,
  72. y_shift - w: y_shift + w, 0] += 1
  73. # Cut to a 40x40 window
  74. noisy_movies = noisy_movies[::, ::, 20:60, 20:60, ::]
  75. shifted_movies = shifted_movies[::, ::, 20:60, 20:60, ::]
  76. noisy_movies[noisy_movies >= 1] = 1
  77. shifted_movies[shifted_movies >= 1] = 1
  78. return noisy_movies, shifted_movies
  79. # Train the network
  80. noisy_movies, shifted_movies = generate_movies(n_samples=1200)
  81. seq.fit(noisy_movies[:1000], shifted_movies[:1000], batch_size=10,
  82. epochs=300, validation_split=0.05)
  83. # Testing the network on one movie
  84. # feed it with the first 7 positions and then
  85. # predict the new positions
  86. which = 1004
  87. track = noisy_movies[which][:7, ::, ::, ::]
  88. for j in range(16):
  89. new_pos = seq.predict(track[np.newaxis, ::, ::, ::, ::])
  90. new = new_pos[::, -1, ::, ::, ::]
  91. track = np.concatenate((track, new), axis=0)
  92. # And then compare the predictions
  93. # to the ground truth
  94. track2 = noisy_movies[which][::, ::, ::, ::]
  95. for i in range(15):
  96. fig = plt.figure(figsize=(10, 5))
  97. ax = fig.add_subplot(121)
  98. if i >= 7:
  99. ax.text(1, 3, 'Predictions !', fontsize=20, color='w')
  100. else:
  101. ax.text(1, 3, 'Initial trajectory', fontsize=20)
  102. toplot = track[i, ::, ::, 0]
  103. plt.imshow(toplot)
  104. ax = fig.add_subplot(122)
  105. plt.text(1, 3, 'Ground truth', fontsize=20)
  106. toplot = track2[i, ::, ::, 0]
  107. if i >= 2:
  108. toplot = shifted_movies[which][i - 1, ::, ::, 0]
  109. plt.imshow(toplot)
  110. plt.savefig('%i_animate.png' % (i + 1))