1-4 Example: Modeling Procedure for Temporal Sequences

The COVID-19 has been lasting for over three months (Note from the translator: until April, 2020) in China and significantly affected the ordinary life.

The impacts could be on the incomes, emotions, psychologies, and weights.

So how long this pandemic is going to last, and when will we be free again?

This example is about predicting the time of COVID-19 termination in China using RNN model established by TensorFlow 2.

1-4 Example: Modeling Procedure for Temporal Sequences - 图1

1. Data Preparation

The dataset is extracted from “tushare”. The details of the data acquisition is here (in Chinese).

1-4 Example: Modeling Procedure for Temporal Sequences - 图2

  1. import numpy as np
  2. import pandas as pd
  3. import matplotlib.pyplot as plt
  4. import tensorflow as tf
  5. from tensorflow.keras import models,layers,losses,metrics,callbacks
  1. %matplotlib inline
  2. %config InlineBackend.figure_format = 'svg'
  3. df = pd.read_csv("../data/covid-19.csv",sep = "\t")
  4. df.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
  5. plt.xticks(rotation=60)

1-4 Example: Modeling Procedure for Temporal Sequences - 图3

  1. dfdata = df.set_index("date")
  2. dfdiff = dfdata.diff(periods=1).dropna()
  3. dfdiff = dfdiff.reset_index("date")
  4. dfdiff.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
  5. plt.xticks(rotation=60)
  6. dfdiff = dfdiff.drop("date",axis = 1).astype("float32")

1-4 Example: Modeling Procedure for Temporal Sequences - 图4

  1. #Use the data of an eight-day window priorier of the date we are investigating as input for prediction
  2. WINDOW_SIZE = 8
  3. def batch_dataset(dataset):
  4. dataset_batched = dataset.batch(WINDOW_SIZE,drop_remainder=True)
  5. return dataset_batched
  6. ds_data = tf.data.Dataset.from_tensor_slices(tf.constant(dfdiff.values,dtype = tf.float32)) \
  7. .window(WINDOW_SIZE,shift=1).flat_map(batch_dataset)
  8. ds_label = tf.data.Dataset.from_tensor_slices(
  9. tf.constant(dfdiff.values[WINDOW_SIZE:],dtype = tf.float32))
  10. #We put all data into one batch for better efficiency since the data volume is small.
  11. ds_train = tf.data.Dataset.zip((ds_data,ds_label)).batch(38).cache()

2. Model Definition

Usually there are three ways of modeling using APIs of Keras: sequential modeling using Sequential() function, arbitrary modeling using functional API, and customized modeling by inheriting base class Model.

Here we use functional API for modeling.

  1. #We design the following block since the daily increment of confirmed, discharged and deceased cases are equal or larger than zero.
  2. class Block(layers.Layer):
  3. def __init__(self, **kwargs):
  4. super(Block, self).__init__(**kwargs)
  5. def call(self, x_input,x):
  6. x_out = tf.maximum((1+x)*x_input[:,-1,:],0.0)
  7. return x_out
  8. def get_config(self):
  9. config = super(Block, self).get_config()
  10. return config
  1. tf.keras.backend.clear_session()
  2. x_input = layers.Input(shape = (None,3),dtype = tf.float32)
  3. x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x_input)
  4. x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)
  5. x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)
  6. x = layers.LSTM(3,input_shape=(None,3))(x)
  7. x = layers.Dense(3)(x)
  8. #We design the following block since the daily increment of confirmed, discharged and deseased cases are equal or larger than zero.
  9. #x = tf.maximum((1+x)*x_input[:,-1,:],0.0)
  10. x = Block()(x_input,x)
  11. model = models.Model(inputs = [x_input],outputs = [x])
  12. model.summary()
  1. Model: "model"
  2. _________________________________________________________________
  3. Layer (type) Output Shape Param #
  4. =================================================================
  5. input_1 (InputLayer) [(None, None, 3)] 0
  6. _________________________________________________________________
  7. lstm (LSTM) (None, None, 3) 84
  8. _________________________________________________________________
  9. lstm_1 (LSTM) (None, None, 3) 84
  10. _________________________________________________________________
  11. lstm_2 (LSTM) (None, None, 3) 84
  12. _________________________________________________________________
  13. lstm_3 (LSTM) (None, 3) 84
  14. _________________________________________________________________
  15. dense (Dense) (None, 3) 12
  16. _________________________________________________________________
  17. block (Block) (None, 3) 0
  18. =================================================================
  19. Total params: 348
  20. Trainable params: 348
  21. Non-trainable params: 0
  22. _________________________________________________________________

