{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 2D 卷积优化\n",
        "\n",
        "**原作者**: [Thierry Moreau](https://homes.cs.washington.edu/~moreau/)\n",
        "\n",
        "本教程提供了关于如何使用 TVM 映射二维卷积工作负载有效的 VTA 设计的概述。建议先学习 {ref}`vta-mat-mult-opt` 教程。\n",
        "\n",
        "二维卷积在大多数计算机视觉深度神经网络中占主导地位。在本教程中,将演示 TVM 调度优化,将 NCHW 布局中的 2D 卷积算子映射到 VTA。还引入了延迟隐藏(latency hiding)的概念,它允许最大化 VTA 的计算和内存资源利用。\n",
        "\n",
        "## RPC 设置\n",
        "\n",
        "首先编程 Pynq 的 FPGA 并构建它的 RPC 运行时。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import tvm\n",
        "import tvm.testing\n",
        "from tvm import te\n",
        "import vta\n",
        "import numpy as np\n",
        "\n",
        "from tvm import rpc\n",
        "from tvm.contrib import utils\n",
        "from vta.testing import simulator\n",
        "\n",
        "# Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file\n",
        "env = vta.get_env()\n",
        "\n",
        "# We read the Pynq RPC host IP address and port number from the OS environment\n",
        "host = os.environ.get(\"VTA_RPC_HOST\", \"192.168.2.99\")\n",
        "port = int(os.environ.get(\"VTA_RPC_PORT\", \"9091\"))\n",
        "\n",
        "# We configure both the bitstream and the runtime system on the Pynq\n",
        "# to match the VTA configuration specified by the vta_config.json file.\n",
        "if env.TARGET == \"pynq\":\n",
        "    # Make sure that TVM was compiled with RPC=1\n",
        "    assert tvm.runtime.enabled(\"rpc\")\n",
        "    remote = rpc.connect(host, port)\n",
        "\n",
        "    # Reconfigure the JIT runtime\n",
        "    vta.reconfig_runtime(remote)\n",
        "\n",
        "    # Program the FPGA with a pre-compiled VTA bitstream.\n",
        "    # You can program the FPGA with your own custom bitstream\n",
        "    # by passing the path to the bitstream file instead of None.\n",
        "    vta.program_fpga(remote, bitstream=None)\n",
        "\n",
        "# In simulation mode, host the RPC server locally.\n",
        "elif env.TARGET in [\"sim\", \"tsim\"]:\n",
        "    remote = rpc.LocalSession()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 声明计算\n",
        "\n",
        "作为第一步,需要用 NCHW 格式描述 2D 卷积计算。\n",
        "\n",
        "通过 batch size、空间维度、输入通道、输出通道、核维度、核维度、填充维度和步长维度来定义二维卷积形状。\n",
        "\n",
        "选择 ResNet-18 架构的第 9 个卷积层的形状作为卷积 workload 参数。\n",
        "\n",
        "在 2D 卷积中添加了额外的算子,用于对输出进行移位和剪切,以模拟定点卷积之后的修正线性激活。将二维卷积层的 TVM 数据流图描述如下:\n",
        "\n",
        "```{image} images/conv2d_dataflow.png\n",
        ":align: center\n",
        "```\n",
        "\n",
        "这个计算被故意设置得太大,以至于不能一次全部放入 VTA 的 on-chip buffers。因此,在调度阶段,将依靠计算分块策略将计算分解为可管理的块。\n",
        "\n",
        "````{admonition} 空间填充\n",
        ":class: alert alert-info\n",
        "\n",
        "注意,需要导入 TOPI 库来对输入特征映射张量应用空间填充(Spatial padding)。空间填充有助于在 2D 卷积环境中分块,因为如果卷积核窗口大小大于 1,那么任何给定层的输入特征映射的相同 `(x, y)` 空间位置将被读取多次。在 CPU 和 GPU 上,当并行工作时,提高内存访问效率的一种方法是空间打包(spatial packing),这需要重新布局数据。VTA load DMA 引擎可以自动插入填充,这样原始的输入特征映射就不必在内存中重新打包。\n",
        "\n",
        "当数据从 DRAM load 到 VTA 的 SRAM 时,下面展示了 VTA 对动态空间填充的影响,随后是 2D 跨步和填充内存读取。\n",
        "\n",
        "```{image} images/padding.png\n",
        ":align: center\n",
        ":width: 480px\n",
        "```\n",
        "````"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from tvm import topi\n",
        "\n",
        "# 2D convolution layer dimensions taken from ResNet-18 architecture\n",
        "# (9th convolutional layer)\n",
        "batch_size = 1\n",
        "height = 14\n",
        "width = 14\n",
        "in_channels = 256\n",
        "out_channels = 256\n",
        "kernel_h = 3\n",
        "kernel_w = 3\n",
        "pad_h = 1\n",
        "pad_w = 1\n",
        "stride_h = 1\n",
        "stride_w = 1\n",
        "assert batch_size % env.BATCH == 0\n",
        "assert in_channels % env.BLOCK_IN == 0\n",
        "assert out_channels % env.BLOCK_OUT == 0\n",
        "\n",
        "# Input feature map: (N, IC, H, W, n, ic)\n",
        "data_shape = (\n",
        "    batch_size // env.BATCH,\n",
        "    in_channels // env.BLOCK_IN,\n",
        "    height,\n",
        "    width,\n",
        "    env.BATCH,\n",
        "    env.BLOCK_IN,\n",
        ")\n",
        "# Kernel: (OC, IC, H, W, oc, ic)\n",
        "kernel_shape = (\n",
        "    out_channels // env.BLOCK_OUT,\n",
        "    in_channels // env.BLOCK_IN,\n",
        "    kernel_h,\n",
        "    kernel_w,\n",
        "    env.BLOCK_OUT,\n",
        "    env.BLOCK_IN,\n",
        ")\n",
        "# Derive output feature map dimensions\n",
        "fout_height = (height + 2 * pad_h - kernel_h) // stride_h + 1\n",
        "fout_width = (width + 2 * pad_w - kernel_w) // stride_w + 1\n",
        "# Output feature map: (N, OC, H, W, n, oc)\n",
        "output_shape = (\n",
        "    batch_size // env.BATCH,\n",
        "    out_channels // env.BLOCK_OUT,\n",
        "    fout_height,\n",
        "    fout_width,\n",
        "    env.BATCH,\n",
        "    env.BLOCK_OUT,\n",
        ")\n",
        "\n",
        "# Convolution reduction axes\n",
        "dy = te.reduce_axis((0, kernel_h), name=\"dy\")\n",
        "dx = te.reduce_axis((0, kernel_w), name=\"dx\")\n",
        "ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name=\"ic\")\n",
        "ic_tns = te.reduce_axis((0, env.BLOCK_IN), name=\"ic_tns\")\n",
        "\n",
        "# Input placeholder tensors\n",
        "data = te.placeholder(data_shape, name=\"data\", dtype=env.inp_dtype)\n",
        "kernel = te.placeholder(kernel_shape, name=\"kernel\", dtype=env.wgt_dtype)\n",
        "\n",
        "# Copy buffers:\n",
        "#   Apply spatial padding to input feature map\n",
        "data_buf = topi.nn.pad(data, [0, 0, pad_h, pad_w, 0, 0], name=\"data_buf\")\n",
        "kernel_buf = te.compute(kernel_shape, lambda *i: kernel(*i), \"kernel_buf\")\n",
        "\n",
        "# Declare 2D convolution\n",
        "res_conv = te.compute(\n",
        "    output_shape,\n",
        "    lambda bo, co, i, j, bi, ci: te.sum(\n",
        "        data_buf[bo, ic, i * stride_h + dy, j * stride_w + dx, bi, ic_tns].astype(env.acc_dtype)\n",
        "        * kernel_buf[co, ic, dy, dx, ci, ic_tns].astype(env.acc_dtype),\n",
        "        axis=[ic, dy, dx, ic_tns],\n",
        "    ),\n",
        "    name=\"res_conv\",\n",
        ")\n",
        "\n",
        "# Add shift stage for fix-point normalization\n",
        "res_shr = te.compute(output_shape, lambda *i: res_conv(*i) >> 8, name=\"res_shr\")\n",
        "\n",
        "# Apply clipping between (0, input max value)\n",
        "inp_max = (1 << (env.INP_WIDTH - 1)) - 1\n",
        "res_max = te.compute(output_shape, lambda *i: tvm.te.max(res_shr(*i), 0), \"res_max\")\n",
        "res_min = te.compute(output_shape, lambda *i: tvm.te.min(res_max(*i), inp_max), \"res_min\")\n",
        "\n",
        "# Result Tensor\n",
        "res = te.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name=\"res\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 调度计算\n",
        "\n",
        "将看到一组必要的调度变换,以有效的方式将 2D 卷积映射到 VTA。这些包括:\n",
        "\n",
        "- 分块计算\n",
        "- 增加计算利用率(compute utilization)的虚拟线程(Virtual threading)\n",
        "- Lowering 到 VTA 硬件 intrinsics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@main = primfn(data_1: handle, kernel_1: handle, res_1: handle) -> ()\n",
            "  attr = {\"from_legacy_te_schedule\": True, \"global_symbol\": \"main\", \"tir.noalias\": True}\n",
            "  buffers = {data: Buffer(data_2: Pointer(int8), int8, [50176], []),\n",
            "             kernel: Buffer(kernel_2: Pointer(int8), int8, [589824], []),\n",
            "             res: Buffer(res_2: Pointer(int8), int8, [50176], [])}\n",
            "  buffer_map = {data_1: data, kernel_1: kernel, res_1: res}\n",
            "  preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 16, 14, 14, 1, 16], []), kernel_1: kernel_3: Buffer(kernel_2, int8, [16, 16, 3, 3, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 16, 14, 14, 1, 16], [])} {\n",
            "  allocate(data_buf: Pointer(global int8), int8, [65536]), storage_scope = global;\n",
            "  allocate(kernel_buf: Pointer(global int8), int8, [589824]), storage_scope = global;\n",
            "  allocate(res_conv: Pointer(global int32), int32, [50176]), storage_scope = global {\n",
            "    for (i1: int32, 0, 16) {\n",
            "      for (i2: int32, 0, 16) {\n",
            "        for (i3: int32, 0, 16) {\n",
            "          for (i5: int32, 0, 16) {\n",
            "            let cse_var_1: int32 = (i3*16)\n",
            "            data_buf_1: Buffer(data_buf, int8, [65536], [])[((((i1*4096) + (i2*256)) + cse_var_1) + i5)] = @tir.if_then_else(((((1 <= i2) && (i2 < 15)) && (1 <= i3)) && (i3 < 15)), data[(((((i1*3136) + (i2*224)) + cse_var_1) + i5) - 240)], 0i8, dtype=int8)\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "    for (i0: int32, 0, 16) {\n",
            "      for (i1_1: int32, 0, 16) {\n",
            "        for (i2_1: int32, 0, 3) {\n",
            "          for (i3_1: int32, 0, 3) {\n",
            "            for (i4: int32, 0, 16) {\n",
            "              for (i5_1: int32, 0, 16) {\n",
            "                let cse_var_2: int32 = ((((((i0*36864) + (i1_1*2304)) + (i2_1*768)) + (i3_1*256)) + (i4*16)) + i5_1)\n",
            "                kernel_buf_1: Buffer(kernel_buf, int8, [589824], [])[cse_var_2] = kernel[cse_var_2]\n",
            "              }\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "    for (co: int32, 0, 16) {\n",
            "      for (i: int32, 0, 14) {\n",
            "        for (j: int32, 0, 14) {\n",
            "          for (ci: int32, 0, 16) {\n",
            "            res_conv_1: Buffer(res_conv, int32, [50176], [])[((((co*3136) + (i*224)) + (j*16)) + ci)] = 0\n",
            "            for (ic: int32, 0, 16) {\n",
            "              for (dy: int32, 0, 3) {\n",
            "                for (dx: int32, 0, 3) {\n",
            "                  for (ic_tns: int32, 0, 16) {\n",
            "                    let cse_var_4: int32 = (j*16)\n",
            "                    let cse_var_3: int32 = ((((co*3136) + (i*224)) + cse_var_4) + ci)\n",
            "                    res_conv_1[cse_var_3] = (res_conv_1[cse_var_3] + (cast(int32, data_buf_1[((((((ic*4096) + (i*256)) + (dy*256)) + cse_var_4) + (dx*16)) + ic_tns)])*cast(int32, kernel_buf_1[((((((co*36864) + (ic*2304)) + (dy*768)) + (dx*256)) + (ci*16)) + ic_tns)])))\n",
            "                  }\n",
            "                }\n",
            "              }\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "    for (i1_2: int32, 0, 16) {\n",
            "      for (i2_2: int32, 0, 14) {\n",
            "        for (i3_2: int32, 0, 14) {\n",
            "          for (i5_2: int32, 0, 16) {\n",
            "            let cse_var_5: int32 = ((((i1_2*3136) + (i2_2*224)) + (i3_2*16)) + i5_2)\n",
            "            res_conv_2: Buffer(res_conv, int32, [50176], [])[cse_var_5] = @tir.shift_right(res_conv_1[cse_var_5], 8, dtype=int32)\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "    for (i1_3: int32, 0, 16) {\n",
            "      for (i2_3: int32, 0, 14) {\n",
            "        for (i3_3: int32, 0, 14) {\n",
            "          for (i5_3: int32, 0, 16) {\n",
            "            let cse_var_6: int32 = ((((i1_3*3136) + (i2_3*224)) + (i3_3*16)) + i5_3)\n",
            "            res_conv_3: Buffer(res_conv, int32, [50176], [])[cse_var_6] = max(res_conv_2[cse_var_6], 0)\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "    for (i1_4: int32, 0, 16) {\n",
            "      for (i2_4: int32, 0, 14) {\n",
            "        for (i3_4: int32, 0, 14) {\n",
            "          for (i5_4: int32, 0, 16) {\n",
            "            let cse_var_7: int32 = ((((i1_4*3136) + (i2_4*224)) + (i3_4*16)) + i5_4)\n",
            "            res_conv_4: Buffer(res_conv, int32, [50176], [])[cse_var_7] = min(res_conv_3[cse_var_7], 127)\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "    for (i1_5: int32, 0, 16) {\n",
            "      for (i2_5: int32, 0, 14) {\n",
            "        for (i3_5: int32, 0, 14) {\n",
            "          for (i5_5: int32, 0, 16) {\n",
            "            let cse_var_8: int32 = ((((i1_5*3136) + (i2_5*224)) + (i3_5*16)) + i5_5)\n",
            "            res[cse_var_8] = cast(int8, res_conv_4[cse_var_8])\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "  }\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# Create TVM schedule\n",
        "s = te.create_schedule(res.op)\n",
        "# Let's look at the default TVM schedule\n",
        "print(tvm.lower(s, [data, kernel, res], simple_mode=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 分块计算\n",
        "\n",
        "默认情况下,2D 卷积太大,激活或卷积核权重无法同时适应 VTA 的 on-chip buffer。沿着输入通道、输出通道和高度空间维度应用分块。不沿宽度空间维度进行分块,因为它是 NCHW 布局中的最内层维度(因此,为了增加局部性,最好不要沿最内层维度进行分块)。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Let's define tiling sizes\n",
        "b_block = 1 // env.BATCH\n",
        "oc_block = 128 // env.BLOCK_OUT\n",
        "ic_block = 16 // env.BLOCK_IN\n",
        "h_block = 7\n",
        "w_block = 14\n",
        "\n",
        "# Tile the output tensor along the spatial and output channel dimensions\n",
        "# (since by default we are doing single batch inference, the split along\n",
        "#  the batch dimension has no effect)\n",
        "b, oc, y, x, b_tns, oc_tns = s[res].op.axis\n",
        "b_out, b_inn = s[res].split(b, factor=b_block)\n",
        "oc_out, oc_inn = s[res].split(oc, factor=oc_block)\n",
        "y_out, y_inn = s[res].split(y, factor=h_block)\n",
        "x_out, x_inn = s[res].split(x, factor=w_block)\n",
        "s[res].reorder(b_out, oc_out, y_out, x_out, b_inn, oc_inn, y_inn, x_inn, b_tns, oc_tns)\n",
        "\n",
        "# Move intermediate computation into each output compute tile\n",
        "s[res_conv].compute_at(s[res], x_out)\n",
        "s[res_shr].compute_at(s[res], x_out)\n",
        "s[res_max].compute_at(s[res], x_out)\n",
        "s[res_min].compute_at(s[res], x_out)\n",
        "\n",
        "# Apply additional loop split along reduction axis (input channel)\n",
        "b_inn, oc_inn, y_inn, x_inn, b_tns, oc_tns = s[res_conv].op.axis\n",
        "ic_out, ic_inn = s[res_conv].split(ic, factor=ic_block)\n",
        "\n",
        "# Reorder axes.\n",
        "# 1) Group the VTA tensor axes in the inner most position: b_tns, oc_tns, ic_tns\n",
        "#    to allow TVM to tensorize.\n",
        "# 2) We move the ic_out axis all the way out of the convolution loop to block\n",
        "#    along the reduction axis.\n",
        "# 3) Now we re-order the block axes: b_inn, oc_inn, y_inn, x_inn, ic_inn, dy, dx.\n",
        "#    VTA runtime/hardware requires us to write to a different output feature map\n",
        "#    location for every VTA tensor operation.\n",
        "#    This restriction requires us to order one of oc_inn, y_inn or x_inn right\n",
        "#    before b_tns, since they all affect output feature map indexing.\n",
        "#    Therefore, we choose to bring x_inn inside as shown below.\n",
        "s[res_conv].reorder(ic_out, b_inn, oc_inn, y_inn, ic_inn, dy, dx, x_inn, b_tns, oc_tns, ic_tns)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 虚拟线程\n",
        "\n",
        "虚拟线程是一种在 VTA 硬件设计中增加任务级管道并行性的机制。换句话说,它通过隐藏内存访问延迟(hiding memory access latency)提高了计算资源的利用率。\n",
        "\n",
        "在下面的实现中,虚拟线程将工作分配给沿输出通道轴划分的两个线程。在下面的图中,展示了计算 2D 卷积时工作是如何分割的。\n",
        "\n",
        "```{image} images/virtual_threading.png\n",
        ":align: center\n",
        ":width: 480px\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@main = primfn(data_1: handle, kernel_1: handle, res_1: handle) -> ()\n",
            "  attr = {\"from_legacy_te_schedule\": True, \"global_symbol\": \"main\", \"tir.noalias\": True}\n",
            "  buffers = {data: Buffer(data_2: Pointer(int8), int8, [50176], []),\n",
            "             kernel: Buffer(kernel_2: Pointer(int8), int8, [589824], []),\n",
            "             res: Buffer(res_2: Pointer(int8), int8, [50176], [])}\n",
            "  buffer_map = {data_1: data, kernel_1: kernel, res_1: res}\n",
            "  preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 16, 14, 14, 1, 16], []), kernel_1: kernel_3: Buffer(kernel_2, int8, [16, 16, 3, 3, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 16, 14, 14, 1, 16], [])} {\n",
            "  allocate(data_buf: Pointer(global int8), int8, [65536]), storage_scope = global;\n",
            "  allocate(kernel_buf: Pointer(global int8), int8, [589824]), storage_scope = global;\n",
            "  allocate(res_conv: Pointer(global int32), int32, [25088]), storage_scope = global {\n",
            "    for (i1: int32, 0, 16) {\n",
            "      for (i2: int32, 0, 16) {\n",
            "        for (i3: int32, 0, 16) {\n",
            "          for (i5: int32, 0, 16) {\n",
            "            let cse_var_1: int32 = (i3*16)\n",
            "            data_buf_1: Buffer(data_buf, int8, [65536], [])[((((i1*4096) + (i2*256)) + cse_var_1) + i5)] = @tir.if_then_else(((((1 <= i2) && (i2 < 15)) && (1 <= i3)) && (i3 < 15)), data[(((((i1*3136) + (i2*224)) + cse_var_1) + i5) - 240)], 0i8, dtype=int8)\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "    for (i0: int32, 0, 16) {\n",
            "      for (i1_1: int32, 0, 16) {\n",
            "        for (i2_1: int32, 0, 3) {\n",
            "          for (i3_1: int32, 0, 3) {\n",
            "            for (i4: int32, 0, 16) {\n",
            "              for (i5_1: int32, 0, 16) {\n",
            "                let cse_var_2: int32 = ((((((i0*36864) + (i1_1*2304)) + (i2_1*768)) + (i3_1*256)) + (i4*16)) + i5_1)\n",
            "                kernel_buf_1: Buffer(kernel_buf, int8, [589824], [])[cse_var_2] = kernel[cse_var_2]\n",
            "              }\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "    for (i2.outer: int32, 0, 2) {\n",
            "      for (co.init: int32, 0, 8) {\n",
            "        for (i.init: int32, 0, 7) {\n",
            "          for (j.init: int32, 0, 14) {\n",
            "            for (ci.init: int32, 0, 16) {\n",
            "              let cse_var_3: int32 = ((((co.init*1568) + (i.init*224)) + (j.init*16)) + ci.init)\n",
            "               {\n",
            "                res_conv_1: Buffer(res_conv, int32, [157351936], [])[cse_var_3] = 0\n",
            "                res_conv_1[(cse_var_3 + 12544)] = 0\n",
            "              }\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "      for (ic.outer: int32, 0, 16) {\n",
            "        for (co: int32, 0, 8) {\n",
            "          for (i: int32, 0, 7) {\n",
            "            for (dy: int32, 0, 3) {\n",
            "              for (dx: int32, 0, 3) {\n",
            "                for (j: int32, 0, 14) {\n",
            "                  for (ci: int32, 0, 16) {\n",
            "                    for (ic_tns: int32, 0, 16) {\n",
            "                      let cse_var_8: int32 = (j*16)\n",
            "                      let cse_var_7: int32 = ((((co*1568) + (i*224)) + cse_var_8) + ci)\n",
            "                      let cse_var_6: int32 = (cse_var_7 + 12544)\n",
            "                      let cse_var_5: int32 = ((((((co*36864) + (ic.outer*2304)) + (dy*768)) + (dx*256)) + (ci*16)) + ic_tns)\n",
            "                      let cse_var_4: int32 = (((((((ic.outer*4096) + (i2.outer*1792)) + (i*256)) + (dy*256)) + cse_var_8) + (dx*16)) + ic_tns)\n",
            "                       {\n",
            "                        res_conv_1[cse_var_7] = (res_conv_1[cse_var_7] + (cast(int32, data_buf_1[cse_var_4])*cast(int32, kernel_buf_1[cse_var_5])))\n",
            "                        res_conv_1[cse_var_6] = (res_conv_1[cse_var_6] + (cast(int32, data_buf_1[cse_var_4])*cast(int32, kernel_buf_1[(cse_var_5 + 294912)])))\n",
            "                      }\n",
            "                    }\n",
            "                  }\n",
            "                }\n",
            "              }\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "      for (i1_2: int32, 0, 8) {\n",
            "        for (i2_2: int32, 0, 7) {\n",
            "          for (i3_2: int32, 0, 14) {\n",
            "            for (i5_2: int32, 0, 16) {\n",
            "              let cse_var_10: int32 = ((((i1_2*1568) + (i2_2*224)) + (i3_2*16)) + i5_2)\n",
            "              let cse_var_9: int32 = (cse_var_10 + 12544)\n",
            "               {\n",
            "                res_conv_2: Buffer(res_conv, int32, [157351936], [])[cse_var_10] = @tir.shift_right(res_conv_1[cse_var_10], 8, dtype=int32)\n",
            "                res_conv_2[cse_var_9] = @tir.shift_right(res_conv_1[cse_var_9], 8, dtype=int32)\n",
            "              }\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "      for (i1_3: int32, 0, 8) {\n",
            "        for (i2_3: int32, 0, 7) {\n",
            "          for (i3_3: int32, 0, 14) {\n",
            "            for (i5_3: int32, 0, 16) {\n",
            "              let cse_var_12: int32 = ((((i1_3*1568) + (i2_3*224)) + (i3_3*16)) + i5_3)\n",
            "              let cse_var_11: int32 = (cse_var_12 + 12544)\n",
            "               {\n",
            "                res_conv_3: Buffer(res_conv, int32, [157351936], [])[cse_var_12] = max(res_conv_2[cse_var_12], 0)\n",
            "                res_conv_3[cse_var_11] = max(res_conv_2[cse_var_11], 0)\n",
            "              }\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "      for (i1_4: int32, 0, 8) {\n",
            "        for (i2_4: int32, 0, 7) {\n",
            "          for (i3_4: int32, 0, 14) {\n",
            "            for (i5_4: int32, 0, 16) {\n",
            "              let cse_var_14: int32 = ((((i1_4*1568) + (i2_4*224)) + (i3_4*16)) + i5_4)\n",
            "              let cse_var_13: int32 = (cse_var_14 + 12544)\n",
            "               {\n",
            "                res_conv_4: Buffer(res_conv, int32, [157351936], [])[cse_var_14] = min(res_conv_3[cse_var_14], 127)\n",
            "                res_conv_4[cse_var_13] = min(res_conv_3[cse_var_13], 127)\n",
            "              }\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "      for (i1.inner: int32, 0, 8) {\n",
            "        for (i2.inner: int32, 0, 7) {\n",
            "          for (i3.inner: int32, 0, 14) {\n",
            "            for (i5_5: int32, 0, 16) {\n",
            "              let cse_var_18: int32 = (i2.inner*224)\n",
            "              let cse_var_17: int32 = (i3.inner*16)\n",
            "              let cse_var_16: int32 = ((((i1.inner*1568) + cse_var_18) + cse_var_17) + i5_5)\n",
            "              let cse_var_15: int32 = (((((i1.inner*3136) + (i2.outer*1568)) + cse_var_18) + cse_var_17) + i5_5)\n",
            "               {\n",
            "                res[cse_var_15] = cast(int8, res_conv_4[cse_var_16])\n",
            "                res[(cse_var_15 + 25088)] = cast(int8, res_conv_4[(cse_var_16 + 12544)])\n",
            "              }\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "  }\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# VTA only supports 2 virtual threads\n",
        "v_threads = 2\n",
        "\n",
        "# Perform virtual thread split along output channel outer axis\n",
        "_, tx = s[res].split(oc_out, factor=v_threads)\n",
        "s[res].reorder(tx, b_out)\n",
        "s[res].bind(tx, te.thread_axis(\"cthread\"))\n",
        "\n",
        "# Let's look at the current TVM schedule after blocking and virtual threading\n",
        "print(tvm.lower(s, [data, kernel, res], simple_mode=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Lowering Copies 到 DMA Transfers\n",
        "\n",
        "接下来,设置相应的 on-chip VTA SRAM buffers 的 buffers 作用域。将 load 循环移动到 2D 卷积计算循环,以 stage 内存加载,以便它们适合 on-chip SRAM buffers。最后,用 DMA 复制 pragma 注解了 load/store 循环外轴,以便在 VTA 上执行大容量内存传输。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Set scope of SRAM buffers\n",
        "s[data_buf].set_scope(env.inp_scope)\n",
        "s[kernel_buf].set_scope(env.wgt_scope)\n",
        "s[res_conv].set_scope(env.acc_scope)\n",
        "s[res_shr].set_scope(env.acc_scope)\n",
        "s[res_min].set_scope(env.acc_scope)\n",
        "s[res_max].set_scope(env.acc_scope)\n",
        "\n",
        "# Block data and kernel cache reads\n",
        "s[data_buf].compute_at(s[res_conv], ic_out)\n",
        "s[kernel_buf].compute_at(s[res_conv], ic_out)\n",
        "\n",
        "# Use DMA copy pragma on DRAM->SRAM operations\n",
        "s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy)\n",
        "s[kernel_buf].pragma(s[kernel_buf].op.axis[0], env.dma_copy)\n",
        "\n",
        "# Use DMA copy pragma on SRAM->DRAM operation in each result block\n",
        "# (this implies that these copies should be performed along b_inn,\n",
        "# or result axis 4)\n",
        "s[res].pragma(s[res].op.axis[4], env.dma_copy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Lowering 计算到 VTA 计算 Intrinsics\n",
        "\n",
        "最后阶段是通过将二维卷积映射为张量 intrinsics,并将位移和剪切计算映射为向量 ALU,从而将计算循环 lower 到 VTA 硬件 intrinsics。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@main = primfn(data_1: handle, kernel_1: handle, res_1: handle) -> ()\n",
            "  attr = {\"from_legacy_te_schedule\": True, \"global_symbol\": \"main\", \"tir.noalias\": True}\n",
            "  buffers = {data: Buffer(data_2: Pointer(int8), int8, [50176], []),\n",
            "             kernel: Buffer(kernel_2: Pointer(int8), int8, [589824], []),\n",
            "             res: Buffer(res_2: Pointer(int8), int8, [50176], [])}\n",
            "  buffer_map = {data_1: data, kernel_1: kernel, res_1: res}\n",
            "  preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 16, 14, 14, 1, 16], []), kernel_1: kernel_3: Buffer(kernel_2, int8, [16, 16, 3, 3, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 16, 14, 14, 1, 16], [])} {\n",
            "  @tir.vta.coproc_dep_push(3, 2, dtype=int32)\n",
            "  @tir.vta.coproc_dep_push(3, 2, dtype=int32)\n",
            "  for (i2.outer: int32, 0, 2) {\n",
            "    for (cthread.s: int32, 0, 2) {\n",
            "      attr [IterVar(vta: int32, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 2 {\n",
            "        @tir.vta.coproc_dep_pop(3, 2, dtype=int32)\n",
            "        attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushGEMMOp\" {\n",
            "          @tir.call_extern(\"VTAUopLoopBegin\", 8, 98, 0, 0, dtype=int32)\n",
            "          @tir.call_extern(\"VTAUopLoopBegin\", 7, 14, 0, 0, dtype=int32)\n",
            "          for (j.init: int32, 0, 14) {\n",
            "            @tir.vta.uop_push(0, 1, ((cthread.s*784) + j.init), 0, 0, 0, 0, 0, dtype=int32)\n",
            "          }\n",
            "          @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n",
            "          @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n",
            "        }\n",
            "        @tir.vta.coproc_dep_push(2, 1, dtype=int32)\n",
            "      }\n",
            "    }\n",
            "    for (ic.outer: int32, 0, 16) {\n",
            "      let cse_var_6: int32 = (i2.outer*7)\n",
            "      let cse_var_5: int32 = (ic.outer*9)\n",
            "      let cse_var_4: int32 = max((1 - cse_var_6), 0)\n",
            "      let cse_var_3: int32 = max((cse_var_6 - 6), 0)\n",
            "      let cse_var_2: int32 = ((9 - cse_var_4) - cse_var_3)\n",
            "      let cse_var_1: int32 = ((((ic.outer*196) + (i2.outer*98)) + (cse_var_4*14)) - 14)\n",
            "       {\n",
            "        attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 1 {\n",
            "          @tir.vta.coproc_dep_pop(2, 1, dtype=int32)\n",
            "          @tir.call_extern(\"VTALoadBuffer2D\", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), data_2, cse_var_1, 14, cse_var_2, 14, 1, cse_var_4, 1, cse_var_3, 0, 2, dtype=int32)\n",
            "          @tir.call_extern(\"VTALoadBuffer2D\", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), kernel_2, cse_var_5, 9, 8, 144, 0, 0, 0, 0, 0, 1, dtype=int32)\n",
            "          @tir.vta.coproc_dep_push(1, 2, dtype=int32)\n",
            "        }\n",
            "        attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 1 {\n",
            "          @tir.vta.coproc_dep_pop(2, 1, dtype=int32)\n",
            "          @tir.call_extern(\"VTALoadBuffer2D\", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), data_2, cse_var_1, 14, cse_var_2, 14, 1, cse_var_4, 1, cse_var_3, 144, 2, dtype=int32)\n",
            "          @tir.call_extern(\"VTALoadBuffer2D\", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), kernel_2, (cse_var_5 + 1152), 9, 8, 144, 0, 0, 0, 0, 72, 1, dtype=int32)\n",
            "          @tir.vta.coproc_dep_push(1, 2, dtype=int32)\n",
            "        }\n",
            "        for (cthread.s_1: int32, 0, 2) {\n",
            "          attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 2 {\n",
            "            @tir.vta.coproc_dep_pop(1, 2, dtype=int32)\n",
            "            attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushGEMMOp\" {\n",
            "              @tir.call_extern(\"VTAUopLoopBegin\", 8, 98, 0, 9, dtype=int32)\n",
            "              @tir.call_extern(\"VTAUopLoopBegin\", 7, 14, 16, 0, dtype=int32)\n",
            "              for (dy: int32, 0, 3) {\n",
            "                for (dx: int32, 0, 3) {\n",
            "                  for (j: int32, 0, 14) {\n",
            "                    @tir.vta.uop_push(0, 0, ((cthread.s_1*784) + j), ((((cthread.s_1*144) + (dy*16)) + j) + dx), (((cthread.s_1*72) + (dy*3)) + dx), 0, 0, 0, dtype=int32)\n",
            "                  }\n",
            "                }\n",
            "              }\n",
            "              @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n",
            "              @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n",
            "            }\n",
            "            @tir.vta.coproc_dep_push(2, 1, dtype=int32)\n",
            "          }\n",
            "        }\n",
            "      }\n",
            "    }\n",
            "    @tir.vta.coproc_dep_pop(2, 1, dtype=int32)\n",
            "    @tir.vta.coproc_dep_pop(2, 1, dtype=int32)\n",
            "    for (cthread.s_2: int32, 0, 2) {\n",
            "      let cse_var_7: int32 = (cthread.s_2*784)\n",
            "      attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 2 {\n",
            "        attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushALUOp\" {\n",
            "          @tir.call_extern(\"VTAUopLoopBegin\", 784, 1, 1, 0, dtype=int32)\n",
            "          @tir.vta.uop_push(1, 0, cse_var_7, cse_var_7, 0, 3, 1, 8, dtype=int32)\n",
            "          @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n",
            "        }\n",
            "        attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushALUOp\" {\n",
            "          @tir.call_extern(\"VTAUopLoopBegin\", 784, 1, 1, 0, dtype=int32)\n",
            "          @tir.vta.uop_push(1, 0, cse_var_7, cse_var_7, 0, 1, 1, 0, dtype=int32)\n",
            "          @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n",
            "        }\n",
            "        attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushALUOp\" {\n",
            "          @tir.call_extern(\"VTAUopLoopBegin\", 784, 1, 1, 0, dtype=int32)\n",
            "          @tir.vta.uop_push(1, 0, cse_var_7, cse_var_7, 0, 0, 1, 127, dtype=int32)\n",
            "          @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n",
            "        }\n",
            "        @tir.vta.coproc_dep_push(2, 3, dtype=int32)\n",
            "      }\n",
            "    }\n",
            "    for (cthread.s_3: int32, 0, 2) {\n",
            "      attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 3 {\n",
            "        @tir.vta.coproc_dep_pop(2, 3, dtype=int32)\n",
            "        for (i1.inner: int32, 0, 8) {\n",
            "          for (i2.inner: int32, 0, 7) {\n",
            "            for (i3.inner: int32, 0, 14) {\n",
            "              let cse_var_8: int32 = (i2.inner*14)\n",
            "              @tir.call_extern(\"VTAStoreBuffer2D\", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), ((((cthread.s_3*784) + (i1.inner*98)) + cse_var_8) + i3.inner), 4, res_2, (((((cthread.s_3*1568) + (i1.inner*196)) + (i2.outer*98)) + cse_var_8) + i3.inner), 1, 1, 1, dtype=int32)\n",
            "            }\n",
            "          }\n",
            "        }\n",
            "        @tir.vta.coproc_dep_push(3, 2, dtype=int32)\n",
            "      }\n",
            "    }\n",
            "  }\n",
            "  @tir.vta.coproc_dep_pop(3, 2, dtype=int32)\n",
            "  @tir.vta.coproc_dep_pop(3, 2, dtype=int32)\n",
            "  @tir.vta.coproc_sync(, dtype=int32)\n",
            "}\n",
            "\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "[16:02:31] /media/pc/data/4tb/lxw/books/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=128\n"
          ]
        }
      ],
      "source": [
        "# Apply tensorization over the batch tensor tile axis\n",
        "s[res_conv].tensorize(b_tns, env.gemm)\n",
        "\n",
        "# Add an ALU pragma over the shift and clipping operations\n",
        "s[res_shr].pragma(s[res_shr].op.axis[0], env.alu)\n",
        "s[res_min].pragma(s[res_min].op.axis[0], env.alu)\n",
        "s[res_max].pragma(s[res_max].op.axis[0], env.alu)\n",
        "\n",
        "# Let's look at the final lowered TVM schedule after lowering memory\n",
        "# loads/stores down to DMA copy intrinsics, and the computation down to\n",
        "# VTA compute intrinsics.\n",
        "print(vta.lower(s, [data, kernel, res], simple_mode=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## TVM 计算和验证\n",
        "\n",
        "在指定调度之后,可以将其编译为 TVM 函数。保存模块,这样就可以通过 RPC 发送它。运行该函数并对 numpy 实现进行验证,以确保其正确性。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "[16:02:32] /media/pc/data/4tb/lxw/books/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=128\n",
            "/media/workspace/anaconda3/envs/mx/lib/python3.10/site-packages/tvm/driver/build_module.py:263: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Execution statistics:\n",
            "\tinp_load_nbytes :           114688\n",
            "\twgt_load_nbytes :          1179648\n",
            "\tacc_load_nbytes :                0\n",
            "\tuop_load_nbytes :             1144\n",
            "\tout_store_nbytes:            50176\n",
            "\tgemm_counter    :           451584\n",
            "\talu_counter     :             9408\n",
            "Successful 2D convolution test!\n"
          ]
        }
      ],
      "source": [
        "# This library facilitates 2D convolution testing\n",
        "from tvm.topi.testing import conv2d_nchw_python\n",
        "\n",
        "# Compile the TVM module\n",
        "with vta.build_config(disabled_pass={\"tir.CommonSubexprElimTIR\"}):\n",
        "    my_conv = vta.build(\n",
        "        s, [data, kernel, res], tvm.target.Target(\"ext_dev\", host=env.target_host), name=\"my_conv\"\n",
        "    )\n",
        "temp = utils.tempdir()\n",
        "my_conv.save(temp.relpath(\"conv2d.o\"))\n",
        "remote.upload(temp.relpath(\"conv2d.o\"))\n",
        "f = remote.load_module(\"conv2d.o\")\n",
        "\n",
        "# Get the remote device context\n",
        "ctx = remote.ext_dev(0)\n",
        "\n",
        "# Initialize the data and kernel arrays randomly in the int range\n",
        "# of (-128, 128] in NCHW layout\n",
        "data_np = np.random.randint(-128, 128, size=(batch_size, in_channels, height, width)).astype(\n",
        "    data.dtype\n",
        ")\n",
        "kernel_np = np.random.randint(\n",
        "    -128, 128, size=(out_channels, in_channels, kernel_h, kernel_w)\n",
        ").astype(kernel.dtype)\n",
        "\n",
        "# Apply packing to the data and kernel arrays from a 2D NCHW\n",
        "# to a 4D NCHWnc packed layout\n",
        "data_packed = data_np.reshape(\n",
        "    batch_size // env.BATCH, env.BATCH, in_channels // env.BLOCK_IN, env.BLOCK_IN, height, width\n",
        ").transpose((0, 2, 4, 5, 1, 3))\n",
        "\n",
        "kernel_packed = kernel_np.reshape(\n",
        "    out_channels // env.BLOCK_OUT,\n",
        "    env.BLOCK_OUT,\n",
        "    in_channels // env.BLOCK_IN,\n",
        "    env.BLOCK_IN,\n",
        "    kernel_h,\n",
        "    kernel_w,\n",
        ").transpose((0, 2, 4, 5, 1, 3))\n",
        "\n",
        "# Format the input/output arrays with tvm.nd.array to the DLPack standard\n",
        "data_nd = tvm.nd.array(data_packed, ctx)\n",
        "kernel_nd = tvm.nd.array(kernel_packed, ctx)\n",
        "res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)\n",
        "\n",
        "# Clear stats\n",
        "if env.TARGET in [\"sim\", \"tsim\"]:\n",
        "    simulator.clear_stats()\n",
        "\n",
        "# Invoke the module to perform the computation\n",
        "f(data_nd, kernel_nd, res_nd)\n",
        "\n",
        "# Verify against numpy implementation\n",
        "res_ref = conv2d_nchw_python(\n",
        "    data_np.astype(env.acc_dtype),\n",
        "    kernel_np.astype(env.acc_dtype),\n",
        "    (stride_h, stride_w),\n",
        "    (pad_h, pad_w),\n",
        ").astype(env.acc_dtype)\n",
        "res_ref = res_ref >> env.INP_WIDTH\n",
        "res_ref = np.clip(res_ref, 0, inp_max)\n",
        "res_ref = res_ref.astype(res.dtype)\n",
        "res_ref = res_ref.reshape(\n",
        "    (\n",
        "        batch_size // env.BATCH,\n",
        "        env.BATCH,\n",
        "        out_channels // env.BLOCK_OUT,\n",
        "        env.BLOCK_OUT,\n",
        "        fout_height,\n",
        "        fout_width,\n",
        "    )\n",
        ").transpose((0, 2, 4, 5, 1, 3))\n",
        "tvm.testing.assert_allclose(res_ref, res_nd.numpy())\n",
        "\n",
        "# Print stats\n",
        "if env.TARGET in [\"sim\", \"tsim\"]:\n",
        "    sim_stats = simulator.stats()\n",
        "    print(\"Execution statistics:\")\n",
        "    for k, v in sim_stats.items():\n",
        "        print(\"\\t{:<16}: {:>16}\".format(k, v))\n",
        "\n",
        "print(\"Successful 2D convolution test!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 小结\n",
        "\n",
        "本教程演示如何使用 TVM 调度原语 lower 硬件加速器 intrinsics 的 2D 卷积,利用特定于硬件的优化,比如使用带虚拟线程的隐藏延迟。"
      ]
    }
  ],
  "metadata": {
    "interpreter": {
      "hash": "ee40e4cbda3c4716866f133b45765e0887afdbc9aa3bd872ab229f889d521355"
    },
    "kernelspec": {
      "display_name": "Python 3.10.4 ('mx')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}