{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(sphx_glr_how_to_compile_models_from_pytorch.py)=\n",
    "# 编译 PyTorch 模型\n",
    "\n",
    "**Author**: [Alex Wong](https://github.com/alexwong/)\n",
    "\n",
    "本文是使用 Relay 部署 PyTorch 模型的入门教程。\n",
    "\n",
    "对于我们来说,首先应该安装 PyTorch。TorchVision 也是必需的,因为我们将使用它作为我们的模型动物园。\n",
    "\n",
    "快速的解决方案是通过 pip 进行安装:\n",
    "\n",
    "```python\n",
    "pip install torch==1.7.0\n",
    "pip install torchvision==0.8.1\n",
    "```\n",
    "\n",
    "或者参考[官方网站](https://pytorch.org/get-started/locally/)。\n",
    "\n",
    "PyTorch 版本应该向后兼容,但应该与适当的 TorchVision 版本一起使用。\n",
    "\n",
    "目前,TVM 支持 PyTorch 1.7 和 1.4。其他版本可能不稳定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# PyTorch imports\n",
    "import torch\n",
    "import torchvision\n",
    "\n",
    "import set_env # 设置 TVM 环境\n",
    "import tvm\n",
    "from tvm import relay\n",
    "from tvm.contrib.download import download_testdata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 载入 PyTorch 预训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"resnet18\"\n",
    "model = getattr(torchvision.models, model_name)(pretrained=True)\n",
    "model = model.eval()\n",
    "\n",
    "# 我们通过跟踪获取 TorchScripted 模型\n",
    "input_shape = [1, 3, 224, 224]\n",
    "input_data = torch.randn(input_shape)\n",
    "scripted_model = torch.jit.trace(model, input_data).eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载测试图片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "img_url = \"https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true\"\n",
    "img_path = download_testdata(img_url, \"cat.png\", module=\"data\")\n",
    "img = Image.open(img_path).resize((224, 224))\n",
    "\n",
    "# Preprocess the image and convert to tensor\n",
    "from torchvision import transforms\n",
    "\n",
    "my_preprocess = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ]\n",
    ")\n",
    "img = my_preprocess(img)\n",
    "img = np.expand_dims(img, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入 Graph 到 Relay\n",
    "\n",
    "将 PyTorch 图转换为 Relay 图。`input_name` 可以是任意的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_name = \"input0\"\n",
    "shape_list = [(input_name, img.shape)]\n",
    "mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Relay 构建\n",
    "\n",
    "使用给定的输入规范将 graph 编译为 llvm 目标:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.\n"
     ]
    }
   ],
   "source": [
    "target = tvm.target.Target(\"llvm\", host=\"llvm\")\n",
    "dev = tvm.cpu(0)\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    lib = relay.build(mod, target=target, params=params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 在 TVM 上执行可移植 Graph\n",
    "\n",
    "可以尝试在目标上部署编译后的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.contrib import graph_executor\n",
    "\n",
    "dtype = \"float32\"\n",
    "m = graph_executor.GraphModule(lib[\"default\"](dev))\n",
    "# Set inputs\n",
    "m.set_input(input_name, tvm.nd.array(img.astype(dtype)))\n",
    "# Execute\n",
    "m.run()\n",
    "# Get outputs\n",
    "tvm_output = m.get_output(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 查找 synset 名称\n",
    "\n",
    "在 1000 类 synset 中查找预测 top 1 索引。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Relay top-1 id: 281, class name: tabby, tabby cat\n",
      "Torch top-1 id: 281, class name: tabby, tabby cat\n"
     ]
    }
   ],
   "source": [
    "synset_url = \"\".join(\n",
    "    [\n",
    "        \"https://raw.githubusercontent.com/Cadene/\",\n",
    "        \"pretrained-models.pytorch/master/data/\",\n",
    "        \"imagenet_synsets.txt\",\n",
    "    ]\n",
    ")\n",
    "synset_name = \"imagenet_synsets.txt\"\n",
    "synset_path = download_testdata(synset_url, synset_name, module=\"data\")\n",
    "with open(synset_path) as f:\n",
    "    synsets = f.readlines()\n",
    "\n",
    "synsets = [x.strip() for x in synsets]\n",
    "splits = [line.split(\" \") for line in synsets]\n",
    "key_to_classname = {spl[0]: \" \".join(spl[1:]) for spl in splits}\n",
    "\n",
    "class_url = \"\".join(\n",
    "    [\n",
    "        \"https://raw.githubusercontent.com/Cadene/\",\n",
    "        \"pretrained-models.pytorch/master/data/\",\n",
    "        \"imagenet_classes.txt\",\n",
    "    ]\n",
    ")\n",
    "class_name = \"imagenet_classes.txt\"\n",
    "class_path = download_testdata(class_url, class_name, module=\"data\")\n",
    "with open(class_path) as f:\n",
    "    class_id_to_key = f.readlines()\n",
    "\n",
    "class_id_to_key = [x.strip() for x in class_id_to_key]\n",
    "\n",
    "# Get top-1 result for TVM\n",
    "top1_tvm = np.argmax(tvm_output.numpy()[0])\n",
    "tvm_class_key = class_id_to_key[top1_tvm]\n",
    "\n",
    "# Convert input to PyTorch variable and get PyTorch result for comparison\n",
    "with torch.no_grad():\n",
    "    torch_img = torch.from_numpy(img)\n",
    "    output = model(torch_img)\n",
    "\n",
    "    # Get top-1 result for PyTorch\n",
    "    top1_torch = np.argmax(output.numpy())\n",
    "    torch_class_key = class_id_to_key[top1_torch]\n",
    "\n",
    "print(\"Relay top-1 id: {}, class name: {}\".format(top1_tvm, key_to_classname[tvm_class_key]))\n",
    "print(\"Torch top-1 id: {}, class name: {}\".format(top1_torch, key_to_classname[torch_class_key]))"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "6ee5142ba8a2589df39b0df03e82f50c3ae535c49aaf7d83abad1a0d572c7e37"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('tvm-test')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}