3. Model Training

There are three usual ways for model training: use internal function fit, use internal function train_on_batch, and customized training loop. Here we use the simplist way: using internal function fit.

Note: The parameter adjustment of RNN is more difficult comparing to other types of neural network. We need to try various learning rate to achieve a satisfying result.

  1. #Customized loss function, consider the ratio between square error and the prediction
  2. class MSPE(losses.Loss):
  3. def call(self,y_true,y_pred):
  4. err_percent = (y_true - y_pred)**2/(tf.maximum(y_true**2,1e-7))
  5. mean_err_percent = tf.reduce_mean(err_percent)
  6. return mean_err_percent
  7. def get_config(self):
  8. config = super(MSPE, self).get_config()
  9. return config
  1. import os
  2. import datetime
  3. optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
  4. model.compile(optimizer=optimizer,loss=MSPE(name = "MSPE"))
  5. stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  6. logdir = os.path.join('data', 'autograph', stamp)
  7. ## We recommend using pathlib under Python3
  8. # from pathlib import Path
  9. # stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
  10. # logdir = str(Path('../data/autograph/' + stamp))
  11. tb_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
  12. #Half the learning rate if loss is not improved after 100 epoches
  13. lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor="loss",factor = 0.5, patience = 100)
  14. #Stop training when loss is not improved after 200 epoches
  15. stop_callback = tf.keras.callbacks.EarlyStopping(monitor = "loss", patience= 200)
  16. callbacks_list = [tb_callback,lr_callback,stop_callback]
  17. history = model.fit(ds_train,epochs=500,callbacks = callbacks_list)
  1. Epoch 371/500
  2. 1/1 [==============================] - 0s 61ms/step - loss: 0.1184
  3. Epoch 372/500
  4. 1/1 [==============================] - 0s 64ms/step - loss: 0.1177
  5. Epoch 373/500
  6. 1/1 [==============================] - 0s 56ms/step - loss: 0.1169
  7. Epoch 374/500
  8. 1/1 [==============================] - 0s 50ms/step - loss: 0.1161
  9. Epoch 375/500
  10. 1/1 [==============================] - 0s 55ms/step - loss: 0.1154
  11. Epoch 376/500
  12. 1/1 [==============================] - 0s 55ms/step - loss: 0.1147
  13. Epoch 377/500
  14. 1/1 [==============================] - 0s 62ms/step - loss: 0.1140
  15. Epoch 378/500
  16. 1/1 [==============================] - 0s 93ms/step - loss: 0.1133
  17. Epoch 379/500
  18. 1/1 [==============================] - 0s 85ms/step - loss: 0.1126
  19. Epoch 380/500
  20. 1/1 [==============================] - 0s 68ms/step - loss: 0.1119
  21. Epoch 381/500
  22. 1/1 [==============================] - 0s 52ms/step - loss: 0.1113
  23. Epoch 382/500
  24. 1/1 [==============================] - 0s 54ms/step - loss: 0.1107
  25. Epoch 383/500
  26. 1/1 [==============================] - 0s 55ms/step - loss: 0.1100
  27. Epoch 384/500
  28. 1/1 [==============================] - 0s 56ms/step - loss: 0.1094
  29. Epoch 385/500
  30. 1/1 [==============================] - 0s 54ms/step - loss: 0.1088
  31. Epoch 386/500
  32. 1/1 [==============================] - 0s 74ms/step - loss: 0.1082
  33. Epoch 387/500
  34. 1/1 [==============================] - 0s 60ms/step - loss: 0.1077
  35. Epoch 388/500
  36. 1/1 [==============================] - 0s 52ms/step - loss: 0.1071
  37. Epoch 389/500
  38. 1/1 [==============================] - 0s 52ms/step - loss: 0.1066
  39. Epoch 390/500
  40. 1/1 [==============================] - 0s 56ms/step - loss: 0.1060
  41. Epoch 391/500
  42. 1/1 [==============================] - 0s 61ms/step - loss: 0.1055
  43. Epoch 392/500
  44. 1/1 [==============================] - 0s 60ms/step - loss: 0.1050
  45. Epoch 393/500
  46. 1/1 [==============================] - 0s 59ms/step - loss: 0.1045
  47. Epoch 394/500
  48. 1/1 [==============================] - 0s 65ms/step - loss: 0.1040
  49. Epoch 395/500
  50. 1/1 [==============================] - 0s 58ms/step - loss: 0.1035
  51. Epoch 396/500
  52. 1/1 [==============================] - 0s 52ms/step - loss: 0.1031
  53. Epoch 397/500
  54. 1/1 [==============================] - 0s 58ms/step - loss: 0.1026
  55. Epoch 398/500
  56. 1/1 [==============================] - 0s 60ms/step - loss: 0.1022
  57. Epoch 399/500
  58. 1/1 [==============================] - 0s 57ms/step - loss: 0.1017
  59. Epoch 400/500
  60. 1/1 [==============================] - 0s 63ms/step - loss: 0.1013
  61. Epoch 401/500
  62. 1/1 [==============================] - 0s 59ms/step - loss: 0.1009
  63. Epoch 402/500
  64. 1/1 [==============================] - 0s 53ms/step - loss: 0.1005
  65. Epoch 403/500
  66. 1/1 [==============================] - 0s 56ms/step - loss: 0.1001
  67. Epoch 404/500
  68. 1/1 [==============================] - 0s 55ms/step - loss: 0.0997
  69. Epoch 405/500
  70. 1/1 [==============================] - 0s 58ms/step - loss: 0.0993
  71. Epoch 406/500
  72. 1/1 [==============================] - 0s 53ms/step - loss: 0.0990
  73. Epoch 407/500
  74. 1/1 [==============================] - 0s 59ms/step - loss: 0.0986
  75. Epoch 408/500
  76. 1/1 [==============================] - 0s 63ms/step - loss: 0.0982
  77. Epoch 409/500
  78. 1/1 [==============================] - 0s 67ms/step - loss: 0.0979
  79. Epoch 410/500
  80. 1/1 [==============================] - 0s 55ms/step - loss: 0.0976
  81. Epoch 411/500
  82. 1/1 [==============================] - 0s 54ms/step - loss: 0.0972
  83. Epoch 412/500
  84. 1/1 [==============================] - 0s 55ms/step - loss: 0.0969
  85. Epoch 413/500
  86. 1/1 [==============================] - 0s 55ms/step - loss: 0.0966
  87. Epoch 414/500
  88. 1/1 [==============================] - 0s 59ms/step - loss: 0.0963
  89. Epoch 415/500
  90. 1/1 [==============================] - 0s 60ms/step - loss: 0.0960
  91. Epoch 416/500
  92. 1/1 [==============================] - 0s 62ms/step - loss: 0.0957
  93. Epoch 417/500
  94. 1/1 [==============================] - 0s 69ms/step - loss: 0.0954
  95. Epoch 418/500
  96. 1/1 [==============================] - 0s 60ms/step - loss: 0.0951
  97. Epoch 419/500
  98. 1/1 [==============================] - 0s 50ms/step - loss: 0.0948
  99. Epoch 420/500
  100. 1/1 [==============================] - 0s 56ms/step - loss: 0.0946
  101. Epoch 421/500
  102. 1/1 [==============================] - 0s 57ms/step - loss: 0.0943
  103. Epoch 422/500
  104. 1/1 [==============================] - 0s 55ms/step - loss: 0.0941
  105. Epoch 423/500
  106. 1/1 [==============================] - 0s 62ms/step - loss: 0.0938
  107. Epoch 424/500
  108. 1/1 [==============================] - 0s 60ms/step - loss: 0.0936
  109. Epoch 425/500
  110. 1/1 [==============================] - 0s 100ms/step - loss: 0.0933
  111. Epoch 426/500
  112. 1/1 [==============================] - 0s 68ms/step - loss: 0.0931
  113. Epoch 427/500
  114. 1/1 [==============================] - 0s 60ms/step - loss: 0.0929
  115. Epoch 428/500
  116. 1/1 [==============================] - 0s 50ms/step - loss: 0.0926
  117. Epoch 429/500
  118. 1/1 [==============================] - 0s 55ms/step - loss: 0.0924
  119. Epoch 430/500
  120. 1/1 [==============================] - 0s 57ms/step - loss: 0.0922
  121. Epoch 431/500
  122. 1/1 [==============================] - 0s 75ms/step - loss: 0.0920
  123. Epoch 432/500
  124. 1/1 [==============================] - 0s 57ms/step - loss: 0.0918
  125. Epoch 433/500
  126. 1/1 [==============================] - 0s 77ms/step - loss: 0.0916
  127. Epoch 434/500
  128. 1/1 [==============================] - 0s 50ms/step - loss: 0.0914
  129. Epoch 435/500
  130. 1/1 [==============================] - 0s 56ms/step - loss: 0.0912
  131. Epoch 436/500
  132. 1/1 [==============================] - 0s 60ms/step - loss: 0.0911
  133. Epoch 437/500
  134. 1/1 [==============================] - 0s 55ms/step - loss: 0.0909
  135. Epoch 438/500
  136. 1/1 [==============================] - 0s 57ms/step - loss: 0.0907
  137. Epoch 439/500
  138. 1/1 [==============================] - 0s 59ms/step - loss: 0.0905
  139. Epoch 440/500
  140. 1/1 [==============================] - 0s 60ms/step - loss: 0.0904
  141. Epoch 441/500
  142. 1/1 [==============================] - 0s 68ms/step - loss: 0.0902
  143. Epoch 442/500
  144. 1/1 [==============================] - 0s 73ms/step - loss: 0.0901
  145. Epoch 443/500
  146. 1/1 [==============================] - 0s 50ms/step - loss: 0.0899
  147. Epoch 444/500
  148. 1/1 [==============================] - 0s 58ms/step - loss: 0.0898
  149. Epoch 445/500
  150. 1/1 [==============================] - 0s 56ms/step - loss: 0.0896
  151. Epoch 446/500
  152. 1/1 [==============================] - 0s 52ms/step - loss: 0.0895
  153. Epoch 447/500
  154. 1/1 [==============================] - 0s 60ms/step - loss: 0.0893
  155. Epoch 448/500
  156. 1/1 [==============================] - 0s 64ms/step - loss: 0.0892
  157. Epoch 449/500
  158. 1/1 [==============================] - 0s 70ms/step - loss: 0.0891
  159. Epoch 450/500
  160. 1/1 [==============================] - 0s 57ms/step - loss: 0.0889
  161. Epoch 451/500
  162. 1/1 [==============================] - 0s 53ms/step - loss: 0.0888
  163. Epoch 452/500
  164. 1/1 [==============================] - 0s 51ms/step - loss: 0.0887
  165. Epoch 453/500
  166. 1/1 [==============================] - 0s 55ms/step - loss: 0.0886
  167. Epoch 454/500
  168. 1/1 [==============================] - 0s 58ms/step - loss: 0.0885
  169. Epoch 455/500
  170. 1/1 [==============================] - 0s 55ms/step - loss: 0.0883
  171. Epoch 456/500
  172. 1/1 [==============================] - 0s 71ms/step - loss: 0.0882
  173. Epoch 457/500
  174. 1/1 [==============================] - 0s 50ms/step - loss: 0.0881
  175. Epoch 458/500
  176. 1/1 [==============================] - 0s 56ms/step - loss: 0.0880
  177. Epoch 459/500
  178. 1/1 [==============================] - 0s 55ms/step - loss: 0.0879
  179. Epoch 460/500
  180. 1/1 [==============================] - 0s 57ms/step - loss: 0.0878
  181. Epoch 461/500
  182. 1/1 [==============================] - 0s 56ms/step - loss: 0.0878
  183. Epoch 462/500
  184. 1/1 [==============================] - 0s 55ms/step - loss: 0.0879
  185. Epoch 463/500
  186. 1/1 [==============================] - 0s 60ms/step - loss: 0.0879
  187. Epoch 464/500
  188. 1/1 [==============================] - 0s 68ms/step - loss: 0.0888
  189. Epoch 465/500
  190. 1/1 [==============================] - 0s 62ms/step - loss: 0.0875
  191. Epoch 466/500
  192. 1/1 [==============================] - 0s 55ms/step - loss: 0.0873
  193. Epoch 467/500
  194. 1/1 [==============================] - 0s 49ms/step - loss: 0.0872
  195. Epoch 468/500
  196. 1/1 [==============================] - 0s 56ms/step - loss: 0.0872
  197. Epoch 469/500
  198. 1/1 [==============================] - 0s 55ms/step - loss: 0.0871
  199. Epoch 470/500
  200. 1/1 [==============================] - 0s 55ms/step - loss: 0.0871
  201. Epoch 471/500
  202. 1/1 [==============================] - 0s 59ms/step - loss: 0.0870
  203. Epoch 472/500
  204. 1/1 [==============================] - 0s 68ms/step - loss: 0.0871
  205. Epoch 473/500
  206. 1/1 [==============================] - 0s 57ms/step - loss: 0.0869
  207. Epoch 474/500
  208. 1/1 [==============================] - 0s 61ms/step - loss: 0.0870
  209. Epoch 475/500
  210. 1/1 [==============================] - 0s 47ms/step - loss: 0.0868
  211. Epoch 476/500
  212. 1/1 [==============================] - 0s 55ms/step - loss: 0.0868
  213. Epoch 477/500
  214. 1/1 [==============================] - 0s 62ms/step - loss: 0.0866
  215. Epoch 478/500
  216. 1/1 [==============================] - 0s 58ms/step - loss: 0.0867
  217. Epoch 479/500
  218. 1/1 [==============================] - 0s 60ms/step - loss: 0.0865
  219. Epoch 480/500
  220. 1/1 [==============================] - 0s 65ms/step - loss: 0.0866
  221. Epoch 481/500
  222. 1/1 [==============================] - 0s 58ms/step - loss: 0.0864
  223. Epoch 482/500
  224. 1/1 [==============================] - 0s 57ms/step - loss: 0.0865
  225. Epoch 483/500
  226. 1/1 [==============================] - 0s 53ms/step - loss: 0.0863
  227. Epoch 484/500
  228. 1/1 [==============================] - 0s 56ms/step - loss: 0.0864
  229. Epoch 485/500
  230. 1/1 [==============================] - 0s 56ms/step - loss: 0.0862
  231. Epoch 486/500
  232. 1/1 [==============================] - 0s 55ms/step - loss: 0.0863
  233. Epoch 487/500
  234. 1/1 [==============================] - 0s 52ms/step - loss: 0.0861
  235. Epoch 488/500
  236. 1/1 [==============================] - 0s 68ms/step - loss: 0.0862
  237. Epoch 489/500
  238. 1/1 [==============================] - 0s 62ms/step - loss: 0.0860
  239. Epoch 490/500
  240. 1/1 [==============================] - 0s 57ms/step - loss: 0.0861
  241. Epoch 491/500
  242. 1/1 [==============================] - 0s 51ms/step - loss: 0.0859
  243. Epoch 492/500
  244. 1/1 [==============================] - 0s 54ms/step - loss: 0.0860
  245. Epoch 493/500
  246. 1/1 [==============================] - 0s 51ms/step - loss: 0.0859
  247. Epoch 494/500
  248. 1/1 [==============================] - 0s 54ms/step - loss: 0.0860
  249. Epoch 495/500
  250. 1/1 [==============================] - 0s 50ms/step - loss: 0.0858
  251. Epoch 496/500
  252. 1/1 [==============================] - 0s 69ms/step - loss: 0.0859
  253. Epoch 497/500
  254. 1/1 [==============================] - 0s 63ms/step - loss: 0.0857
  255. Epoch 498/500
  256. 1/1 [==============================] - 0s 56ms/step - loss: 0.0858
  257. Epoch 499/500
  258. 1/1 [==============================] - 0s 54ms/step - loss: 0.0857
  259. Epoch 500/500
  260. 1/1 [==============================] - 0s 57ms/step - loss: 0.0858

