{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(sphx_glr_how_to_compile_models_from_pytorch.py)=\n", "# 编译 PyTorch 模型\n", "\n", "**Author**: [Alex Wong](https://github.com/alexwong/)\n", "\n", "本文是使用 Relay 部署 PyTorch 模型的入门教程。\n", "\n", "对于我们来说,首先应该安装 PyTorch。TorchVision 也是必需的,因为我们将使用它作为我们的模型动物园。\n", "\n", "快速的解决方案是通过 pip 进行安装:\n", "\n", "```python\n", "pip install torch==1.7.0\n", "pip install torchvision==0.8.1\n", "```\n", "\n", "或者参考[官方网站](https://pytorch.org/get-started/locally/)。\n", "\n", "PyTorch 版本应该向后兼容,但应该与适当的 TorchVision 版本一起使用。\n", "\n", "目前,TVM 支持 PyTorch 1.7 和 1.4。其他版本可能不稳定。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "# PyTorch imports\n", "import torch\n", "import torchvision\n", "\n", "import set_env # 设置 TVM 环境\n", "import tvm\n", "from tvm import relay\n", "from tvm.contrib.download import download_testdata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 载入 PyTorch 预训练模型" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model_name = \"resnet18\"\n", "model = getattr(torchvision.models, model_name)(pretrained=True)\n", "model = model.eval()\n", "\n", "# 我们通过跟踪获取 TorchScripted 模型\n", "input_shape = [1, 3, 224, 224]\n", "input_data = torch.randn(input_shape)\n", "scripted_model = torch.jit.trace(model, input_data).eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 加载测试图片" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "\n", "img_url = \"https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true\"\n", "img_path = download_testdata(img_url, \"cat.png\", module=\"data\")\n", "img = Image.open(img_path).resize((224, 224))\n", "\n", "# Preprocess the image and convert to tensor\n", "from torchvision import transforms\n", "\n", "my_preprocess = transforms.Compose(\n", " [\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ]\n", ")\n", "img = my_preprocess(img)\n", "img = np.expand_dims(img, 0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 导入 Graph 到 Relay\n", "\n", "将 PyTorch 图转换为 Relay 图。`input_name` 可以是任意的。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "input_name = \"input0\"\n", "shape_list = [(input_name, img.shape)]\n", "mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Relay 构建\n", "\n", "使用给定的输入规范将 graph 编译为 llvm 目标:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n" ] } ], "source": [ "target = tvm.target.Target(\"llvm\", host=\"llvm\")\n", "dev = tvm.cpu(0)\n", "with tvm.transform.PassContext(opt_level=3):\n", " lib = relay.build(mod, target=target, params=params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 在 TVM 上执行可移植 Graph\n", "\n", "可以尝试在目标上部署编译后的模型。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from tvm.contrib import graph_executor\n", "\n", "dtype = \"float32\"\n", "m = graph_executor.GraphModule(lib[\"default\"](dev))\n", "# Set inputs\n", "m.set_input(input_name, tvm.nd.array(img.astype(dtype)))\n", "# Execute\n", "m.run()\n", "# Get outputs\n", "tvm_output = m.get_output(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 查找 synset 名称\n", "\n", "在 1000 类 synset 中查找预测 top 1 索引。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Relay top-1 id: 281, class name: tabby, tabby cat\n", "Torch top-1 id: 281, class name: tabby, tabby cat\n" ] } ], "source": [ "synset_url = \"\".join(\n", " [\n", " \"https://raw.githubusercontent.com/Cadene/\",\n", " \"pretrained-models.pytorch/master/data/\",\n", " \"imagenet_synsets.txt\",\n", " ]\n", ")\n", "synset_name = \"imagenet_synsets.txt\"\n", "synset_path = download_testdata(synset_url, synset_name, module=\"data\")\n", "with open(synset_path) as f:\n", " synsets = f.readlines()\n", "\n", "synsets = [x.strip() for x in synsets]\n", "splits = [line.split(\" \") for line in synsets]\n", "key_to_classname = {spl[0]: \" \".join(spl[1:]) for spl in splits}\n", "\n", "class_url = \"\".join(\n", " [\n", " \"https://raw.githubusercontent.com/Cadene/\",\n", " \"pretrained-models.pytorch/master/data/\",\n", " \"imagenet_classes.txt\",\n", " ]\n", ")\n", "class_name = \"imagenet_classes.txt\"\n", "class_path = download_testdata(class_url, class_name, module=\"data\")\n", "with open(class_path) as f:\n", " class_id_to_key = f.readlines()\n", "\n", "class_id_to_key = [x.strip() for x in class_id_to_key]\n", "\n", "# Get top-1 result for TVM\n", "top1_tvm = np.argmax(tvm_output.numpy()[0])\n", "tvm_class_key = class_id_to_key[top1_tvm]\n", "\n", "# Convert input to PyTorch variable and get PyTorch result for comparison\n", "with torch.no_grad():\n", " torch_img = torch.from_numpy(img)\n", " output = model(torch_img)\n", "\n", " # Get top-1 result for PyTorch\n", " top1_torch = np.argmax(output.numpy())\n", " torch_class_key = class_id_to_key[top1_torch]\n", "\n", "print(\"Relay top-1 id: {}, class name: {}\".format(top1_tvm, key_to_classname[tvm_class_key]))\n", "print(\"Torch top-1 id: {}, class name: {}\".format(top1_torch, key_to_classname[torch_class_key]))" ] } ], "metadata": { "interpreter": { "hash": "6ee5142ba8a2589df39b0df03e82f50c3ae535c49aaf7d83abad1a0d572c7e37" }, "kernelspec": { "display_name": "Python 3.10.4 ('tvm-test')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }