使用 TVM 部署框架预量化模型#

原作者: Masahiro Masuda

这是关于将深度学习框架量化的模型加载到 TVM 的教程。预量化模型导入是 TVM 中量化支持的一种。TVM 中量化的更多细节可以在这里找到。

这里,将演示如何加载和运行由 PyTorch、MXNet 和 TFLite 量化的模型。一旦加载,就可以在任何 TVM 支持的硬件上运行已编译的、量化的模型。

首先,一些必备的载入:

from PIL import Image
import numpy as np

import torch
from torchvision.models.quantization import mobilenet as qmobilenet

加载 TVM 库:

import set_env

import tvm
from tvm import relay
from tvm.contrib.download import download_testdata

运行演示程序的辅助函数:

def get_transform():
    import torchvision.transforms as transforms

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]
    )


def get_real_image(im_height, im_width):
    img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
    img_path = download_testdata(img_url, "cat.png", module="data")
    return Image.open(img_path).resize((im_height, im_width))


def get_imagenet_input():
    im = get_real_image(224, 224)
    preprocess = get_transform()
    pt_tensor = preprocess(im)
    return np.expand_dims(pt_tensor.numpy(), 0)


def get_synset():
    synset_url = "".join(
        [
            "https://gist.githubusercontent.com/zhreshold/",
            "4d0b62f3d01426887599d4f7ede23ee5/raw/",
            "596b27d23537e5a1b5751d2b0481ef172f58b539/",
            "imagenet1000_clsid_to_human.txt",
        ]
    )
    synset_name = "imagenet1000_clsid_to_human.txt"
    synset_path = download_testdata(synset_url, synset_name, module="data")
    with open(synset_path) as f:
        return eval(f.read())


def run_tvm_model(mod, params, input_name, inp, target="llvm"):
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)

    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.device(target, 0)))

    runtime.set_input(input_name, inp)
    runtime.run()
    return runtime.get_output(0).numpy(), runtime

从标签到类名的映射,以验证下面模型的输出是合理的:

synset = get_synset()

大家最喜欢的猫的图像演示:

inp = get_imagenet_input()

部署已量化的 PyTorch 模型#

首先,演示如何使用 PyTorch 前端加载由 PyTorch 量化的深度学习模型。

请参阅 PyTorch 静态量化教程,了解它们的量化工作流程。

使用 quantize_model() 函数来量化 PyTorch 模型。简而言之,此函数采取浮点模型,并将其转换为 uint8。模型是逐通道量化的。

def quantize_model(model, inp):
    model.fuse_model()
    model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
    torch.quantization.prepare(model, inplace=True)
    # Dummy calibration
    model(inp)
    torch.quantization.convert(model, inplace=True)

从 torchvision 加载量化准备,预训练的 Mobilenet v2 模型#

选择 mobilenet v2 是因为此模型是用量化感知训练训练的。其他模型需要完整的后训练校准。

qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval()

量化,跟踪和运行 PyTorch Mobilenet v2 模型#

详细信息超出了本教程的范围。请参考 PyTorch 网站上的教程来学习 quantization 和 jit。

pt_inp = torch.from_numpy(inp)
quantize_model(qmodel, pt_inp)
script_module = torch.jit.trace(qmodel, pt_inp).eval()

with torch.no_grad():
    pt_result = script_module(pt_inp).numpy()
/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/torch/ao/quantization/observer.py:177: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/torch/ao/quantization/observer.py:1124: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point 
  warnings.warn(

使用 PyTorch 前端将量化的 Mobilenet v2 转换为 Relay-QNN#

PyTorch 前端支持将量化的 PyTorch 模型转换为具有量化感知算子(quantization-aware operator)的等效 Relay 模块。称这种表示 Relay QNN dialect。

可以从前端打印输出,以查看量化模型是如何表示的。

将看到针对量化的运算符,如 qnn.quantizeqnn.dequantizeqnn.requantizeqnn.conv2d 等等。

input_name = "input"  # the input name can be be arbitrary for PyTorch frontend.
input_shapes = [(input_name, (1, 3, 224, 224))]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)
# print(mod['main']) # comment in to see the QNN IR dump

编译和运行 Relay 模块#

一旦获得了量化的 Relay 模块,其余的工作流程就像运行浮点模型一样。请参考其他教程了解更多细节。

在编译之前,量化特定的算子被 lower 到标准 Relay 算子序列。

target = "llvm"
tvm_result, rt_mod = run_tvm_model(mod, params, input_name, inp, target=target)
/media/pc/data/4tb/lxw/books/tvm/python/tvm/target/target.py:316: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  warnings.warn(
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.

计算输出标签#

应该看到打印出相同的标签。

pt_top3_labels = np.argsort(pt_result[0])[::-1][:3]
tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3]

print("PyTorch top3 labels:", [synset[label] for label in pt_top3_labels])
print("TVM top3 labels:", [synset[label] for label in tvm_top3_labels])
PyTorch top3 labels: ['tabby, tabby cat', 'tiger cat', 'Egyptian cat']
TVM top3 labels: ['tabby, tabby cat', 'tiger cat', 'Egyptian cat']

然而,由于数值上的差异,通常原始浮点输出不会是相同的。这里,打印从 mobilenet v2 的 1000 个输出中有多少个浮点输出值是相同的。

print("%d in 1000 raw floating outputs identical." % np.sum(tvm_result[0] == pt_result[0]))
207 in 1000 raw floating outputs identical.

性能度量#

在此,举例说明如何度量 TVM 编译模型的性能。

n_repeat = 100  # should be bigger to make the measurement more accurate
dev = tvm.cpu(0)
print(rt_mod.benchmark(dev, number=1, repeat=n_repeat))
Execution time summary:
 mean (ms)   median (ms)    max (ms)     min (ms)     std (ms)  
   5.9385       5.5819       9.1432       5.4313       0.8170   
               

备注

  • 由于度量是在 C++ 中完成的,所以没有 Python 的开销

  • 它包括几个 warm up 运行

  • 同样的方法可以用于远程设备(android 等)的配置。

警告

除非硬件对快速 8 bit 指令有特殊支持,否则量化模型不会比 FP32 模型更快。如果没有快速的 8 bit 指令,可 TVM 以在 16 bit 进行量化卷积,即使模型本身是 8 bit。

对于 x86,最好的性能可以在带有 AVX512 指令集的 CPU 上实现。在这种情况下,TVM 为给定的目标使用最快的可用 8 bit 指令。这包括对 VNNI 8 bit 点积指令(CascadeLake 或更新版本)的支持。

此外,以下对 CPU 性能的一般建议同样适用:

  • 将环境变量 TVM_NUM_THREADS 设置为物理核数

  • 为您的硬件选择最佳的目标,例如 "llvm -mcpu=skylake-avx512" "llvm -mcpu=cascadelake" (将来会有更多带有 AVX512 的 CPU)

Deploy a quantized MXNet Model#

TODO

Deploy a quantized TFLite Model#

TODO