tf.train.Checkpoint :变量的保存与恢复

警告

Checkpoint只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),请参考 “部署”章节中的SavedModel

很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。可能你第一个想到的是用Python的序列化模块 pickle 存储 model.variables。但不幸的是,TensorFlow的变量类型 ResourceVariable 并不能被序列化。

好在TensorFlow提供了 tf.train.Checkpoint 这一强大的变量保存与恢复类,可以使用其 save()restore() 方法将TensorFlow中所有包含Checkpointable State的对象进行保存和恢复。具体而言,tf.keras.optimizertf.Variabletf.keras.Layer 或者 tf.keras.Model 实例都可以被保存。其使用方法非常简单,我们首先声明一个Checkpoint:

  1. checkpoint = tf.train.Checkpoint(model=model)

这里 tf.train.Checkpoint() 接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 tf.keras.Model 的模型实例 model 和一个继承 tf.train.Optimizer 的优化器 optimizer ,我们可以这样写:

  1. checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)

这里 myAwesomeModel 是我们为待保存的模型 model 所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。

接下来,当模型训练完成需要保存的时候,使用:

  1. checkpoint.save(save_path_with_prefix)

就可以。 save_path_with_prefix 是保存文件的目录+前缀。

注解

例如,在源代码目录建立一个名为save的文件夹并调用一次 checkpoint.save('./save/model.ckpt') ,我们就可以在可以在save目录下发现名为 checkpointmodel.ckpt-1.indexmodel.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。checkpoint.save() 方法可以运行多次,每运行一次都会得到一个.index文件和.data文件,序号依次累加。

当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个checkpoint,同时保持键名的一致。再调用checkpoint的restore方法。就像下面这样:

  1. model_to_be_restored = MyModel() # 待恢复参数的同一模型
  2. checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) # 键名保持为“myAwesomeModel”
  3. checkpoint.restore(save_path_with_prefix_and_index)

即可恢复模型变量。 save_path_with_prefix_and_index 是之前保存的文件的目录+前缀+编号。例如,调用 checkpoint.restore('./save/model.ckpt-1') 就可以载入前缀为 model.ckpt ,序号为1的文件来恢复模型。

当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path) 这个辅助函数返回目录下最近一次checkpoint的文件名。例如如果save目录下有 model.ckpt-1.indexmodel.ckpt-10.index 的10个保存文件, tf.train.latest_checkpoint('./save') 即返回 ./save/model.ckpt-10

总体而言,恢复与保存变量的典型代码框架如下:

  1. # train.py 模型训练阶段
  2.  
  3. model = MyModel()
  4. # 实例化Checkpoint,指定保存对象为model(如果需要保存Optimizer的参数也可加入)
  5. checkpoint = tf.train.Checkpoint(myModel=model)
  6. # ...(模型训练代码)
  7. # 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次)
  8. checkpoint.save('./save/model.ckpt')
  1. # test.py 模型使用阶段
  2.  
  3. model = MyModel()
  4. checkpoint = tf.train.Checkpoint(myModel=model) # 实例化Checkpoint,指定恢复对象为model
  5. checkpoint.restore(tf.train.latest_checkpoint('./save')) # 从文件恢复模型参数
  6. # 模型使用代码

注解

tf.train.Checkpoint 与以前版本常用的 tf.train.Saver 相比,强大之处在于其支持在Eager Execution下“延迟”恢复变量。具体而言,当调用了 checkpoint.restore() ,但模型中的变量还没有被建立的时候,Checkpoint可以等到变量被建立的时候再进行数值的恢复。Eager Execution下,模型中各个层的初始化和变量的建立是在模型第一次被调用的时候才进行的(好处在于可以根据输入的张量形状而自动确定变量形状,无需手动指定)。这意味着当模型刚刚被实例化的时候,其实里面还一个变量都没有,这时候使用以往的方式去恢复变量数值是一定会报错的。比如,你可以试试在train.py调用 tf.keras.Modelsave_weight() 方法保存model的参数,并在test.py中实例化model后立即调用 load_weight() 方法,就会出错,只有当调用了一遍model之后再运行 load_weight() 方法才能得到正确的结果。可见, tf.train.Checkpoint 在这种情况下可以给我们带来相当大的便利。另外, tf.train.Checkpoint 同时也支持Graph Execution模式。

