Keras Sequential save方法(Jinpeng)

我们以keras模型训练和保存为例进行讲解,如下是keras官方的mnist模型训练样例。

源码地址:

  1. https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py

以上代码,是基于keras的Sequential构建了多层的卷积神经网络,并进行训练。

为了方便起见可使用如下命令拷贝到本地:

  1. curl -LO https://raw.githubusercontent.com/keras-team/keras/master/examples/mnist_cnn.py

然后,在最后加上如下一行代码(主要是对keras训练完毕的模型进行保存):

  1. model.save('mnist_cnn.h5')

在终端中执行mnist_cnn.py文件,如下:

  1. python mnist_cnn.py

警告

该过程需要连接网络获取 mnist.npz 文件(https://s3.amazonaws.com/img-datasets/mnist.npz),会被保存到 $HOME/.keras/datasets/ 。如果网络连接存在问题,可以通过其他方式获取 mnist.npz 后,直接保存到该目录即可。

执行过程会比较久,执行结束后,会在当前目录产生 mnist_cnn.h5 文件(HDF5格式),就是keras训练后的模型,其中已经包含了训练后的模型结构和权重等信息。

在服务器端,可以直接通过 keras.models.load_model("mnist_cnn.h5") 加载,然后进行推理;在移动设备需要将HDF5模型文件转换为TensorFlow Lite的格式,然后通过相应平台的Interpreter加载,然后进行推理。