pyTorch模型转onnx

pyTorch模型转换为onnx格式相对来说是比较简单的,因为pyTorch提供了torch.onnx模块用于实现这一转换过程。

基本实现

将pyTorch模型转换为onnx格式的基本流程和转换为TorchScript格式是一样的,这是因为无论转换为TorchScript还是onnx都需要对torch.nn.Module进行tracing操作以记录模型所有的运算。

torch.onnx.export 函数说明

两者的不同之处在于实现转换功能的函数,转换为onnx格式需要用函数torch.onnx.export[2]实现。函数的用法pyTorch官方有详细的介绍,不过我觉得还是有必要将重点用更简单的语言说一下。函数格式如下:

1
2
3
4
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, \
input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, \
opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None,
custom_opsets=None, export_modules_as_functions=False)

看起来一堆参数,但是实际上必须的参数只有三个:

  • model: 要进行转换的模型,通常是加载好权重的torch.nn.Module实例。 多说一句,model也可以是torch.jit.ScriptModule,实际上torch.onnx.export转换时处理的就是torch.jit.ScriptModule实例,如果输入是torch.nn.Module实例,那么就会首先被转换为torch.jit.ScriptModule

  • args: 模型model的输入参数,可以是元组或者torch.Tensor

  • f: 简单理解是导出文件的名字,此时需要是一个包含文件名的字符串。

除了这三个必要的参数,还有一些非常重要、常用的参数:

  • input_names: 运算图输入节点的名字,类型是字符串列表
  • output_names : 运算图输出节点的名字,类型是字符串列表
  • dynamic_axes: 可以控制将导出的模型设置为支持输入动态维度
  • opset_version: 运算集版本

torch.onnx.export 函数使用

下边通过一个简单的例子说明一下torch.onnx.export的用法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torchvision

# 创建模型实例,对于自定义的模型需要加载权重
model = torchvision.models.resnet18(pretrained=True).cuda()

# 评估模式
model.eval()

# 创建一个示例输出,维度和forward函数的输入一致
dummy_input = torch.rand(1, 3, 224, 224, device='cuda')

# 执行转换并将结果保存为`resnet18.onnx`
torch.onnx.export(
model,
dummy_input,
"resnet18.onnx",
input_names = ["image"],
output_names = ["pred"],
dynamic_axes = {"image": {0: "batch"},
"label": {0: "batch"}},
opset_version=11
)

根据我们对于这个函数的认知,我们可以知道:这段代码最后得到了一个onnx模型,名字为resnet18.onnx,转换用到的操作集版本是11,resnet18.onnx的输入节点名字是image、输出节点名字是pred,输出和输出的batch size都是动态的,但每张输入图像的维度是固定的,那就是

为了验证我们转出来的onnx模型是不是跟我们预想的一致,还可以把onnx模型上传到https://netron.app 实现可视化查看。

完整的转换实例脚本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
import torch

# === import your model ===
from model import MyModel

parser = argparse.ArgumentParser("export model as onnx")
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument(
"--input-shape",
type=str,
default="224,224",
help="specify the input shape for inference")
parser.add_argument("--dynamic", action="store_true", help="whether the input shape should be dynamic")
parser.add_argument("--file-name", type=str, default="model.onnx", help="onnx file name")
parser.add_argument("--input-name", type=str, default="image", help="name of input node")
parser.add_argument("--output-name", type=str, default="pred", help="name of output node")
parser.add_argument("--opset", type=int, default=11, help="onnx opset version")

args = parser.parse_args()

model = MyModel().cuda()
checkpoint = torch.load(args.checkpoint)
model.load_state_dict(checkpoint)

model.eval()

input_shape = tuple(map(int, args.input_shape.split(",")))
dummy_input = torch.randn(args.batch_size, 6, *input_shape, device="cuda")

torch.onnx.export(
model,
dummy_input,
args.file_name,
input_names = [args.input_name],
output_names = [args.output_name],
dynamic_axes = {args.input_name: {0: "batch"},
args.output_name: {0: "batch"}} if args.dynamic else None,
opset_version=args.opset,
)

print("save torchscript as " + args.file_name)

参考

[1]. https://onnxruntime.ai/docs/get-started/with-python.html

[2]. https://pytorch.org/docs/stable/onnx.html#functions