最后提供一个实例,以前章的 多层感知机模型 为例展示模型变量的保存和载入:

  1. import tensorflow as tf
  2. import numpy as np
  3. import argparse
  4. from zh.model.mnist.mlp import MLP
  5. from zh.model.utils import MNISTLoader
  6.  
  7. parser = argparse.ArgumentParser(description='Process some integers.')
  8. parser.add_argument('--mode', default='train', help='train or test')
  9. parser.add_argument('--num_epochs', default=1)
  10. parser.add_argument('--batch_size', default=50)
  11. parser.add_argument('--learning_rate', default=0.001)
  12. args = parser.parse_args()
  13. data_loader = MNISTLoader()
  14.  
  15.  
  16. def train():
  17. model = MLP()
  18. optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
  19. num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)
  20. checkpoint = tf.train.Checkpoint(myAwesomeModel=model) # 实例化Checkpoint,设置保存对象为model
  21. for batch_index in range(1, num_batches+1):
  22. X, y = data_loader.get_batch(args.batch_size)
  23. with tf.GradientTape() as tape:
  24. y_pred = model(X)
  25. loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
  26. loss = tf.reduce_mean(loss)
  27. print("batch %d: loss %f" % (batch_index, loss.numpy()))
  28. grads = tape.gradient(loss, model.variables)
  29. optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
  30. if batch_index % 100 == 0: # 每隔100个Batch保存一次
  31. path = checkpoint.save('./save/model.ckpt') # 保存模型参数到文件
  32. print("model saved to %s" % path)
  33.  
  34.  
  35. def test():
  36. model_to_be_restored = MLP()
  37. # 实例化Checkpoint,设置恢复对象为新建立的模型model_to_be_restored
  38. checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)
  39. checkpoint.restore(tf.train.latest_checkpoint('./save')) # 从文件恢复模型参数
  40. y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)
  41. print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))
  42.  
  43.  
  44. if __name__ == '__main__':
  45. if args.mode == 'train':
  46. train()
  47. if args.mode == 'test':
  48. test()

在代码目录下建立save文件夹并运行代码进行训练后,save文件夹内将会存放每隔100个batch保存一次的模型变量数据。在命令行参数中加入 —mode=test 并再次运行代码,将直接使用最后一次保存的变量值恢复模型并在测试集上测试模型性能,可以直接获得95%左右的准确率。

使用 tf.train.CheckpointManager 删除旧的Checkpoint以及自定义文件编号

在模型的训练过程中,我们往往每隔一定步数保存一个Checkpoint并进行编号。不过很多时候我们会有这样的需求:

  • 在长时间的训练后,程序会保存大量的Checkpoint,但我们只想保留最后的几个Checkpoint;

  • Checkpoint默认从1开始编号,每次累加1,但我们可能希望使用别的编号方式(例如使用当前Batch的编号作为文件编号)。

这时,我们可以使用TensorFlow的 tf.train.CheckpointManager 来实现以上需求。具体而言,在定义Checkpoint后接着定义一个CheckpointManager:

  1. checkpoint = tf.train.Checkpoint(model=model)
  2. manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)

此处, directory 参数为文件保存的路径, checkpoint_name 为文件名前缀(不提供则默认为 ckpt ), max_to_keep 为保留的Checkpoint数目。

在需要保存模型的时候,我们直接使用 manager.save() 即可。如果我们希望自行指定保存的Checkpoint的编号,则可以在保存时加入 checkpoint_number 参数。例如 manager.save(checkpoint_number=100)

以下提供一个实例,展示使用CheckpointManager限制仅保留最后三个Checkpoint文件,并使用batch的编号作为Checkpoint的文件编号。

  1. import tensorflow as tf
  2. import numpy as np
  3. import argparse
  4. from zh.model.mnist.mlp import MLP
  5. from zh.model.utils import MNISTLoader
  6.  
  7. parser = argparse.ArgumentParser(description='Process some integers.')
  8. parser.add_argument('--mode', default='train', help='train or test')
  9. parser.add_argument('--num_epochs', default=1)
  10. parser.add_argument('--batch_size', default=50)
  11. parser.add_argument('--learning_rate', default=0.001)
  12. args = parser.parse_args()
  13. data_loader = MNISTLoader()
  14.  
  15.  
  16. def train():
  17. model = MLP()
  18. optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
  19. num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)
  20. checkpoint = tf.train.Checkpoint(myAwesomeModel=model)
  21. # 使用tf.train.CheckpointManager管理Checkpoint
  22. manager = tf.train.CheckpointManager(checkpoint, directory='./save', max_to_keep=3)
  23. for batch_index in range(1, num_batches):
  24. X, y = data_loader.get_batch(args.batch_size)
  25. with tf.GradientTape() as tape:
  26. y_pred = model(X)
  27. loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
  28. loss = tf.reduce_mean(loss)
  29. print("batch %d: loss %f" % (batch_index, loss.numpy()))
  30. grads = tape.gradient(loss, model.variables)
  31. optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
  32. if batch_index % 100 == 0:
  33. # 使用CheckpointManager保存模型参数到文件并自定义编号
  34. path = manager.save(checkpoint_number=batch_index)
  35. print("model saved to %s" % path)
  36.  
  37.  
  38. def test():
  39. model_to_be_restored = MLP()
  40. checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)
  41. checkpoint.restore(tf.train.latest_checkpoint('./save'))
  42. y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)
  43. print("test accuracy: %f" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))
  44.  
  45.  
  46. if __name__ == '__main__':
  47. if args.mode == 'train':
  48. train()
  49. if args.mode == 'test':
  50. test()