pyTorch模型转torchscript

作为一个高效好用的深度学习框架,pyTorch被广泛用于深度学习模型的搭建和训练,但是却几乎没有人会直接将pyTorch模型用于部署。对此,pyTorch官方也给出了自己的一种解决方案——TorchScript。

TorchScript[1]是一种从PyTorch代码创建可序列化、可优化的模型的方法,任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中,从而可以实现模型的部署。

将pyTorch模型转换成TorchScript有两种方法:一种叫tracing,另一种叫scripting。通常这两种方法任意一种都可以,但在一些特定模型结构中需要将两者结合使用。

鉴于tracing对于我来说已经足够,本文只关注tracing方法。

tracing用法

tracing方法的核心是用torch.jit.trace记录一次模型推理中经过的所有运算记录,将这些记录整合成计算图。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torchvision

# 创建模型实例,对于自定义的模型需要加载权重
model = torchvision.models.resnet18(pretrained=True)

# 评估模式
model.eval()

# 创建一个示例输出,维度和forward函数的输入一致
dummy_input = torch.rand(1, 3, 224, 224)

# 用 torch.jit.trace 生成 torch.jit.ScriptModule
with torch.no_grad():
traced_script_module = torch.jit.trace(model, dummy_input)

# 保存 TorchScript,习惯上后缀为.pt
traced_script_module.save("traced_resnet.pt")

这样即可把pyTorch模型转换成TorchScript了。

创建一个完整的转换脚本

这里补充一个完整的转换脚本吧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
import torch

# === import your model ===
from model import MyModel

parser = argparse.ArgumentParser("export model as torchscript")
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument(
"--input-shape",
type=str,
default="224,224",
help="specify the input shape for inference")

args = parser.parse_args()

model = MyModel()
checkpoint = torch.load(args.checkpoint)
model.load_state_dict(checkpoint)

model.eval()

input_shape = tuple(map(int, args.input_shape.split(",")))
dummy_input = torch.randn(args.batch_size, 6, *input_shape)

with torch.no_grad():
traced_script_module = torch.jit.trace(model, dummy_input)

traced_script_module.save("model.pt")
print("save torchscript as model.pt")

TorchScript加载

pyTorch模型转换成TorchScript后可以很方便地用python或者C++直接加载,如在python中可直接使用torch.jit.load加载序列化后的模型。

torch.jit.load函数原型如下:

1
torch.jit.load(f, map_location=None, _extra_files=None, _restore_shapes=False)
  • f: 一个文件流对象或者模型文件名字符串,如我们的model.pt;
  • map_location: 字符串(如cuda:0)或torch.device,用于将模型映射加载到指定的设备上。 这一点很重要,因为默认情况下torch.jit.load会试图将模型加载到保存时所使用的设备上,如果这个设备不存在就会报错。比如保存时模型在cuda:1上,而我们加载时候的设备只有一张卡即cuda:0,那么就会出错。而通过指定map_location我们可以重新分配用于加载模型的设备。

加载实现:

1
2
3
4
# 使用保存时的设备加载
model = torch.jit.load("model.pt")
# 使用1卡加载
model = torch.jit.load("model.pt", map_location="cuda:1")

在C++中:

1
auto model = torch::jit::load('model.pt');

参考

[1]. https://pytorch.org/docs/master/jit.html

[2]. http://djl.ai/docs/pytorch/how_to_convert_your_model_to_torchscript.html