How to use a stateful LSTM model, stateful vs stateless LSTM performance comparison

More documentation about the Keras LSTM model

The models are trained on an input/output pair, wherethe input is a generated uniformly distributedrandom sequence of length = input_len,and the output is a moving average of the input with window length = tsteps.Both input_len and tsteps are defined in the "editable parameters"section.

A larger tsteps value means that the LSTM will need more memoryto figure out the input-output relationship.This memory length is controlled by the lahead variable (more details below).

The rest of the parameters are:

  • input_len: the length of the generated input sequence
  • lahead: the input sequence length that the LSTM is trained on for each output point
  • batch_size, epochs: same parameters as in the model.fit(…) function

When lahead > 1, the model input is preprocessed to a "rolling window view"of the data, with the window length = lahead.This is similar to sklearn's view_as_windowswith window_shape being a single number.

When lahead < tsteps, only the stateful LSTM converges because itsstatefulness allows it to see beyond the capability that laheadgave it to fit the n-point average. The stateless LSTM does not havethis capability, and hence is limited by its lahead parameter,which is not sufficient to see the n-point average.

When lahead >= tsteps, both the stateful and stateless LSTM converge.

  1. from __future__ import print_function
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. import pandas as pd
  5. from keras.models import Sequential
  6. from keras.layers import Dense, LSTM
  7. # ----------------------------------------------------------
  8. # EDITABLE PARAMETERS
  9. # Read the documentation in the script head for more details
  10. # ----------------------------------------------------------
  11. # length of input
  12. input_len = 1000
  13. # The window length of the moving average used to generate
  14. # the output from the input in the input/output pair used
  15. # to train the LSTM
  16. # e.g. if tsteps=2 and input=[1, 2, 3, 4, 5],
  17. # then output=[1.5, 2.5, 3.5, 4.5]
  18. tsteps = 2
  19. # The input sequence length that the LSTM is trained on for each output point
  20. lahead = 1
  21. # training parameters passed to "model.fit(...)"
  22. batch_size = 1
  23. epochs = 10
  24. # ------------
  25. # MAIN PROGRAM
  26. # ------------
  27. print("*" * 33)
  28. if lahead >= tsteps:
  29. print("STATELESS LSTM WILL ALSO CONVERGE")
  30. else:
  31. print("STATELESS LSTM WILL NOT CONVERGE")
  32. print("*" * 33)
  33. np.random.seed(1986)
  34. print('Generating Data...')
  35. def gen_uniform_amp(amp=1, xn=10000):
  36. """Generates uniform random data between
  37. -amp and +amp
  38. and of length xn
  39. # Arguments
  40. amp: maximum/minimum range of uniform data
  41. xn: length of series
  42. """
  43. data_input = np.random.uniform(-1 * amp, +1 * amp, xn)
  44. data_input = pd.DataFrame(data_input)
  45. return data_input
  46. # Since the output is a moving average of the input,
  47. # the first few points of output will be NaN
  48. # and will be dropped from the generated data
  49. # before training the LSTM.
  50. # Also, when lahead > 1,
  51. # the preprocessing step later of "rolling window view"
  52. # will also cause some points to be lost.
  53. # For aesthetic reasons,
  54. # in order to maintain generated data length = input_len after pre-processing,
  55. # add a few points to account for the values that will be lost.
  56. to_drop = max(tsteps - 1, lahead - 1)
  57. data_input = gen_uniform_amp(amp=0.1, xn=input_len + to_drop)
  58. # set the target to be a N-point average of the input
  59. expected_output = data_input.rolling(window=tsteps, center=False).mean()
  60. # when lahead > 1, need to convert the input to "rolling window view"
  61. # https://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html
  62. if lahead > 1:
  63. data_input = np.repeat(data_input.values, repeats=lahead, axis=1)
  64. data_input = pd.DataFrame(data_input)
  65. for i, c in enumerate(data_input.columns):
  66. data_input[c] = data_input[c].shift(i)
  67. # drop the nan
  68. expected_output = expected_output[to_drop:]
  69. data_input = data_input[to_drop:]
  70. print('Input shape:', data_input.shape)
  71. print('Output shape:', expected_output.shape)
  72. print('Input head: ')
  73. print(data_input.head())
  74. print('Output head: ')
  75. print(expected_output.head())
  76. print('Input tail: ')
  77. print(data_input.tail())
  78. print('Output tail: ')
  79. print(expected_output.tail())
  80. print('Plotting input and expected output')
  81. plt.plot(data_input[0][:10], '.')
  82. plt.plot(expected_output[0][:10], '-')
  83. plt.legend(['Input', 'Expected output'])
  84. plt.title('Input')
  85. plt.show()
  86. def create_model(stateful):
  87. model = Sequential()
  88. model.add(LSTM(20,
  89. input_shape=(lahead, 1),
  90. batch_size=batch_size,
  91. stateful=stateful))
  92. model.add(Dense(1))
  93. model.compile(loss='mse', optimizer='adam')
  94. return model
  95. print('Creating Stateful Model...')
  96. model_stateful = create_model(stateful=True)
  97. # split train/test data
  98. def split_data(x, y, ratio=0.8):
  99. to_train = int(input_len * ratio)
  100. # tweak to match with batch_size
  101. to_train -= to_train % batch_size
  102. x_train = x[:to_train]
  103. y_train = y[:to_train]
  104. x_test = x[to_train:]
  105. y_test = y[to_train:]
  106. # tweak to match with batch_size
  107. to_drop = x.shape[0] % batch_size
  108. if to_drop > 0:
  109. x_test = x_test[:-1 * to_drop]
  110. y_test = y_test[:-1 * to_drop]
  111. # some reshaping
  112. reshape_3 = lambda x: x.values.reshape((x.shape[0], x.shape[1], 1))
  113. x_train = reshape_3(x_train)
  114. x_test = reshape_3(x_test)
  115. reshape_2 = lambda x: x.values.reshape((x.shape[0], 1))
  116. y_train = reshape_2(y_train)
  117. y_test = reshape_2(y_test)
  118. return (x_train, y_train), (x_test, y_test)
  119. (x_train, y_train), (x_test, y_test) = split_data(data_input, expected_output)
  120. print('x_train.shape: ', x_train.shape)
  121. print('y_train.shape: ', y_train.shape)
  122. print('x_test.shape: ', x_test.shape)
  123. print('y_test.shape: ', y_test.shape)
  124. print('Training')
  125. for i in range(epochs):
  126. print('Epoch', i + 1, '/', epochs)
  127. # Note that the last state for sample i in a batch will
  128. # be used as initial state for sample i in the next batch.
  129. # Thus we are simultaneously training on batch_size series with
  130. # lower resolution than the original series contained in data_input.
  131. # Each of these series are offset by one step and can be
  132. # extracted with data_input[i::batch_size].
  133. model_stateful.fit(x_train,
  134. y_train,
  135. batch_size=batch_size,
  136. epochs=1,
  137. verbose=1,
  138. validation_data=(x_test, y_test),
  139. shuffle=False)
  140. model_stateful.reset_states()
  141. print('Predicting')
  142. predicted_stateful = model_stateful.predict(x_test, batch_size=batch_size)
  143. print('Creating Stateless Model...')
  144. model_stateless = create_model(stateful=False)
  145. print('Training')
  146. model_stateless.fit(x_train,
  147. y_train,
  148. batch_size=batch_size,
  149. epochs=epochs,
  150. verbose=1,
  151. validation_data=(x_test, y_test),
  152. shuffle=False)
  153. print('Predicting')
  154. predicted_stateless = model_stateless.predict(x_test, batch_size=batch_size)
  155. # ----------------------------
  156. print('Plotting Results')
  157. plt.subplot(3, 1, 1)
  158. plt.plot(y_test)
  159. plt.title('Expected')
  160. plt.subplot(3, 1, 2)
  161. # drop the first "tsteps-1" because it is not possible to predict them
  162. # since the "previous" timesteps to use do not exist
  163. plt.plot((y_test - predicted_stateful).flatten()[tsteps - 1:])
  164. plt.title('Stateful: Expected - Predicted')
  165. plt.subplot(3, 1, 3)
  166. plt.plot((y_test - predicted_stateless).flatten())
  167. plt.title('Stateless: Expected - Predicted')
  168. plt.show()