模型导出ONNX协议
一、简介
ONNX (Open Neural Network Exchange) 是针对机器学习所设计的开源文件格式,用于存储训练好的模型。它使得不同的人工智能框架可以采用相同格式存储模型并交互。通过ONNX格式,Paddle模型可以使用OpenVINO、ONNX Runtime等框架进行推理。
Paddle转ONNX协议由 paddle2onnx 实现,下面介绍如何将Paddle模型转换为ONNX模型并验证正确性。
本教程涉及的示例代码,可点击 IPython 获取, 除Paddle以外,还需安装以下依赖:
pip install paddle2onnx onnx onnxruntime // -i https://mirror.baidu.com/pypi/simple 如果网速不好,可以使用其他源下载
二、模型导出为ONNX协议
2.1 动态图导出ONNX协议
Paddle动态图模型转换为ONNX协议,首先会将Paddle的动态图 paddle.nn.Layer
转换为静态图, 详细原理可以参考 动态图转静态图 。然后依照ONNX的算子协议,将Paddle的算子一一映射为ONNX的算子。动态图转换ONNX调用 paddle.onnx.export()
接口即可实现,该接口通过 input_spec
参数为模型指定输入的形状和数据类型,支持 Tensor
或 InputSpec
,其中 InputSpec
支持动态的shape。
关于 paddle.onnx.export
接口更详细的使用方法,请参考 API 。
import paddle
from paddle import nn
from paddle.static import InputSpec
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = nn.Linear(784, 10)
def forward(self, x):
return self._linear(x)
# export to ONNX
layer = LinearNet()
save_path = 'onnx.save/linear_net'
x_spec = InputSpec([None, 784], 'float32', 'x')
paddle.onnx.export(layer, save_path, input_spec=[x_spec])
2.2 静态图导出ONNX协议
Paddle 2.0以后将主推动态图组网方式,如果您的模型来自于旧版本的Paddle,使用静态图组网,请参考paddle2onnx的 使用文档 和 示例 。
三、ONNX模型的验证
ONNX官方工具包提供了API可验证模型的正确性,主要包括两个方面,一是算子是否符合对应版本的协议,二是网络结构是否完整。
# check by ONNX
import onnx
onnx_file = save_path + '.onnx'
onnx_model = onnx.load(onnx_file)
onnx.checker.check_model(onnx_model)
print('The model is checked!')
如果模型检查失败,请到 Paddle 或 paddle2onnx 提出Issue,我们会跟进相应的问题。
四、ONNXRuntime推理
本节介绍使用ONNXRuntime对已转换的Paddle模型进行推理,并与使用Paddle进行推理的结果进行对比。
import numpy as np
import onnxruntime
x = np.random.random((2, 784)).astype('float32')
# predict by ONNX Runtime
ort_sess = onnxruntime.InferenceSession(onnx_file)
ort_inputs = {ort_sess.get_inputs()[0].name: x}
ort_outs = ort_sess.run(None, ort_inputs)
print("Exported model has been predicted by ONNXRuntime!")
# predict by Paddle
layer.eval()
paddle_outs = layer(x)
# compare ONNX Runtime and Paddle results
np.testing.assert_allclose(ort_outs[0], paddle_outs.numpy(), rtol=1.0, atol=1e-05)
print("The difference of results between ONNXRuntime and Paddle looks good!")