{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 使用 TVM 部署框架预量化模型\n", "\n", "**原作者**: [Masahiro Masuda](https://github.com/masahi)\n", "\n", "这是关于将深度学习框架量化的模型加载到 TVM 的教程。预量化模型导入是 TVM 中量化支持的一种。TVM 中量化的更多细节可以在[这里](https://discuss.tvm.apache.org/t/quantization-story/3920)找到。\n", "\n", "这里,将演示如何加载和运行由 PyTorch、MXNet 和 TFLite 量化的模型。一旦加载,就可以在任何 TVM 支持的硬件上运行已编译的、量化的模型。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "首先,一些必备的载入:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "import numpy as np\n", "\n", "import torch\n", "from torchvision.models.quantization import mobilenet as qmobilenet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "加载 TVM 库:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import set_env\n", "\n", "import tvm\n", "from tvm import relay\n", "from tvm.contrib.download import download_testdata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "运行演示程序的辅助函数:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_transform():\n", " import torchvision.transforms as transforms\n", "\n", " normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", " return transforms.Compose(\n", " [\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " normalize,\n", " ]\n", " )\n", "\n", "\n", "def get_real_image(im_height, im_width):\n", " img_url = \"https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true\"\n", " img_path = download_testdata(img_url, \"cat.png\", module=\"data\")\n", " return Image.open(img_path).resize((im_height, im_width))\n", "\n", "\n", "def get_imagenet_input():\n", " im = get_real_image(224, 224)\n", " preprocess = get_transform()\n", " pt_tensor = preprocess(im)\n", " return np.expand_dims(pt_tensor.numpy(), 0)\n", "\n", "\n", "def get_synset():\n", " synset_url = \"\".join(\n", " [\n", " \"https://gist.githubusercontent.com/zhreshold/\",\n", " \"4d0b62f3d01426887599d4f7ede23ee5/raw/\",\n", " \"596b27d23537e5a1b5751d2b0481ef172f58b539/\",\n", " \"imagenet1000_clsid_to_human.txt\",\n", " ]\n", " )\n", " synset_name = \"imagenet1000_clsid_to_human.txt\"\n", " synset_path = download_testdata(synset_url, synset_name, module=\"data\")\n", " with open(synset_path) as f:\n", " return eval(f.read())\n", "\n", "\n", "def run_tvm_model(mod, params, input_name, inp, target=\"llvm\"):\n", " with tvm.transform.PassContext(opt_level=3):\n", " lib = relay.build(mod, target=target, params=params)\n", "\n", " runtime = tvm.contrib.graph_executor.GraphModule(lib[\"default\"](tvm.device(target, 0)))\n", "\n", " runtime.set_input(input_name, inp)\n", " runtime.run()\n", " return runtime.get_output(0).numpy(), runtime" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "从标签到类名的映射,以验证下面模型的输出是合理的:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "synset = get_synset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "大家最喜欢的猫的图像演示:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "inp = get_imagenet_input()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 部署已量化的 PyTorch 模型\n", "\n", "首先,演示如何使用 PyTorch 前端加载由 PyTorch 量化的深度学习模型。\n", "\n", "请参阅 [PyTorch 静态量化教程](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html),了解它们的量化工作流程。\n", "\n", "使用 {func}`quantize_model` 函数来量化 PyTorch 模型。简而言之,此函数采取浮点模型,并将其转换为 uint8。模型是逐通道量化的。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def quantize_model(model, inp):\n", " model.fuse_model()\n", " model.qconfig = torch.quantization.get_default_qconfig(\"fbgemm\")\n", " torch.quantization.prepare(model, inplace=True)\n", " # Dummy calibration\n", " model(inp)\n", " torch.quantization.convert(model, inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 从 torchvision 加载量化准备,预训练的 Mobilenet v2 模型\n", "\n", "选择 mobilenet v2 是因为此模型是用量化感知训练训练的。其他模型需要完整的后训练校准。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "qmodel = qmobilenet.mobilenet_v2(pretrained=True).eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 量化,跟踪和运行 PyTorch Mobilenet v2 模型\n", "\n", "详细信息超出了本教程的范围。请参考 PyTorch 网站上的教程来学习 quantization 和 jit。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/torch/ao/quantization/observer.py:177: UserWarning: Please use quant_min and quant_max to specify the range for observers. reduce_range will be deprecated in a future release of PyTorch.\n", " warnings.warn(\n", "/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/torch/ao/quantization/observer.py:1124: UserWarning: must run observer before calling calculate_qparams. Returning default scale and zero point \n", " warnings.warn(\n" ] } ], "source": [ "pt_inp = torch.from_numpy(inp)\n", "quantize_model(qmodel, pt_inp)\n", "script_module = torch.jit.trace(qmodel, pt_inp).eval()\n", "\n", "with torch.no_grad():\n", " pt_result = script_module(pt_inp).numpy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 使用 PyTorch 前端将量化的 Mobilenet v2 转换为 Relay-QNN\n", "\n", "PyTorch 前端支持将量化的 PyTorch 模型转换为具有量化感知算子(quantization-aware operator)的等效 Relay 模块。称这种表示 Relay QNN dialect。\n", "\n", "可以从前端打印输出,以查看量化模型是如何表示的。\n", "\n", "将看到针对量化的运算符,如 `qnn.quantize`、`qnn.dequantize`、`qnn.requantize` 和 `qnn.conv2d` 等等。" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [], "source": [ "input_name = \"input\" # the input name can be be arbitrary for PyTorch frontend.\n", "input_shapes = [(input_name, (1, 3, 224, 224))]\n", "mod, params = relay.frontend.from_pytorch(script_module, input_shapes)\n", "# print(mod['main']) # comment in to see the QNN IR dump" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 编译和运行 Relay 模块\n", "\n", "一旦获得了量化的 Relay 模块,其余的工作流程就像运行浮点模型一样。请参考其他教程了解更多细节。\n", "\n", "在编译之前,量化特定的算子被 lower 到标准 Relay 算子序列。" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/media/pc/data/4tb/lxw/books/tvm/python/tvm/target/target.py:316: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.\n", " warnings.warn(\n", "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n" ] } ], "source": [ "target = \"llvm\"\n", "tvm_result, rt_mod = run_tvm_model(mod, params, input_name, inp, target=target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 计算输出标签\n", "\n", "应该看到打印出相同的标签。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PyTorch top3 labels: ['tabby, tabby cat', 'tiger cat', 'Egyptian cat']\n", "TVM top3 labels: ['tabby, tabby cat', 'tiger cat', 'Egyptian cat']\n" ] } ], "source": [ "pt_top3_labels = np.argsort(pt_result[0])[::-1][:3]\n", "tvm_top3_labels = np.argsort(tvm_result[0])[::-1][:3]\n", "\n", "print(\"PyTorch top3 labels:\", [synset[label] for label in pt_top3_labels])\n", "print(\"TVM top3 labels:\", [synset[label] for label in tvm_top3_labels])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "然而,由于数值上的差异,通常原始浮点输出不会是相同的。这里,打印从 mobilenet v2 的 1000 个输出中有多少个浮点输出值是相同的。" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "207 in 1000 raw floating outputs identical.\n" ] } ], "source": [ "print(\"%d in 1000 raw floating outputs identical.\" % np.sum(tvm_result[0] == pt_result[0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 性能度量\n", "\n", "在此,举例说明如何度量 TVM 编译模型的性能。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Execution time summary:\n", " mean (ms) median (ms) max (ms) min (ms) std (ms) \n", " 5.9385 5.5819 9.1432 5.4313 0.8170 \n", " \n" ] } ], "source": [ "n_repeat = 100 # should be bigger to make the measurement more accurate\n", "dev = tvm.cpu(0)\n", "print(rt_mod.benchmark(dev, number=1, repeat=n_repeat))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{note}\n", ":class: alert alert-info\n", "\n", "* 由于度量是在 C++ 中完成的,所以没有 Python 的开销\n", "* 它包括几个 warm up 运行\n", "* 同样的方法可以用于远程设备(android 等)的配置。\n", "```\n", "\n", "```{warning}\n", ":class: alert alert-info\n", "\n", "除非硬件对快速 8 bit 指令有特殊支持,否则量化模型不会比 FP32 模型更快。如果没有快速的 8 bit 指令,可 TVM 以在 16 bit 进行量化卷积,即使模型本身是 8 bit。\n", "\n", "对于 x86,最好的性能可以在带有 AVX512 指令集的 CPU 上实现。在这种情况下,TVM 为给定的目标使用最快的可用 8 bit 指令。这包括对 VNNI 8 bit 点积指令(CascadeLake 或更新版本)的支持。\n", "\n", "此外,以下对 CPU 性能的一般建议同样适用:\n", "\n", "- 将环境变量 ``TVM_NUM_THREADS`` 设置为物理核数\n", "- 为您的硬件选择最佳的目标,例如 `\"llvm -mcpu=skylake-avx512\" ` 或 `\"llvm -mcpu=cascadelake\"` (将来会有更多带有 AVX512 的 CPU)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deploy a quantized MXNet Model\n", "------------------------------\n", "TODO\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deploy a quantized TFLite Model\n", "-------------------------------\n", "TODO\n", "\n" ] } ], "metadata": { "interpreter": { "hash": "7a45eadec1f9f49b0fdfd1bc7d360ac982412448ce738fa321afc640e3212175" }, "kernelspec": { "display_name": "Python 3.10.4 ('torchx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 0 }