pyTorch模型转换为onnx格式相对来说是比较简单的,因为pyTorch提供了torch.onnx
模块用于实现这一转换过程。
基本实现
将pyTorch模型转换为onnx格式的基本流程和转换为TorchScript格式是一样的,这是因为无论转换为TorchScript还是onnx都需要对torch.nn.Module
进行tracing操作以记录模型所有的运算。
torch.onnx.export 函数说明
两者的不同之处在于实现转换功能的函数,转换为onnx格式需要用函数torch.onnx.export
[2]实现。函数的用法pyTorch官方有详细的介绍,不过我觉得还是有必要将重点用更简单的语言说一下。函数格式如下:
1 | torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, \ |
看起来一堆参数,但是实际上必须的参数只有三个:
model
: 要进行转换的模型,通常是加载好权重的torch.nn.Module
实例。 多说一句,model
也可以是torch.jit.ScriptModule
,实际上torch.onnx.export
转换时处理的就是torch.jit.ScriptModule
实例,如果输入是torch.nn.Module
实例,那么就会首先被转换为torch.jit.ScriptModule
。args
: 模型model
的输入参数,可以是元组或者torch.Tensor
。f
: 简单理解是导出文件的名字,此时需要是一个包含文件名的字符串。
除了这三个必要的参数,还有一些非常重要、常用的参数:
input_names
: 运算图输入节点的名字,类型是字符串列表output_names
: 运算图输出节点的名字,类型是字符串列表dynamic_axes
: 可以控制将导出的模型设置为支持输入动态维度opset_version
: 运算集版本
torch.onnx.export 函数使用
下边通过一个简单的例子说明一下torch.onnx.export
的用法。
1 | import torch |
根据我们对于这个函数的认知,我们可以知道:这段代码最后得到了一个onnx模型,名字为resnet18.onnx
,转换用到的操作集版本是11,resnet18.onnx
的输入节点名字是image
、输出节点名字是pred
,输出和输出的batch
size都是动态的,但每张输入图像的维度是固定的,那就是
为了验证我们转出来的onnx模型是不是跟我们预想的一致,还可以把onnx模型上传到https://netron.app 实现可视化查看。
完整的转换实例脚本
1 | import argparse |
参考
[1]. https://onnxruntime.ai/docs/get-started/with-python.html