使用SavedModel完整导出模型

在部署模型时,我们的第一步往往是将训练好的整个模型完整导出为一系列标准格式的文件,然后即可在不同的平台上部署模型文件。这时,TensorFlow为我们提供了SavedModel这一格式。与前面介绍的Checkpoint不同,SavedModel包含了一个TensorFlow程序的完整信息: 不仅包含参数的权值,还包含计算的流程(即计算图) 。当模型导出为SavedModel文件时,无需建立模型的源代码即可再次运行模型,这使得SavedModel尤其适用于模型的分享和部署。后文的TensorFlow Serving(服务器端部署模型)、TensorFlow Lite(移动端部署模型)以及TensorFlow.js都会用到这一格式。

Keras模型均可方便地导出为SavedModel格式。不过需要注意的是,因为SavedModel基于计算图,所以对于使用继承 tf.keras.Model 类建立的Keras模型,其需要导出到SavedModel格式的方法(比如 call )都需要使用 @tf.function 修饰( @tf.function 的使用方式见 前文 )。然后,假设我们有一个名为 model 的Keras模型,使用下面的代码即可将模型导出为SavedModel:

  1. tf.saved_model.save(model, "保存的目标文件夹名称")

在需要载入SavedModel文件时,使用

  1. model = tf.saved_model.load("保存的目标文件夹名称")

即可。

提示

对于使用继承 tf.keras.Model 类建立的Keras模型 model ,使用SavedModel载入后将无法使用 model() 直接进行推断,而需要使用 model.call()

以下是一个简单的示例,将 前文MNIST手写体识别的模型 进行导出和导入。

导出模型到 saved/1 文件夹:

  1. import tensorflow as tf
  2. from zh.model.utils import MNISTLoader
  3.  
  4. num_epochs = 1
  5. batch_size = 50
  6. learning_rate = 0.001
  7.  
  8. model = tf.keras.models.Sequential([
  9. tf.keras.layers.Flatten(),
  10. tf.keras.layers.Dense(100, activation=tf.nn.relu),
  11. tf.keras.layers.Dense(10),
  12. tf.keras.layers.Softmax()
  13. ])
  14.  
  15. data_loader = MNISTLoader()
  16. model.compile(
  17. optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
  18. loss=tf.keras.losses.sparse_categorical_crossentropy,
  19. metrics=[tf.keras.metrics.sparse_categorical_accuracy]
  20. )
  21. model.fit(data_loader.train_data, data_loader.train_label, epochs=num_epochs, batch_size=batch_size)
  22. tf.saved_model.save(model, "saved/1")

saved/1 中的模型导入并测试性能:

  1. import tensorflow as tf
  2. from zh.model.utils import MNISTLoader
  3.  
  4. batch_size = 50
  5.  
  6. model = tf.saved_model.load("saved/1")
  7. data_loader = MNISTLoader()
  8. sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
  9. num_batches = int(data_loader.num_test_data // batch_size)
  10. for batch_index in range(num_batches):
  11. start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
  12. y_pred = model(data_loader.test_data[start_index: end_index])
  13. sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
  14. print("test accuracy: %f" % sparse_categorical_accuracy.result())

输出:

  1. test accuracy: 0.952000

使用继承 tf.keras.Model 类建立的Keras模型同样可以以相同方法导出,唯须注意 call 方法需要以 @tf.function 修饰,以转化为SavedModel支持的计算图,代码如下:

  1. class MLP(tf.keras.Model):
  2. def __init__(self):
  3. super().__init__()
  4. self.flatten = tf.keras.layers.Flatten()
  5. self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
  6. self.dense2 = tf.keras.layers.Dense(units=10)
  7.  
  8. @tf.function
  9. def call(self, inputs): # [batch_size, 28, 28, 1]
  10. x = self.flatten(inputs) # [batch_size, 784]
  11. x = self.dense1(x) # [batch_size, 100]
  12. x = self.dense2(x) # [batch_size, 10]
  13. output = tf.nn.softmax(x)
  14. return output
  15.  
  16. model = MLP()
  17. ...

模型导入并测试性能的过程也相同,唯须注意模型推断时需要显式调用 call 方法,即使用:

  1. ...
  2. y_pred = model.call(data_loader.test_data[start_index: end_index])
  3. ...