TPU 基础使用
在 TPU 上进行 TensorFlow 分布式训练的核心API是 tf.distribute.TPUStrategy
,可以简单几行代码就实现在 TPU 上的分布式训练,同时也可以很容易的迁移到 GPU单机多卡、多机多卡的环境。以下是如何实例化 TPUStrategy
:
- resolver = tf.distribute.resolver.TPUClusterResolver(
- tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
- tf.config.experimental_connect_to_host(resolver.master())
- tf.tpu.experimental.initialize_tpu_system(resolver)
- strategy = tf.distribute.experimental.TPUStrategy(resolver)
在上面的代码中,首先我们通过 TPU 的 IP 和端口实例化 TPUClusterResolver;然后,我们通过 resolver 连接到 TPU 上,并对其进行初始化;最后,完成实例化 TPUStrategy。
以下使用 Fashion MNIST 分类任务展示 TPU 的使用方式。本小节的源代码可以在 https://github.com/huan/tensorflow-handbook-tpu 找到。
更方便的是在 Google Colab 上直接打开本例子的 Jupyter 直接运行,地址:https://colab.research.google.com/github/huan/tensorflow-handbook-tpu/blob/master/tensorflow-handbook-tpu-example.ipynb (推荐)
- import tensorflow as tf
- import numpy as np
- import os
- (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
- # add empty color dimension
- x_train = np.expand_dims(x_train, -1)
- x_test = np.expand_dims(x_test, -1)
- def create_model():
- model = tf.keras.models.Sequential()
- model.add(tf.keras.layers.Conv2D(64, (3, 3), input_shape=x_train.shape[1:]))
- model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
- model.add(tf.keras.layers.Activation('relu'))
- model.add(tf.keras.layers.Flatten())
- model.add(tf.keras.layers.Dense(10))
- model.add(tf.keras.layers.Activation('softmax'))
- return model
- resolver = tf.distribute.resolver.TPUClusterResolver(
- tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
- tf.config.experimental_connect_to_host(resolver.master())
- tf.tpu.experimental.initialize_tpu_system(resolver)
- strategy = tf.distribute.experimental.TPUStrategy(resolver)
- with strategy.scope():
- model = create_model()
- model.compile(
- optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
- loss=tf.keras.losses.sparse_categorical_crossentropy,
- metrics=[tf.keras.metrics.sparse_categorical_accuracy])
- model.fit(
- x_train.astype(np.float32), y_train.astype(np.float32),
- epochs=5,
- steps_per_epoch=60,
- validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
- validation_freq=5
- )
以上程序运行输出为:
- Epoch 1/5
- 60/60 [==========] - 1s 23ms/step - loss: 12.7235 - accuracy: 0.7156
- Epoch 2/5
- 60/60 [==========] - 1s 11ms/step - loss: 0.7600 - accuracy: 0.8598
- Epoch 3/5
- 60/60 [==========] - 1s 11ms/step - loss: 0.4443 - accuracy: 0.8830
- Epoch 4/5
- 60/60 [==========] - 1s 11ms/step - loss: 0.3401 - accuracy: 0.8972
- Epoch 5/5
- 60/60 [==========] - 4s 60ms/step - loss: 0.2867 - accuracy: 0.9072
- 10/10 [==========] - 2s 158ms/step
- 10/10 [==========] - 2s 158ms/step
- val_loss: 0.3893 - val_sparse_categorical_accuracy: 0.8848