4. Model Evaluation

Model evaluation usually needs both evaluation and testing sets. We only have very few data in this case so we only visualize the changes of loss function during iteration.

  1. %matplotlib inline
  2. %config InlineBackend.figure_format = 'svg'
  3. import matplotlib.pyplot as plt
  4. def plot_metric(history, metric):
  5. train_metrics = history.history[metric]
  6. epochs = range(1, len(train_metrics) + 1)
  7. plt.plot(epochs, train_metrics, 'bo--')
  8. plt.title('Training '+ metric)
  9. plt.xlabel("Epochs")
  10. plt.ylabel(metric)
  11. plt.legend(["train_"+metric])
  12. plt.show()
  1. plot_metric(history,"loss")

1-4 Example: Modeling Procedure for Temporal Sequences - 图5

5. Model Application

We predict the time of the end of COVID-19 here, i.e. the date when the daily increment of new confirmed cases = 0.

  1. #This "dfresult" is used to record the current and predicted data
  2. dfresult = dfdiff[["confirmed_num","cured_num","dead_num"]].copy()
  3. dfresult.tail()

1-4 Example: Modeling Procedure for Temporal Sequences - 图6

  1. #Predicting the daily increment of the new confirmed cases of the next 100 days; add this result into dfresult
  2. for i in range(100):
  3. arr_predict = model.predict(tf.constant(tf.expand_dims(dfresult.values[-38:,:],axis = 0)))
  4. dfpredict = pd.DataFrame(tf.cast(tf.floor(arr_predict),tf.float32).numpy(),
  5. columns = dfresult.columns)
  6. dfresult = dfresult.append(dfpredict,ignore_index=True)
  1. dfresult.query("confirmed_num==0").head()
  2. # From Day 55 the daily increment of the new confirmed cases reduced to zero. Since Day 45 is corresponding to March 10, the daily increment of the news confirmed cases will reduce to 0 in Manch 20.
  3. # Note: this prediction is TOO optimistic

