编译 PyTorch 目标检测模型
导航
编译 PyTorch 目标检测模型#
本文是使用 Relay VM 部署 PyTorch 目标检测模型的介绍性教程。
首先应该安装 PyTorch。TorchVision 也是必需的,因为将使用它作为模型动物园。
快速的解决方案是通过 pip 安装:
pip install torch torchvision
或者请参考 官方网站。
PyTorch 版本应该向后兼容,但应该与正确的 TorchVision 版本一起使用。
目前,TVM 支持 PyTorch 1.7 和 1.4。其他版本可能不稳定。
import env
import tvm
from tvm import relay
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download_testdata
import numpy as np
from cv2 import cv2
# PyTorch imports
import torch
import torchvision
从 torchvision 加载预训练的 maskrcnn 并进行跟踪#
def do_trace(model, inp):
model_trace = torch.jit.trace(model, inp)
model_trace.eval()
return model_trace
def dict_to_tuple(out_dict):
if "masks" in out_dict.keys():
return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
return out_dict["boxes"], out_dict["scores"], out_dict["labels"]
class TraceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inp):
out = self.model(inp)
return dict_to_tuple(out[0])
in_size = 300
input_shape = (1, 3, in_size, in_size)
model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))
model.eval()
inp = torch.rand(input_shape)
with torch.no_grad():
out = model(inp)
script_module = do_trace(model, inp)
下载测试图像并进行预处理#
img_url = (
"https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg"
)
img_path = download_testdata(img_url,
"test_street_small.jpg",
module="data")
img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)
导入 graph 到 Relay#
input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(script_module, shape_list)
使用 Relay VM 编译#
备注
目前只支持 CPU target。对于 x86 target,由于在 torchvision rcnn 模型中存在较大的 dense 算子,因此强烈推荐使用 Intel MKL 和 Intel OpenMP 构建 TVM 以获得最佳性能。
# Add "-libs=mkl" to get best performance on x86 target.
# For x86 machine supports AVX512, the complete target is
# "llvm -mcpu=skylake-avx512 -libs=mkl"
target = "llvm"
with tvm.transform.PassContext(opt_level=3,
disabled_pass=["FoldScaleAxis"]):
vm_exec = relay.vm.compile(mod, target=target, params=params)
使用 Relay VM 推理#
dev = tvm.cpu()
vm = VirtualMachine(vm_exec, dev)
vm.set_input("main", **{input_name: img})
tvm_res = vm.run()
获得得分大于 0.9 的 boxes#
score_threshold = 0.9
boxes = tvm_res[0].numpy().tolist()
valid_boxes = []
for i, score in enumerate(tvm_res[1].numpy().tolist()):
if score > score_threshold:
valid_boxes.append(boxes[i])
else:
break
print("Get {} valid boxes".format(len(valid_boxes)))
Get 9 valid boxes