TPU 基础使用

在 TPU 上进行 TensorFlow 分布式训练的核心API是 tf.distribute.TPUStrategy ,可以简单几行代码就实现在 TPU 上的分布式训练,同时也可以很容易的迁移到 GPU单机多卡、多机多卡的环境。以下是如何实例化 TPUStrategy

  1. resolver = tf.distribute.resolver.TPUClusterResolver(
  2. tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
  3. tf.config.experimental_connect_to_host(resolver.master())
  4. tf.tpu.experimental.initialize_tpu_system(resolver)
  5. strategy = tf.distribute.experimental.TPUStrategy(resolver)

在上面的代码中,首先我们通过 TPU 的 IP 和端口实例化 TPUClusterResolver;然后,我们通过 resolver 连接到 TPU 上,并对其进行初始化;最后,完成实例化 TPUStrategy

以下使用 Fashion MNIST 分类任务展示 TPU 的使用方式。本小节的源代码可以在 https://github.com/huan/tensorflow-handbook-tpu 找到。

更方便的是在 Google Colab 上直接打开本例子的 Jupyter 直接运行,地址:https://colab.research.google.com/github/huan/tensorflow-handbook-tpu/blob/master/tensorflow-handbook-tpu-example.ipynb (推荐)

  1. import tensorflow as tf
  2. import numpy as np
  3. import os
  4.  
  5. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
  6.  
  7. # add empty color dimension
  8. x_train = np.expand_dims(x_train, -1)
  9. x_test = np.expand_dims(x_test, -1)
  10.  
  11. def create_model():
  12. model = tf.keras.models.Sequential()
  13.  
  14. model.add(tf.keras.layers.Conv2D(64, (3, 3), input_shape=x_train.shape[1:]))
  15. model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
  16. model.add(tf.keras.layers.Activation('relu'))
  17.  
  18. model.add(tf.keras.layers.Flatten())
  19. model.add(tf.keras.layers.Dense(10))
  20. model.add(tf.keras.layers.Activation('softmax'))
  21.  
  22. return model
  23.  
  24. resolver = tf.distribute.resolver.TPUClusterResolver(
  25. tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
  26. tf.config.experimental_connect_to_host(resolver.master())
  27. tf.tpu.experimental.initialize_tpu_system(resolver)
  28. strategy = tf.distribute.experimental.TPUStrategy(resolver)
  29.  
  30. with strategy.scope():
  31. model = create_model()
  32. model.compile(
  33. optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
  34. loss=tf.keras.losses.sparse_categorical_crossentropy,
  35. metrics=[tf.keras.metrics.sparse_categorical_accuracy])
  36.  
  37. model.fit(
  38. x_train.astype(np.float32), y_train.astype(np.float32),
  39. epochs=5,
  40. steps_per_epoch=60,
  41. validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
  42. validation_freq=5
  43. )

以上程序运行输出为:

  1. Epoch 1/5
  2. 60/60 [==========] - 1s 23ms/step - loss: 12.7235 - accuracy: 0.7156
  3. Epoch 2/5
  4. 60/60 [==========] - 1s 11ms/step - loss: 0.7600 - accuracy: 0.8598
  5. Epoch 3/5
  6. 60/60 [==========] - 1s 11ms/step - loss: 0.4443 - accuracy: 0.8830
  7. Epoch 4/5
  8. 60/60 [==========] - 1s 11ms/step - loss: 0.3401 - accuracy: 0.8972
  9. Epoch 5/5
  10. 60/60 [==========] - 4s 60ms/step - loss: 0.2867 - accuracy: 0.9072
  11. 10/10 [==========] - 2s 158ms/step
  12. 10/10 [==========] - 2s 158ms/step
  13. val_loss: 0.3893 - val_sparse_categorical_accuracy: 0.8848