6-5 Model Training Using TPU

It only requires six additional lines of code when training your model using TPU on Google Colab.

In Colab notebook, choose TPU in Edit -> Notebook Settings -> Hardware Accelerator.

Note: the following code only executes on Colab.

You may use the following link for testing (tf_TPU, in Chinese)

https://colab.research.google.com/drive/1XCIhATyE1R7lq6uwFlYlRsUr5d9_-r1s

  1. %tensorflow_version 2.x
  2. import tensorflow as tf
  3. print(tf.__version__)
  4. from tensorflow.keras import *

1. Data Preparation

  1. MAX_LEN = 300
  2. BATCH_SIZE = 32
  3. (x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
  4. x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
  5. x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
  6. MAX_WORDS = x_train.max()+1
  7. CAT_NUM = y_train.max()+1
  8. ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
  9. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  10. .prefetch(tf.data.experimental.AUTOTUNE).cache()
  11. ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
  12. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  13. .prefetch(tf.data.experimental.AUTOTUNE).cache()

2. Model Defining

  1. tf.keras.backend.clear_session()
  2. def create_model():
  3. model = models.Sequential()
  4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
  5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
  6. model.add(layers.MaxPool1D(2))
  7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
  8. model.add(layers.MaxPool1D(2))
  9. model.add(layers.Flatten())
  10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
  11. return(model)
  12. def compile_model(model):
  13. model.compile(optimizer=optimizers.Nadam(),
  14. loss=losses.SparseCategoricalCrossentropy(from_logits=True),
  15. metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
  16. return(model)

3. Model Training

  1. # The above mentioned 6 additional lines of code
  2. import os
  3. resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
  4. tf.config.experimental_connect_to_cluster(resolver)
  5. tf.tpu.experimental.initialize_tpu_system(resolver)
  6. strategy = tf.distribute.experimental.TPUStrategy(resolver)
  7. with strategy.scope():
  8. model = create_model()
  9. model.summary()
  10. model = compile_model(model)
  1. WARNING:tensorflow:TPU system 10.26.134.242:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
  2. WARNING:tensorflow:TPU system 10.26.134.242:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
  3. INFO:tensorflow:Initializing the TPU system: 10.26.134.242:8470
  4. INFO:tensorflow:Initializing the TPU system: 10.26.134.242:8470
  5. INFO:tensorflow:Clearing out eager caches
  6. INFO:tensorflow:Clearing out eager caches
  7. INFO:tensorflow:Finished initializing TPU system.
  8. INFO:tensorflow:Finished initializing TPU system.
  9. INFO:tensorflow:Found TPU system:
  10. INFO:tensorflow:Found TPU system:
  11. INFO:tensorflow:*** Num TPU Cores: 8
  12. INFO:tensorflow:*** Num TPU Cores: 8
  13. INFO:tensorflow:*** Num TPU Workers: 1
  14. INFO:tensorflow:*** Num TPU Workers: 1
  15. INFO:tensorflow:*** Num TPU Cores Per Worker: 8
  16. INFO:tensorflow:*** Num TPU Cores Per Worker: 8
  17. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
  18. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
  19. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
  20. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
  21. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
  22. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
  23. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
  24. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
  25. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
  26. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
  27. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
  28. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
  29. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
  30. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
  31. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
  32. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
  33. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
  34. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
  35. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
  36. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
  37. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
  38. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
  39. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
  40. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
  41. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
  42. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
  43. Model: "sequential"
  44. _________________________________________________________________
  45. Layer (type) Output Shape Param #
  46. =================================================================
  47. embedding (Embedding) (None, 300, 7) 216874
  48. _________________________________________________________________
  49. conv1d (Conv1D) (None, 296, 64) 2304
  50. _________________________________________________________________
  51. max_pooling1d (MaxPooling1D) (None, 148, 64) 0
  52. _________________________________________________________________
  53. conv1d_1 (Conv1D) (None, 146, 32) 6176
  54. _________________________________________________________________
  55. max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
  56. _________________________________________________________________
  57. flatten (Flatten) (None, 2336) 0
  58. _________________________________________________________________
  59. dense (Dense) (None, 46) 107502
  60. =================================================================
  61. Total params: 332,856
  62. Trainable params: 332,856
  63. Non-trainable params: 0
  64. _________________________________________________________________
  1. history = model.fit(ds_train,validation_data = ds_test,epochs = 10)
  1. Train for 281 steps, validate for 71 steps
  2. Epoch 1/10
  3. 281/281 [==============================] - 12s 43ms/step - loss: 3.4466 - sparse_categorical_accuracy: 0.4332 - sparse_top_k_categorical_accuracy: 0.7180 - val_loss: 3.3179 - val_sparse_categorical_accuracy: 0.5352 - val_sparse_top_k_categorical_accuracy: 0.7195
  4. Epoch 2/10
  5. 281/281 [==============================] - 6s 20ms/step - loss: 3.3251 - sparse_categorical_accuracy: 0.5405 - sparse_top_k_categorical_accuracy: 0.7302 - val_loss: 3.3082 - val_sparse_categorical_accuracy: 0.5463 - val_sparse_top_k_categorical_accuracy: 0.7235
  6. Epoch 3/10
  7. 281/281 [==============================] - 6s 20ms/step - loss: 3.2961 - sparse_categorical_accuracy: 0.5729 - sparse_top_k_categorical_accuracy: 0.7280 - val_loss: 3.3026 - val_sparse_categorical_accuracy: 0.5499 - val_sparse_top_k_categorical_accuracy: 0.7217
  8. Epoch 4/10
  9. 281/281 [==============================] - 5s 19ms/step - loss: 3.2751 - sparse_categorical_accuracy: 0.5924 - sparse_top_k_categorical_accuracy: 0.7276 - val_loss: 3.2957 - val_sparse_categorical_accuracy: 0.5543 - val_sparse_top_k_categorical_accuracy: 0.7217
  10. Epoch 5/10
  11. 281/281 [==============================] - 5s 19ms/step - loss: 3.2655 - sparse_categorical_accuracy: 0.6008 - sparse_top_k_categorical_accuracy: 0.7290 - val_loss: 3.3022 - val_sparse_categorical_accuracy: 0.5490 - val_sparse_top_k_categorical_accuracy: 0.7231
  12. Epoch 6/10
  13. 281/281 [==============================] - 5s 19ms/step - loss: 3.2616 - sparse_categorical_accuracy: 0.6041 - sparse_top_k_categorical_accuracy: 0.7295 - val_loss: 3.3015 - val_sparse_categorical_accuracy: 0.5503 - val_sparse_top_k_categorical_accuracy: 0.7235
  14. Epoch 7/10
  15. 281/281 [==============================] - 6s 21ms/step - loss: 3.2595 - sparse_categorical_accuracy: 0.6059 - sparse_top_k_categorical_accuracy: 0.7322 - val_loss: 3.3064 - val_sparse_categorical_accuracy: 0.5454 - val_sparse_top_k_categorical_accuracy: 0.7266
  16. Epoch 8/10
  17. 281/281 [==============================] - 6s 21ms/step - loss: 3.2591 - sparse_categorical_accuracy: 0.6063 - sparse_top_k_categorical_accuracy: 0.7327 - val_loss: 3.3025 - val_sparse_categorical_accuracy: 0.5481 - val_sparse_top_k_categorical_accuracy: 0.7231
  18. Epoch 9/10
  19. 281/281 [==============================] - 5s 19ms/step - loss: 3.2588 - sparse_categorical_accuracy: 0.6062 - sparse_top_k_categorical_accuracy: 0.7332 - val_loss: 3.2992 - val_sparse_categorical_accuracy: 0.5521 - val_sparse_top_k_categorical_accuracy: 0.7257
  20. Epoch 10/10
  21. 281/281 [==============================] - 5s 18ms/step - loss: 3.2577 - sparse_categorical_accuracy: 0.6073 - sparse_top_k_categorical_accuracy: 0.7363 - val_loss: 3.2981 - val_sparse_categorical_accuracy: 0.5516 - val_sparse_top_k_categorical_accuracy: 0.7306
  22. CPU times: user 18.9 s, sys: 3.86 s, total: 22.7 s
  23. Wall time: 1min 1s

Please leave comments in the WeChat official account “Python与算法之美” (Elegance of Python and Algorithms) if you want to communicate with the author about the content. The author will try best to reply given the limited time available.

You are also welcomed to join the group chat with the other readers through replying 加群 (join group) in the WeChat official account.

image.png