6-2 Three Ways of Training

There are three ways of model training: using pre-defined fit method, using pre-defined tran_on_batch method, using customized training loop.

Note: fit_generator method is not recommended in tf.keras since it has been merged into fit.

  1. import numpy as np
  2. import pandas as pd
  3. import tensorflow as tf
  4. from tensorflow.keras import *
  5. # Time stamps
  6. @tf.function
  7. def printbar():
  8. ts = tf.timestamp()
  9. today_ts = ts%(24*60*60)
  10. hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
  11. minite = tf.cast((today_ts%3600)//60,tf.int32)
  12. second = tf.cast(tf.floor(today_ts%60),tf.int32)
  13. def timeformat(m):
  14. if tf.strings.length(tf.strings.format("{}",m))==1:
  15. return(tf.strings.format("0{}",m))
  16. else:
  17. return(tf.strings.format("{}",m))
  18. timestring = tf.strings.join([timeformat(hour),timeformat(minite),
  19. timeformat(second)],separator = ":")
  20. tf.print("=========="*8,end = "")
  21. tf.print(timestring)
  1. MAX_LEN = 300
  2. BATCH_SIZE = 32
  3. (x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
  4. x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
  5. x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
  6. MAX_WORDS = x_train.max()+1
  7. CAT_NUM = y_train.max()+1
  8. ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
  9. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  10. .prefetch(tf.data.experimental.AUTOTUNE).cache()
  11. ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
  12. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  13. .prefetch(tf.data.experimental.AUTOTUNE).cache()

1. Pre-defined fit method

This is a powerful method, which supports training the data with types of numpy array, tf.data.Dataset and Python generator.

This method also supports complicated logical controlling through proper configuration of the callbacks.

  1. tf.keras.backend.clear_session()
  2. def create_model():
  3. model = models.Sequential()
  4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
  5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
  6. model.add(layers.MaxPool1D(2))
  7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
  8. model.add(layers.MaxPool1D(2))
  9. model.add(layers.Flatten())
  10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
  11. return(model)
  12. def compile_model(model):
  13. model.compile(optimizer=optimizers.Nadam(),
  14. loss=losses.SparseCategoricalCrossentropy(),
  15. metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
  16. return(model)
  17. model = create_model()
  18. model.summary()
  19. model = compile_model(model)
  1. Model: "sequential"
  2. _________________________________________________________________
  3. Layer (type) Output Shape Param #
  4. =================================================================
  5. embedding (Embedding) (None, 300, 7) 216874
  6. _________________________________________________________________
  7. conv1d (Conv1D) (None, 296, 64) 2304
  8. _________________________________________________________________
  9. max_pooling1d (MaxPooling1D) (None, 148, 64) 0
  10. _________________________________________________________________
  11. conv1d_1 (Conv1D) (None, 146, 32) 6176
  12. _________________________________________________________________
  13. max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
  14. _________________________________________________________________
  15. flatten (Flatten) (None, 2336) 0
  16. _________________________________________________________________
  17. dense (Dense) (None, 46) 107502
  18. =================================================================
  19. Total params: 332,856
  20. Trainable params: 332,856
  21. Non-trainable params: 0
  22. _________________________________________________________________
  1. history = model.fit(ds_train,validation_data = ds_test,epochs = 10)
  1. Train for 281 steps, validate for 71 steps
  2. Epoch 1/10
  3. 281/281 [==============================] - 11s 37ms/step - loss: 2.0231 - sparse_categorical_accuracy: 0.4636 - sparse_top_k_categorical_accuracy: 0.7450 - val_loss: 1.7346 - val_sparse_categorical_accuracy: 0.5534 - val_sparse_top_k_categorical_accuracy: 0.7560
  4. Epoch 2/10
  5. 281/281 [==============================] - 9s 31ms/step - loss: 1.5079 - sparse_categorical_accuracy: 0.6091 - sparse_top_k_categorical_accuracy: 0.7901 - val_loss: 1.5475 - val_sparse_categorical_accuracy: 0.6109 - val_sparse_top_k_categorical_accuracy: 0.7792
  6. Epoch 3/10
  7. 281/281 [==============================] - 9s 33ms/step - loss: 1.2204 - sparse_categorical_accuracy: 0.6823 - sparse_top_k_categorical_accuracy: 0.8448 - val_loss: 1.5455 - val_sparse_categorical_accuracy: 0.6367 - val_sparse_top_k_categorical_accuracy: 0.8001
  8. Epoch 4/10
  9. 281/281 [==============================] - 9s 33ms/step - loss: 0.9382 - sparse_categorical_accuracy: 0.7543 - sparse_top_k_categorical_accuracy: 0.9075 - val_loss: 1.6780 - val_sparse_categorical_accuracy: 0.6398 - val_sparse_top_k_categorical_accuracy: 0.8032
  10. Epoch 5/10
  11. 281/281 [==============================] - 10s 34ms/step - loss: 0.6791 - sparse_categorical_accuracy: 0.8255 - sparse_top_k_categorical_accuracy: 0.9513 - val_loss: 1.9426 - val_sparse_categorical_accuracy: 0.6376 - val_sparse_top_k_categorical_accuracy: 0.7956
  12. Epoch 6/10
  13. 281/281 [==============================] - 9s 33ms/step - loss: 0.5063 - sparse_categorical_accuracy: 0.8762 - sparse_top_k_categorical_accuracy: 0.9716 - val_loss: 2.2141 - val_sparse_categorical_accuracy: 0.6291 - val_sparse_top_k_categorical_accuracy: 0.7947
  14. Epoch 7/10
  15. 281/281 [==============================] - 10s 37ms/step - loss: 0.4031 - sparse_categorical_accuracy: 0.9050 - sparse_top_k_categorical_accuracy: 0.9817 - val_loss: 2.4126 - val_sparse_categorical_accuracy: 0.6264 - val_sparse_top_k_categorical_accuracy: 0.7947
  16. Epoch 8/10
  17. 281/281 [==============================] - 10s 35ms/step - loss: 0.3380 - sparse_categorical_accuracy: 0.9205 - sparse_top_k_categorical_accuracy: 0.9881 - val_loss: 2.5366 - val_sparse_categorical_accuracy: 0.6242 - val_sparse_top_k_categorical_accuracy: 0.7974
  18. Epoch 9/10
  19. 281/281 [==============================] - 10s 36ms/step - loss: 0.2921 - sparse_categorical_accuracy: 0.9299 - sparse_top_k_categorical_accuracy: 0.9909 - val_loss: 2.6564 - val_sparse_categorical_accuracy: 0.6242 - val_sparse_top_k_categorical_accuracy: 0.7983
  20. Epoch 10/10
  21. 281/281 [==============================] - 9s 30ms/step - loss: 0.2613 - sparse_categorical_accuracy: 0.9334 - sparse_top_k_categorical_accuracy: 0.9947 - val_loss: 2.7365 - val_sparse_categorical_accuracy: 0.6220 - val_sparse_top_k_categorical_accuracy: 0.8005

2. Pre-defined train_on_batch method

This pre-defined method allows fine-controlling to the training procedure for each batch without the callbacks, which is even more flexible than fit method.

  1. tf.keras.backend.clear_session()
  2. def create_model():
  3. model = models.Sequential()
  4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
  5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
  6. model.add(layers.MaxPool1D(2))
  7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
  8. model.add(layers.MaxPool1D(2))
  9. model.add(layers.Flatten())
  10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
  11. return(model)
  12. def compile_model(model):
  13. model.compile(optimizer=optimizers.Nadam(),
  14. loss=losses.SparseCategoricalCrossentropy(),
  15. metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
  16. return(model)
  17. model = create_model()
  18. model.summary()
  19. model = compile_model(model)
  1. Model: "sequential"
  2. _________________________________________________________________
  3. Layer (type) Output Shape Param #
  4. =================================================================
  5. embedding (Embedding) (None, 300, 7) 216874
  6. _________________________________________________________________
  7. conv1d (Conv1D) (None, 296, 64) 2304
  8. _________________________________________________________________
  9. max_pooling1d (MaxPooling1D) (None, 148, 64) 0
  10. _________________________________________________________________
  11. conv1d_1 (Conv1D) (None, 146, 32) 6176
  12. _________________________________________________________________
  13. max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
  14. _________________________________________________________________
  15. flatten (Flatten) (None, 2336) 0
  16. _________________________________________________________________
  17. dense (Dense) (None, 46) 107502
  18. =================================================================
  19. Total params: 332,856
  20. Trainable params: 332,856
  21. Non-trainable params: 0
  22. _________________________________________________________________
  1. def train_model(model,ds_train,ds_valid,epoches):
  2. for epoch in tf.range(1,epoches+1):
  3. model.reset_metrics()
  4. # Reduce learning rate at the late stage of training.
  5. if epoch == 5:
  6. model.optimizer.lr.assign(model.optimizer.lr/2.0)
  7. tf.print("Lowering optimizer Learning Rate...\n\n")
  8. for x, y in ds_train:
  9. train_result = model.train_on_batch(x, y)
  10. for x, y in ds_valid:
  11. valid_result = model.test_on_batch(x, y,reset_metrics=False)
  12. if epoch%1 ==0:
  13. printbar()
  14. tf.print("epoch = ",epoch)
  15. print("train:",dict(zip(model.metrics_names,train_result)))
  16. print("valid:",dict(zip(model.metrics_names,valid_result)))
  17. print("")
  1. train_model(model,ds_train,ds_test,10)
  1. ================================================================================13:09:19
  2. epoch = 1
  3. train: {'loss': 0.82411176, 'sparse_categorical_accuracy': 0.77272725, 'sparse_top_k_categorical_accuracy': 0.8636364}
  4. valid: {'loss': 1.9265995, 'sparse_categorical_accuracy': 0.5743544, 'sparse_top_k_categorical_accuracy': 0.75779164}
  5. ================================================================================13:09:27
  6. epoch = 2
  7. train: {'loss': 0.6006621, 'sparse_categorical_accuracy': 0.90909094, 'sparse_top_k_categorical_accuracy': 0.95454544}
  8. valid: {'loss': 1.844159, 'sparse_categorical_accuracy': 0.6126447, 'sparse_top_k_categorical_accuracy': 0.7920748}
  9. ================================================================================13:09:35
  10. epoch = 3
  11. train: {'loss': 0.36935613, 'sparse_categorical_accuracy': 0.90909094, 'sparse_top_k_categorical_accuracy': 0.95454544}
  12. valid: {'loss': 2.163433, 'sparse_categorical_accuracy': 0.63312554, 'sparse_top_k_categorical_accuracy': 0.8045414}
  13. ================================================================================13:09:42
  14. epoch = 4
  15. train: {'loss': 0.2304088, 'sparse_categorical_accuracy': 0.90909094, 'sparse_top_k_categorical_accuracy': 1.0}
  16. valid: {'loss': 2.8911984, 'sparse_categorical_accuracy': 0.6344613, 'sparse_top_k_categorical_accuracy': 0.7978629}
  17. Lowering optimizer Learning Rate...
  18. ================================================================================13:09:51
  19. epoch = 5
  20. train: {'loss': 0.111194365, 'sparse_categorical_accuracy': 0.95454544, 'sparse_top_k_categorical_accuracy': 1.0}
  21. valid: {'loss': 3.6431572, 'sparse_categorical_accuracy': 0.6295637, 'sparse_top_k_categorical_accuracy': 0.7978629}
  22. ================================================================================13:09:59
  23. epoch = 6
  24. train: {'loss': 0.07741702, 'sparse_categorical_accuracy': 0.95454544, 'sparse_top_k_categorical_accuracy': 1.0}
  25. valid: {'loss': 4.074161, 'sparse_categorical_accuracy': 0.6255565, 'sparse_top_k_categorical_accuracy': 0.794301}
  26. ================================================================================13:10:07
  27. epoch = 7
  28. train: {'loss': 0.056113098, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
  29. valid: {'loss': 4.4461513, 'sparse_categorical_accuracy': 0.6273375, 'sparse_top_k_categorical_accuracy': 0.79652715}
  30. ================================================================================13:10:17
  31. epoch = 8
  32. train: {'loss': 0.043448802, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
  33. valid: {'loss': 4.7687583, 'sparse_categorical_accuracy': 0.6224399, 'sparse_top_k_categorical_accuracy': 0.79741764}
  34. ================================================================================13:10:26
  35. epoch = 9
  36. train: {'loss': 0.035002146, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
  37. valid: {'loss': 5.130505, 'sparse_categorical_accuracy': 0.6175423, 'sparse_top_k_categorical_accuracy': 0.794301}
  38. ================================================================================13:10:34
  39. epoch = 10
  40. train: {'loss': 0.028303564, 'sparse_categorical_accuracy': 1.0, 'sparse_top_k_categorical_accuracy': 1.0}
  41. valid: {'loss': 5.4559293, 'sparse_categorical_accuracy': 0.6148709, 'sparse_top_k_categorical_accuracy': 0.7947462}

3. Customized Training Loop

Re-compilation of the model is not required in the customized training loop, just back-propagate the iterative parameters through the optimizer according to the loss function, which gives us the highest flexibility.

  1. tf.keras.backend.clear_session()
  2. def create_model():
  3. model = models.Sequential()
  4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
  5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
  6. model.add(layers.MaxPool1D(2))
  7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
  8. model.add(layers.MaxPool1D(2))
  9. model.add(layers.Flatten())
  10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
  11. return(model)
  12. model = create_model()
  13. model.summary()
  1. optimizer = optimizers.Nadam()
  2. loss_func = losses.SparseCategoricalCrossentropy()
  3. train_loss = metrics.Mean(name='train_loss')
  4. train_metric = metrics.SparseCategoricalAccuracy(name='train_accuracy')
  5. valid_loss = metrics.Mean(name='valid_loss')
  6. valid_metric = metrics.SparseCategoricalAccuracy(name='valid_accuracy')
  7. @tf.function
  8. def train_step(model, features, labels):
  9. with tf.GradientTape() as tape:
  10. predictions = model(features,training = True)
  11. loss = loss_func(labels, predictions)
  12. gradients = tape.gradient(loss, model.trainable_variables)
  13. optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  14. train_loss.update_state(loss)
  15. train_metric.update_state(labels, predictions)
  16. @tf.function
  17. def valid_step(model, features, labels):
  18. predictions = model(features)
  19. batch_loss = loss_func(labels, predictions)
  20. valid_loss.update_state(batch_loss)
  21. valid_metric.update_state(labels, predictions)
  22. def train_model(model,ds_train,ds_valid,epochs):
  23. for epoch in tf.range(1,epochs+1):
  24. for features, labels in ds_train:
  25. train_step(model,features,labels)
  26. for features, labels in ds_valid:
  27. valid_step(model,features,labels)
  28. logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
  29. if epoch%1 ==0:
  30. printbar()
  31. tf.print(tf.strings.format(logs,
  32. (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
  33. tf.print("")
  34. train_loss.reset_states()
  35. valid_loss.reset_states()
  36. train_metric.reset_states()
  37. valid_metric.reset_states()
  38. train_model(model,ds_train,ds_test,10)
  1. ================================================================================13:12:03
  2. Epoch=1,Loss:2.02051544,Accuracy:0.460253835,Valid Loss:1.75700927,Valid Accuracy:0.536954582
  3. ================================================================================13:12:09
  4. Epoch=2,Loss:1.510795,Accuracy:0.610665798,Valid Loss:1.55349839,Valid Accuracy:0.616206586
  5. ================================================================================13:12:17
  6. Epoch=3,Loss:1.19221532,Accuracy:0.696170092,Valid Loss:1.52315605,Valid Accuracy:0.651380241
  7. ================================================================================13:12:23
  8. Epoch=4,Loss:0.90101546,Accuracy:0.766310394,Valid Loss:1.68327653,Valid Accuracy:0.648263574
  9. ================================================================================13:12:30
  10. Epoch=5,Loss:0.655430496,Accuracy:0.831329346,Valid Loss:1.90872383,Valid Accuracy:0.641139805
  11. ================================================================================13:12:37
  12. Epoch=6,Loss:0.492730737,Accuracy:0.877866864,Valid Loss:2.09966016,Valid Accuracy:0.63223511
  13. ================================================================================13:12:44
  14. Epoch=7,Loss:0.391238362,Accuracy:0.904030263,Valid Loss:2.27431226,Valid Accuracy:0.625111282
  15. ================================================================================13:12:51
  16. Epoch=8,Loss:0.327761739,Accuracy:0.922066331,Valid Loss:2.42568827,Valid Accuracy:0.617542326
  17. ================================================================================13:12:58
  18. Epoch=9,Loss:0.285573095,Accuracy:0.930527747,Valid Loss:2.55942106,Valid Accuracy:0.612644672
  19. ================================================================================13:13:05
  20. Epoch=10,Loss:0.255482465,Accuracy:0.936094403,Valid Loss:2.67789412,Valid Accuracy:0.612199485

Please leave comments in the WeChat official account “Python与算法之美” (Elegance of Python and Algorithms) if you want to communicate with the author about the content. The author will try best to reply given the limited time available.

You are also welcomed to join the group chat with the other readers through replying 加群 (join group) in the WeChat official account.

image.png