作为一个高效好用的深度学习框架,pyTorch被广泛用于深度学习模型的搭建和训练,但是却几乎没有人会直接将pyTorch模型用于部署。对此,pyTorch官方也给出了自己的一种解决方案——TorchScript。
TorchScript[1]是一种从PyTorch代码创建可序列化、可优化的模型的方法,任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中,从而可以实现模型的部署。
将pyTorch模型转换成TorchScript有两种方法:一种叫tracing,另一种叫scripting。通常这两种方法任意一种都可以,但在一些特定模型结构中需要将两者结合使用。
鉴于tracing对于我来说已经足够,本文只关注tracing方法。
tracing用法
tracing
方法的核心是用torch.jit.trace
记录一次模型推理中经过的所有运算记录,将这些记录整合成计算图。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| import torch import torchvision
# 创建模型实例,对于自定义的模型需要加载权重 model = torchvision.models.resnet18(pretrained=True)
# 评估模式 model.eval()
# 创建一个示例输出,维度和forward函数的输入一致 dummy_input = torch.rand(1, 3, 224, 224)
# 用 torch.jit.trace 生成 torch.jit.ScriptModule with torch.no_grad(): traced_script_module = torch.jit.trace(model, dummy_input)
# 保存 TorchScript,习惯上后缀为.pt traced_script_module.save("traced_resnet.pt")
|
这样即可把pyTorch模型转换成TorchScript了。
创建一个完整的转换脚本
这里补充一个完整的转换脚本吧
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| import argparse import torch
# === import your model === from model import MyModel
parser = argparse.ArgumentParser("export model as torchscript") parser.add_argument('--checkpoint', type=str, required=True) parser.add_argument("--batch-size", type=int, default=1) parser.add_argument( "--input-shape", type=str, default="224,224", help="specify the input shape for inference")
args = parser.parse_args()
model = MyModel() checkpoint = torch.load(args.checkpoint) model.load_state_dict(checkpoint)
model.eval()
input_shape = tuple(map(int, args.input_shape.split(","))) dummy_input = torch.randn(args.batch_size, 6, *input_shape)
with torch.no_grad(): traced_script_module = torch.jit.trace(model, dummy_input)
traced_script_module.save("model.pt") print("save torchscript as model.pt")
|
TorchScript加载
pyTorch模型转换成TorchScript后可以很方便地用python或者C++直接加载,如在python中可直接使用torch.jit.load
加载序列化后的模型。
torch.jit.load
函数原型如下:
1
| torch.jit.load(f, map_location=None, _extra_files=None, _restore_shapes=False)
|
f
:
一个文件流对象或者模型文件名字符串,如我们的model.pt
;
map_location
:
字符串(如cuda:0
)或torch.device
,用于将模型映射加载到指定的设备上。
这一点很重要,因为默认情况下torch.jit.load
会试图将模型加载到保存时所使用的设备上,如果这个设备不存在就会报错。比如保存时模型在cuda:1
上,而我们加载时候的设备只有一张卡即cuda:0
,那么就会出错。而通过指定map_location
我们可以重新分配用于加载模型的设备。
加载实现:
1 2 3 4
| # 使用保存时的设备加载 model = torch.jit.load("model.pt") # 使用1卡加载 model = torch.jit.load("model.pt", map_location="cuda:1")
|
在C++中:
1
| auto model = torch::jit::load('model.pt');
|
参考
[1]. https://pytorch.org/docs/master/jit.html
[2]. http://djl.ai/docs/pytorch/how_to_convert_your_model_to_torchscript.html