1-4 Example: Modeling Procedure for Temporal Sequences - 图7

  1. dfresult.query("cured_num==0").head()
  2. # The daily increment of the discharged (cured) cases will reduce to 0 in Day 164, which is about 4 months after March 10 (i.e. July 10) all the patients will be discharged.
  3. # Note: this prediction is TOO pessimistic and problematic: the total sum of the daily increment of discharged cases is larger than cumulated confirmed cases.

1-4 Example: Modeling Procedure for Temporal Sequences - 图8

  1. dfresult.query("dead_num==0").head()
  2. # The daily increment of the deceased will be reduced to 0 from Day 60, which is March 25, 2020
  3. # Note: This prediction is relatively reasonable.

1-4 Example: Modeling Procedure for Temporal Sequences - 图9

6. Model Saving

Model saving with the original way of TensorFlow is recommended.

  1. model.save('../data/tf_model_savedmodel', save_format="tf")
  2. print('export saved model.')
  1. model_loaded = tf.keras.models.load_model('../data/tf_model_savedmodel',compile=False)
  2. optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
  3. model_loaded.compile(optimizer=optimizer,loss=MSPE(name = "MSPE"))
  4. model_loaded.predict(ds_train)

Please leave comments in the WeChat official account “Python与算法之美” (Elegance of Python and Algorithms) if you want to communicate with the author about the content. The author will try best to reply given the limited time available.

You are also welcomed to join the group chat with the other readers through replying 加群 (join group) in the WeChat official account.

image.png