{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(vta-mat-mult-opt)=\n", "# 分块矩阵乘法\n", "\n", "\n", "**原作者**: [Thierry Moreau](https://homes.cs.washington.edu/~moreau/)\n", "\n", "本教程概述了如何在 VTA 设计中使用 TVM 有效地映射矩阵乘法。建议先学习 {ref}`basic-mat-mult` 教程。\n", "\n", "在本教程中,将演示 TVM 调度优化,将大型神经网络算子分解为较小的块,以在有限的硬件加速器资源内实现计算。\n", "\n", "## RPC 设置\n", "\n", "首先编程 Pynq 的 FPGA 并构建它的 RPC 运行时。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import os\n", "import tvm\n", "from tvm import te\n", "import vta\n", "import numpy as np\n", "from tvm import rpc\n", "from tvm.contrib import utils\n", "from vta.testing import simulator\n", "\n", "# Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file\n", "env = vta.get_env()\n", "\n", "# We read the Pynq RPC host IP address and port number from the OS environment\n", "host = os.environ.get(\"VTA_RPC_HOST\", \"192.168.2.99\")\n", "port = int(os.environ.get(\"VTA_RPC_PORT\", \"9091\"))\n", "\n", "# We configure both the bitstream and the runtime system on the Pynq\n", "# to match the VTA configuration specified by the vta_config.json file.\n", "if env.TARGET == \"pynq\":\n", "\n", " # Make sure that TVM was compiled with RPC=1\n", " assert tvm.runtime.enabled(\"rpc\")\n", " remote = rpc.connect(host, port)\n", "\n", " # Reconfigure the JIT runtime\n", " vta.reconfig_runtime(remote)\n", "\n", " # Program the FPGA with a pre-compiled VTA bitstream.\n", " # You can program the FPGA with your own custom bitstream\n", " # by passing the path to the bitstream file instead of None.\n", " vta.program_fpga(remote, bitstream=None)\n", "\n", "# In simulation mode, host the RPC server locally.\n", "elif env.TARGET in [\"sim\", \"tsim\"]:\n", " remote = rpc.LocalSession()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 声明计算\n", "\n", "作为第一步,需要描述矩阵乘法的计算。将矩阵乘法定义为全连接层中的计算,由其 batch size、输入通道和输出通道定义。它们必须是 VTA 张量形状的整数倍:`BATCH`、`BLOCK_IN` 和 `BLOCK_OUT`。\n", "\n", "在矩阵乘法中添加额外的算子,这些算子对输出进行了移位(shifting)和剪切(clipping),以模拟定点矩阵乘法,然后是修正的线性激活。将全连通层的 TVM 数据流图描述如下:\n", "\n", "```{image} images/fc_dataflow.png\n", ":align: center\n", "```\n", "\n", "此计算被故意设置得太大,以至于不能一次全部放入 VTA 的 on-chip buffer。因此,在调度阶段,将依靠计算阻塞策略将计算分解为可管理的块。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Fully connected layer dimensions: 1024 x 1024\n", "batch_size = 1\n", "in_channels = 1024\n", "out_channels = 1024\n", "assert batch_size % env.BATCH == 0\n", "assert in_channels % env.BLOCK_IN == 0\n", "assert out_channels % env.BLOCK_OUT == 0\n", "\n", "# Let's derive the tiled input tensor shapes\n", "data_shape = (batch_size // env.BATCH, in_channels // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)\n", "weight_shape = (\n", " out_channels // env.BLOCK_OUT,\n", " in_channels // env.BLOCK_IN,\n", " env.BLOCK_OUT,\n", " env.BLOCK_IN,\n", ")\n", "output_shape = (batch_size // env.BATCH, out_channels // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT)\n", "num_ops = in_channels * out_channels * batch_size * 2\n", "\n", "# Reduction axes\n", "ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name=\"ic\")\n", "ic_tns = te.reduce_axis((0, env.BLOCK_IN), name=\"ic_tns\")\n", "\n", "# Input placeholder tensors\n", "data = te.placeholder(data_shape, name=\"data\", dtype=env.inp_dtype)\n", "weight = te.placeholder(weight_shape, name=\"weight\", dtype=env.wgt_dtype)\n", "\n", "# Copy buffers\n", "data_buf = te.compute(data_shape, lambda *i: data(*i), \"data_buf\")\n", "weight_buf = te.compute(weight_shape, lambda *i: weight(*i), \"weight_buf\")\n", "\n", "# Declare matrix multiply computation\n", "res_gemm = te.compute(\n", " output_shape,\n", " lambda bo, co, bi, ci: te.sum(\n", " data_buf[bo, ic, bi, ic_tns].astype(env.acc_dtype)\n", " * weight_buf[co, ic, ci, ic_tns].astype(env.acc_dtype),\n", " axis=[ic, ic_tns],\n", " ),\n", " name=\"res_gem\",\n", ")\n", "\n", "# Add shift stage for fix-point normalization\n", "res_shr = te.compute(output_shape, lambda *i: res_gemm(*i) >> env.INP_WIDTH, name=\"res_shr\")\n", "\n", "# Apply clipping between (0, input max value)\n", "inp_max = (1 << (env.INP_WIDTH - 1)) - 1\n", "res_max = te.compute(output_shape, lambda *i: tvm.te.max(res_shr(*i), 0), \"res_max\")\n", "res_min = te.compute(output_shape, lambda *i: tvm.te.min(res_max(*i), inp_max), \"res_min\")\n", "\n", "# Apply typecast to input data type before sending results back\n", "res = te.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name=\"res\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 调度计算\n", "\n", "查看一组必要的调度变换,以有效的方式将矩阵乘法映射到 VTA。这些包括:\n", "\n", "- 分块计算(Computation blocking)\n", "- Lowering 到 VTA 硬件 intrinsics" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@main = primfn(data_1: handle, weight_1: handle, res_1: handle) -> ()\n", " attr = {\"from_legacy_te_schedule\": True, \"global_symbol\": \"main\", \"tir.noalias\": True}\n", " buffers = {data: Buffer(data_2: Pointer(int8), int8, [1024], []),\n", " weight: Buffer(weight_2: Pointer(int8), int8, [1048576], []),\n", " res: Buffer(res_2: Pointer(int8), int8, [1024], [])}\n", " buffer_map = {data_1: data, weight_1: weight, res_1: res}\n", " preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 64, 1, 16], []), weight_1: weight_3: Buffer(weight_2, int8, [64, 64, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 64, 1, 16], [])} {\n", " allocate(data_buf: Pointer(global int8), int8, [1024]), storage_scope = global;\n", " allocate(weight_buf: Pointer(global int8), int8, [1048576]), storage_scope = global;\n", " allocate(res_gem: Pointer(global int32), int32, [1024]), storage_scope = global {\n", " for (i1: int32, 0, 64) {\n", " for (i3: int32, 0, 16) {\n", " let cse_var_1: int32 = ((i1*16) + i3)\n", " data_buf_1: Buffer(data_buf, int8, [1024], [])[cse_var_1] = data[cse_var_1]\n", " }\n", " }\n", " for (i0: int32, 0, 64) {\n", " for (i1_1: int32, 0, 64) {\n", " for (i2: int32, 0, 16) {\n", " for (i3_1: int32, 0, 16) {\n", " let cse_var_2: int32 = ((((i0*16384) + (i1_1*256)) + (i2*16)) + i3_1)\n", " weight_buf_1: Buffer(weight_buf, int8, [1048576], [])[cse_var_2] = weight[cse_var_2]\n", " }\n", " }\n", " }\n", " }\n", " for (co: int32, 0, 64) {\n", " for (ci: int32, 0, 16) {\n", " res_gem_1: Buffer(res_gem, int32, [1024], [])[((co*16) + ci)] = 0\n", " for (ic: int32, 0, 64) {\n", " for (ic_tns: int32, 0, 16) {\n", " let cse_var_3: int32 = ((co*16) + ci)\n", " res_gem_1[cse_var_3] = (res_gem_1[cse_var_3] + (cast(int32, data_buf_1[((ic*16) + ic_tns)])*cast(int32, weight_buf_1[((((co*16384) + (ic*256)) + (ci*16)) + ic_tns)])))\n", " }\n", " }\n", " }\n", " }\n", " for (i1_2: int32, 0, 64) {\n", " for (i3_2: int32, 0, 16) {\n", " let cse_var_4: int32 = ((i1_2*16) + i3_2)\n", " res_gem_2: Buffer(res_gem, int32, [1024], [])[cse_var_4] = @tir.shift_right(res_gem_1[cse_var_4], 8, dtype=int32)\n", " }\n", " }\n", " for (i1_3: int32, 0, 64) {\n", " for (i3_3: int32, 0, 16) {\n", " let cse_var_5: int32 = ((i1_3*16) + i3_3)\n", " res_gem_3: Buffer(res_gem, int32, [1024], [])[cse_var_5] = max(res_gem_2[cse_var_5], 0)\n", " }\n", " }\n", " for (i1_4: int32, 0, 64) {\n", " for (i3_4: int32, 0, 16) {\n", " let cse_var_6: int32 = ((i1_4*16) + i3_4)\n", " res_gem_4: Buffer(res_gem, int32, [1024], [])[cse_var_6] = min(res_gem_3[cse_var_6], 127)\n", " }\n", " }\n", " for (i1_5: int32, 0, 64) {\n", " for (i3_5: int32, 0, 16) {\n", " let cse_var_7: int32 = ((i1_5*16) + i3_5)\n", " res[cse_var_7] = cast(int8, res_gem_4[cse_var_7])\n", " }\n", " }\n", " }\n", "}\n", "\n", "\n" ] } ], "source": [ "# Create TVM schedule\n", "s = te.create_schedule(res.op)\n", "# Let's look at the default TVM schedule\n", "print(tvm.lower(s, [data, weight, res], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 分块计算\n", "\n", "在默认情况下,矩阵乘法对于激活或权重来说太大了,无法一次性适应 VTA 的 on-chip buffer。将 (1, 1024)×(1024, 1024) 矩阵乘法分成更小的 (1, 256) × (256, 256) 矩阵乘法,这样中间张量就可以装进加速器的 on-chip SRAM 中。这种方法类似于将分块技术应用于 CPU 和 GPU,以提高缓存命中率(cache hit rate)。\n", "\n", "沿着每个轴执行分块(batch 轴不受影响,因为正在执行单 batch 推理)。也保持最内侧的 tensorization 轴不变,以便 TVM 能够进行模式匹配的 tensorization。在下面的图表中展示了分块在计算调度上的结果:\n", "\n", "```{image} images/blocking.png\n", ":align: center\n", ":width: 480px\n", "```\n", "\n", "````{admonition} 循环分割(splitting)和重新排序(reordering)后的代码等价于下面的伪代码。忽略 batch 轴,因为在这个例子中只执行单 batch 推断:\n", ":class: alert alert-info\n", "```c\n", "for (int oc_out = 0; oc_out < 4; ++oc_out) {\n", " // Initialization loop\n", " for (int oc_inn = 0; oc_inn < 16; ++oc_inn) {\n", " for (int oc_tns = 0; oc_tns < 16; ++oc_tns) {\n", " int j = (oc_out * 16 + oc_inn) * 16 + oc_tns;\n", " C[0][j] = 0;\n", " }\n", " }\n", " for (int ic_out = 0; ic_out < 4; ++ic_out) {\n", " // Block loop\n", " for (int oc_inn = 0; oc_inn < 16; ++oc_inn) {\n", " for (int ic_inn = 0; ic_inn < 16; ++ic_inn) {\n", " // Tensorization loop\n", " for (int oc_tns = 0; oc_tns < 16; ++oc_tns) {\n", " for (int ic_tns = 0; ic_tns < 16; ++ic_tns) {\n", " int i = (ic_out * 16 + ic_inn) * 16 + ic_tns;\n", " int j = (oc_out * 16 + oc_inn) * 16 + oc_tns;\n", " C[0][i] = C[0][i] + A[0][i] * B[j][i];\n", " }\n", " }\n", " }\n", " }\n", " }\n", " }\n", "}\n", "```\n", "````" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@main = primfn(data_1: handle, weight_1: handle, res_1: handle) -> ()\n", " attr = {\"from_legacy_te_schedule\": True, \"global_symbol\": \"main\", \"tir.noalias\": True}\n", " buffers = {data: Buffer(data_2: Pointer(int8), int8, [1024], []),\n", " weight: Buffer(weight_2: Pointer(int8), int8, [1048576], []),\n", " res: Buffer(res_2: Pointer(int8), int8, [1024], [])}\n", " buffer_map = {data_1: data, weight_1: weight, res_1: res}\n", " preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 64, 1, 16], []), weight_1: weight_3: Buffer(weight_2, int8, [64, 64, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 64, 1, 16], [])} {\n", " allocate(data_buf: Pointer(global int8), int8, [1024]), storage_scope = global;\n", " allocate(weight_buf: Pointer(global int8), int8, [1048576]), storage_scope = global;\n", " allocate(res_gem: Pointer(global int32), int32, [256]), storage_scope = global {\n", " for (i1: int32, 0, 64) {\n", " for (i3: int32, 0, 16) {\n", " let cse_var_1: int32 = ((i1*16) + i3)\n", " data_buf_1: Buffer(data_buf, int8, [1024], [])[cse_var_1] = data[cse_var_1]\n", " }\n", " }\n", " for (i0: int32, 0, 64) {\n", " for (i1_1: int32, 0, 64) {\n", " for (i2: int32, 0, 16) {\n", " for (i3_1: int32, 0, 16) {\n", " let cse_var_2: int32 = ((((i0*16384) + (i1_1*256)) + (i2*16)) + i3_1)\n", " weight_buf_1: Buffer(weight_buf, int8, [1048576], [])[cse_var_2] = weight[cse_var_2]\n", " }\n", " }\n", " }\n", " }\n", " for (i1.outer: int32, 0, 4) {\n", " for (co.init: int32, 0, 16) {\n", " for (ci.init: int32, 0, 16) {\n", " res_gem_1: Buffer(res_gem, int32, [256], [])[((co.init*16) + ci.init)] = 0\n", " }\n", " }\n", " for (ic.outer: int32, 0, 4) {\n", " for (co: int32, 0, 16) {\n", " for (ic.inner: int32, 0, 16) {\n", " for (ci: int32, 0, 16) {\n", " for (ic_tns: int32, 0, 16) {\n", " let cse_var_3: int32 = ((co*16) + ci)\n", " res_gem_1[cse_var_3] = (res_gem_1[cse_var_3] + (cast(int32, data_buf_1[(((ic.outer*256) + (ic.inner*16)) + ic_tns)])*cast(int32, weight_buf_1[((((((i1.outer*262144) + (co*16384)) + (ic.outer*4096)) + (ic.inner*256)) + (ci*16)) + ic_tns)])))\n", " }\n", " }\n", " }\n", " }\n", " }\n", " for (i1_2: int32, 0, 16) {\n", " for (i3_2: int32, 0, 16) {\n", " let cse_var_4: int32 = ((i1_2*16) + i3_2)\n", " res_gem_2: Buffer(res_gem, int32, [256], [])[cse_var_4] = @tir.shift_right(res_gem_1[cse_var_4], 8, dtype=int32)\n", " }\n", " }\n", " for (i1_3: int32, 0, 16) {\n", " for (i3_3: int32, 0, 16) {\n", " let cse_var_5: int32 = ((i1_3*16) + i3_3)\n", " res_gem_3: Buffer(res_gem, int32, [256], [])[cse_var_5] = max(res_gem_2[cse_var_5], 0)\n", " }\n", " }\n", " for (i1_4: int32, 0, 16) {\n", " for (i3_4: int32, 0, 16) {\n", " let cse_var_6: int32 = ((i1_4*16) + i3_4)\n", " res_gem_4: Buffer(res_gem, int32, [256], [])[cse_var_6] = min(res_gem_3[cse_var_6], 127)\n", " }\n", " }\n", " for (i1.inner: int32, 0, 16) {\n", " for (i3_5: int32, 0, 16) {\n", " let cse_var_7: int32 = (i1.inner*16)\n", " res[(((i1.outer*256) + cse_var_7) + i3_5)] = cast(int8, res_gem_4[(cse_var_7 + i3_5)])\n", " }\n", " }\n", " }\n", " }\n", "}\n", "\n", "\n" ] } ], "source": [ "# Let's define tiling sizes (expressed in multiples of VTA tensor shape size)\n", "b_block = 1 // env.BATCH\n", "i_block = 256 // env.BLOCK_IN\n", "o_block = 256 // env.BLOCK_OUT\n", "\n", "# Tile the output tensor along the batch and output channel dimensions\n", "# (since by default we are doing single batch inference, the split along\n", "# the batch dimension has no effect)\n", "b, oc, b_tns, oc_tns = s[res].op.axis\n", "b_out, b_inn = s[res].split(b, b_block)\n", "oc_out, oc_inn = s[res].split(oc, o_block)\n", "s[res].reorder(b_out, oc_out, b_inn, oc_inn)\n", "\n", "# Move intermediate computation into each output compute tile\n", "s[res_gemm].compute_at(s[res], oc_out)\n", "s[res_shr].compute_at(s[res], oc_out)\n", "s[res_max].compute_at(s[res], oc_out)\n", "s[res_min].compute_at(s[res], oc_out)\n", "\n", "# Apply additional loop split along reduction axis (input channel)\n", "b_inn, oc_inn, b_tns, oc_tns = s[res_gemm].op.axis\n", "ic_out, ic_inn = s[res_gemm].split(ic, i_block)\n", "\n", "# Reorder axes. We move the ic_out axis all the way out of the GEMM\n", "# loop to block along the reduction axis\n", "s[res_gemm].reorder(ic_out, b_inn, oc_inn, ic_inn, b_tns, oc_tns, ic_tns)\n", "\n", "# Let's look at the current TVM schedule after blocking\n", "print(tvm.lower(s, [data, weight, res], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### lowering 复制到 DMA 传输\n", "\n", "接下来,将 buffer 作用域设置为相应的 on-chip VTA SRAM buffer。将 load 循环移动到矩阵乘法计算循环中,以使它们适合于 on-chip SRAM buffer。最后,用 DMA 复制实用程序对 load/store 循环外轴进行注解,以在 VTA 上执行批量内存传输。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Set scope of SRAM buffers\n", "s[data_buf].set_scope(env.inp_scope)\n", "s[weight_buf].set_scope(env.wgt_scope)\n", "s[res_gemm].set_scope(env.acc_scope)\n", "s[res_shr].set_scope(env.acc_scope)\n", "s[res_min].set_scope(env.acc_scope)\n", "s[res_max].set_scope(env.acc_scope)\n", "\n", "# Block data and weight cache reads\n", "s[data_buf].compute_at(s[res_gemm], ic_out)\n", "s[weight_buf].compute_at(s[res_gemm], ic_out)\n", "\n", "# Use DMA copy pragma on DRAM->SRAM operations\n", "s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy)\n", "s[weight_buf].pragma(s[weight_buf].op.axis[0], env.dma_copy)\n", "\n", "# Use DMA copy pragma on SRAM->DRAM operation\n", "# (this implies that these copies should be performed along b_inn,\n", "# or result axis 2)\n", "s[res].pragma(s[res].op.axis[2], env.dma_copy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Lowering 计算到 VTA Compute Intrinsics\n", "\n", "最后阶段是通过将矩阵乘法映射到张量 intrinsics,将 shift 映射到矢量 ALU,从而将计算循环 lowering 到 VTA 硬件 intrinsics。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@main = primfn(data_1: handle, weight_1: handle, res_1: handle) -> ()\n", " attr = {\"from_legacy_te_schedule\": True, \"global_symbol\": \"main\", \"tir.noalias\": True}\n", " buffers = {data: Buffer(data_2: Pointer(int8), int8, [1024], []),\n", " weight: Buffer(weight_2: Pointer(int8), int8, [1048576], []),\n", " res: Buffer(res_2: Pointer(int8), int8, [1024], [])}\n", " buffer_map = {data_1: data, weight_1: weight, res_1: res}\n", " preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 64, 1, 16], []), weight_1: weight_3: Buffer(weight_2, int8, [64, 64, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 64, 1, 16], [])} {\n", " @tir.vta.coproc_dep_push(3, 2, dtype=int32)\n", " for (i1.outer: int32, 0, 4) {\n", " attr [IterVar(vta: int32, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 2 {\n", " @tir.vta.coproc_dep_pop(3, 2, dtype=int32)\n", " attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushGEMMOp\" {\n", " @tir.call_extern(\"VTAUopLoopBegin\", 16, 1, 0, 0, dtype=int32)\n", " @tir.vta.uop_push(0, 1, 0, 0, 0, 0, 0, 0, dtype=int32)\n", " @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n", " }\n", " @tir.vta.coproc_dep_push(2, 1, dtype=int32)\n", " }\n", " for (ic.outer: int32, 0, 4) {\n", " let cse_var_1: int32 = (ic.outer*16)\n", " {\n", " attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 1 {\n", " @tir.vta.coproc_dep_pop(2, 1, dtype=int32)\n", " @tir.call_extern(\"VTALoadBuffer2D\", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), data_2, cse_var_1, 16, 1, 16, 0, 0, 0, 0, 0, 2, dtype=int32)\n", " @tir.call_extern(\"VTALoadBuffer2D\", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), weight_2, ((i1.outer*1024) + cse_var_1), 16, 16, 64, 0, 0, 0, 0, 0, 1, dtype=int32)\n", " @tir.vta.coproc_dep_push(1, 2, dtype=int32)\n", " }\n", " attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 2 {\n", " @tir.vta.coproc_dep_pop(1, 2, dtype=int32)\n", " attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushGEMMOp\" {\n", " @tir.call_extern(\"VTAUopLoopBegin\", 16, 1, 0, 16, dtype=int32)\n", " @tir.call_extern(\"VTAUopLoopBegin\", 16, 0, 1, 1, dtype=int32)\n", " @tir.vta.uop_push(0, 0, 0, 0, 0, 0, 0, 0, dtype=int32)\n", " @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n", " @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n", " }\n", " @tir.vta.coproc_dep_push(2, 1, dtype=int32)\n", " }\n", " }\n", " }\n", " @tir.vta.coproc_dep_pop(2, 1, dtype=int32)\n", " attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 2 {\n", " attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushALUOp\" {\n", " @tir.call_extern(\"VTAUopLoopBegin\", 16, 1, 1, 0, dtype=int32)\n", " @tir.vta.uop_push(1, 0, 0, 0, 0, 3, 1, 8, dtype=int32)\n", " @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n", " }\n", " attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushALUOp\" {\n", " @tir.call_extern(\"VTAUopLoopBegin\", 16, 1, 1, 0, dtype=int32)\n", " @tir.vta.uop_push(1, 0, 0, 0, 0, 1, 1, 0, dtype=int32)\n", " @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n", " }\n", " attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_uop_scope\" = \"VTAPushALUOp\" {\n", " @tir.call_extern(\"VTAUopLoopBegin\", 16, 1, 1, 0, dtype=int32)\n", " @tir.vta.uop_push(1, 0, 0, 0, 0, 0, 1, 127, dtype=int32)\n", " @tir.call_extern(\"VTAUopLoopEnd\", dtype=int32)\n", " }\n", " @tir.vta.coproc_dep_push(2, 3, dtype=int32)\n", " }\n", " attr [IterVar(vta, (nullptr), \"ThreadIndex\", \"vta\")] \"coproc_scope\" = 3 {\n", " @tir.vta.coproc_dep_pop(2, 3, dtype=int32)\n", " for (i1.inner: int32, 0, 16) {\n", " @tir.call_extern(\"VTAStoreBuffer2D\", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), i1.inner, 4, res_2, ((i1.outer*16) + i1.inner), 1, 1, 1, dtype=int32)\n", " }\n", " @tir.vta.coproc_dep_push(3, 2, dtype=int32)\n", " }\n", " }\n", " @tir.vta.coproc_sync(, dtype=int32)\n", " @tir.vta.coproc_dep_pop(3, 2, dtype=int32)\n", "}\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[21:37:07] /media/pc/data/4tb/lxw/books/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=128\n" ] } ], "source": [ "# Apply tensorization over the batch tensor tile axis\n", "s[res_gemm].tensorize(b_tns, env.gemm)\n", "\n", "# Add an ALU pragma over the shift and clipping operations\n", "s[res_shr].pragma(s[res_shr].op.axis[0], env.alu)\n", "s[res_min].pragma(s[res_min].op.axis[0], env.alu)\n", "s[res_max].pragma(s[res_max].op.axis[0], env.alu)\n", "\n", "# Let's look at the final lowered TVM schedule after lowering memory\n", "# loads/stores down to DMA copy intrinsics, and the computation down to\n", "# VTA compute intrinsics.\n", "print(vta.lower(s, [data, weight, res], simple_mode=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TVM 计算和验证\n", "\n", "在指定调度之后,可以将其编译为 TVM 函数。保存模块,这样就可以通过 RPC 发送它。运行该函数并对 numpy 实现进行验证,以确保其正确性。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[21:37:07] /media/pc/data/4tb/lxw/books/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=128\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Execution statistics:\n", "\tinp_load_nbytes : 4096\n", "\twgt_load_nbytes : 1048576\n", "\tacc_load_nbytes : 0\n", "\tuop_load_nbytes : 20\n", "\tout_store_nbytes: 1024\n", "\tgemm_counter : 4096\n", "\talu_counter : 192\n", "Successful blocked matrix multiply test!\n" ] } ], "source": [ "# Compile the TVM module\n", "my_gemm = vta.build(\n", " s, [data, weight, res], tvm.target.Target(\"ext_dev\", host=env.target_host), name=\"my_gemm\"\n", ")\n", "temp = utils.tempdir()\n", "my_gemm.save(temp.relpath(\"gemm.o\"))\n", "remote.upload(temp.relpath(\"gemm.o\"))\n", "f = remote.load_module(\"gemm.o\")\n", "\n", "# Get the remote device context\n", "ctx = remote.ext_dev(0)\n", "\n", "# Initialize the data and weight arrays randomly in the int range of (-128, 128]\n", "data_np = np.random.randint(-128, 128, size=(batch_size, in_channels)).astype(data.dtype)\n", "weight_np = np.random.randint(-128, 128, size=(out_channels, in_channels)).astype(weight.dtype)\n", "\n", "# Apply packing to the data and weight arrays from a 2D to a 4D packed layout\n", "data_packed = data_np.reshape(\n", " batch_size // env.BATCH, env.BATCH, in_channels // env.BLOCK_IN, env.BLOCK_IN\n", ").transpose((0, 2, 1, 3))\n", "weight_packed = weight_np.reshape(\n", " out_channels // env.BLOCK_OUT, env.BLOCK_OUT, in_channels // env.BLOCK_IN, env.BLOCK_IN\n", ").transpose((0, 2, 1, 3))\n", "\n", "# Format the input/output arrays with tvm.nd.array to the DLPack standard\n", "data_nd = tvm.nd.array(data_packed, ctx)\n", "weight_nd = tvm.nd.array(weight_packed, ctx)\n", "res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)\n", "\n", "# Clear stats\n", "if env.TARGET in [\"sim\", \"tsim\"]:\n", " simulator.clear_stats()\n", "\n", "# Invoke the module to perform the computation\n", "f(data_nd, weight_nd, res_nd)\n", "\n", "# Verify against numpy implementation\n", "res_ref = np.dot(data_np.astype(env.acc_dtype), weight_np.T.astype(env.acc_dtype))\n", "res_ref = res_ref >> env.INP_WIDTH\n", "res_ref = np.clip(res_ref, 0, inp_max)\n", "res_ref = res_ref.astype(res.dtype)\n", "res_ref = res_ref.reshape(\n", " batch_size // env.BATCH, env.BATCH, out_channels // env.BLOCK_OUT, env.BLOCK_OUT\n", ").transpose((0, 2, 1, 3))\n", "np.testing.assert_equal(res_ref, res_nd.numpy())\n", "\n", "# Print stats\n", "if env.TARGET in [\"sim\", \"tsim\"]:\n", " sim_stats = simulator.stats()\n", " print(\"Execution statistics:\")\n", " for k, v in sim_stats.items():\n", " print(\"\\t{:<16}: {:>16}\".format(k, v))\n", "\n", "print(\"Successful blocked matrix multiply test!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 小结\n", "\n", "本教程演示了 TVM 调度原语如何为矩阵乘法示例实现分块计算。这允许将任意大的计算映射到有限的硬件加速器资源上。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 0 }