编译 PyTorch 目标检测模型#

本文是使用 Relay VM 部署 PyTorch 目标检测模型的介绍性教程。

首先应该安装 PyTorch。TorchVision 也是必需的,因为将使用它作为模型动物园。

快速的解决方案是通过 pip 安装:

pip install torch torchvision

或者请参考 官方网站

PyTorch 版本应该向后兼容,但应该与正确的 TorchVision 版本一起使用。

目前,TVM 支持 PyTorch 1.7 和 1.4。其他版本可能不稳定。

import env
import tvm
from tvm import relay
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download_testdata

import numpy as np
from cv2 import cv2

# PyTorch imports
import torch
import torchvision

从 torchvision 加载预训练的 maskrcnn 并进行跟踪#

def do_trace(model, inp):
    model_trace = torch.jit.trace(model, inp)
    model_trace.eval()
    return model_trace


def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]


class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        out = self.model(inp)
        return dict_to_tuple(out[0])
in_size = 300
input_shape = (1, 3, in_size, in_size)

model_func = torchvision.models.detection.maskrcnn_resnet50_fpn

model = TraceWrapper(model_func(pretrained=True))
model.eval()
inp = torch.rand(input_shape)

with torch.no_grad():
    out = model(inp)
    script_module = do_trace(model, inp)

下载测试图像并进行预处理#

img_url = (
    "https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/detection/street_small.jpg"
)
img_path = download_testdata(img_url, 
            "test_street_small.jpg", 
            module="data")

img = cv2.imread(img_path).astype("float32")
img = cv2.resize(img, (in_size, in_size))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.transpose(img / 255.0, [2, 0, 1])
img = np.expand_dims(img, axis=0)

导入 graph 到 Relay#

input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(script_module, shape_list)

使用 Relay VM 编译#

备注

目前只支持 CPU target。对于 x86 target,由于在 torchvision rcnn 模型中存在较大的 dense 算子,因此强烈推荐使用 Intel MKL 和 Intel OpenMP 构建 TVM 以获得最佳性能。

# Add "-libs=mkl" to get best performance on x86 target.
# For x86 machine supports AVX512, the complete target is
# "llvm -mcpu=skylake-avx512 -libs=mkl"
target = "llvm"

with tvm.transform.PassContext(opt_level=3,
                               disabled_pass=["FoldScaleAxis"]):
    vm_exec = relay.vm.compile(mod, target=target, params=params)

使用 Relay VM 推理#

dev = tvm.cpu()
vm = VirtualMachine(vm_exec, dev)
vm.set_input("main", **{input_name: img})
tvm_res = vm.run()

获得得分大于 0.9 的 boxes#

score_threshold = 0.9
boxes = tvm_res[0].numpy().tolist()
valid_boxes = []
for i, score in enumerate(tvm_res[1].numpy().tolist()):
    if score > score_threshold:
        valid_boxes.append(boxes[i])
    else:
        break

print("Get {} valid boxes".format(len(valid_boxes)))
Get 9 valid boxes