分块矩阵乘法#

原作者: Thierry Moreau

本教程概述了如何在 VTA 设计中使用 TVM 有效地映射矩阵乘法。建议先学习 简单的矩阵乘法 教程。

在本教程中,将演示 TVM 调度优化,将大型神经网络算子分解为较小的块,以在有限的硬件加速器资源内实现计算。

RPC 设置#

首先编程 Pynq 的 FPGA 并构建它的 RPC 运行时。

import os
import tvm
from tvm import te
import vta
import numpy as np
from tvm import rpc
from tvm.contrib import utils
from vta.testing import simulator

# Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env()

# We read the Pynq RPC host IP address and port number from the OS environment
host = os.environ.get("VTA_RPC_HOST", "192.168.2.99")
port = int(os.environ.get("VTA_RPC_PORT", "9091"))

# We configure both the bitstream and the runtime system on the Pynq
# to match the VTA configuration specified by the vta_config.json file.
if env.TARGET == "pynq":

    # Make sure that TVM was compiled with RPC=1
    assert tvm.runtime.enabled("rpc")
    remote = rpc.connect(host, port)

    # Reconfigure the JIT runtime
    vta.reconfig_runtime(remote)

    # Program the FPGA with a pre-compiled VTA bitstream.
    # You can program the FPGA with your own custom bitstream
    # by passing the path to the bitstream file instead of None.
    vta.program_fpga(remote, bitstream=None)

# In simulation mode, host the RPC server locally.
elif env.TARGET in ["sim", "tsim"]:
    remote = rpc.LocalSession()

声明计算#

作为第一步,需要描述矩阵乘法的计算。将矩阵乘法定义为全连接层中的计算,由其 batch size、输入通道和输出通道定义。它们必须是 VTA 张量形状的整数倍:BATCHBLOCK_INBLOCK_OUT

在矩阵乘法中添加额外的算子,这些算子对输出进行了移位(shifting)和剪切(clipping),以模拟定点矩阵乘法,然后是修正的线性激活。将全连通层的 TVM 数据流图描述如下:

../../../../../_images/fc_dataflow.png

此计算被故意设置得太大,以至于不能一次全部放入 VTA 的 on-chip buffer。因此,在调度阶段,将依靠计算阻塞策略将计算分解为可管理的块。

# Fully connected layer dimensions: 1024 x 1024
batch_size = 1
in_channels = 1024
out_channels = 1024
assert batch_size % env.BATCH == 0
assert in_channels % env.BLOCK_IN == 0
assert out_channels % env.BLOCK_OUT == 0

# Let's derive the tiled input tensor shapes
data_shape = (batch_size // env.BATCH, in_channels // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
weight_shape = (
    out_channels // env.BLOCK_OUT,
    in_channels // env.BLOCK_IN,
    env.BLOCK_OUT,
    env.BLOCK_IN,
)
output_shape = (batch_size // env.BATCH, out_channels // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT)
num_ops = in_channels * out_channels * batch_size * 2

# Reduction axes
ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name="ic")
ic_tns = te.reduce_axis((0, env.BLOCK_IN), name="ic_tns")

# Input placeholder tensors
data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
weight = te.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype)

# Copy buffers
data_buf = te.compute(data_shape, lambda *i: data(*i), "data_buf")
weight_buf = te.compute(weight_shape, lambda *i: weight(*i), "weight_buf")

# Declare matrix multiply computation
res_gemm = te.compute(
    output_shape,
    lambda bo, co, bi, ci: te.sum(
        data_buf[bo, ic, bi, ic_tns].astype(env.acc_dtype)
        * weight_buf[co, ic, ci, ic_tns].astype(env.acc_dtype),
        axis=[ic, ic_tns],
    ),
    name="res_gem",
)

# Add shift stage for fix-point normalization
res_shr = te.compute(output_shape, lambda *i: res_gemm(*i) >> env.INP_WIDTH, name="res_shr")

# Apply clipping between (0, input max value)
inp_max = (1 << (env.INP_WIDTH - 1)) - 1
res_max = te.compute(output_shape, lambda *i: tvm.te.max(res_shr(*i), 0), "res_max")
res_min = te.compute(output_shape, lambda *i: tvm.te.min(res_max(*i), inp_max), "res_min")

# Apply typecast to input data type before sending results back
res = te.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")

调度计算#

查看一组必要的调度变换,以有效的方式将矩阵乘法映射到 VTA。这些包括:

  • 分块计算(Computation blocking)

  • Lowering 到 VTA 硬件 intrinsics

# Create TVM schedule
s = te.create_schedule(res.op)
# Let's look at the default TVM schedule
print(tvm.lower(s, [data, weight, res], simple_mode=True))
@main = primfn(data_1: handle, weight_1: handle, res_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {data: Buffer(data_2: Pointer(int8), int8, [1024], []),
             weight: Buffer(weight_2: Pointer(int8), int8, [1048576], []),
             res: Buffer(res_2: Pointer(int8), int8, [1024], [])}
  buffer_map = {data_1: data, weight_1: weight, res_1: res}
  preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 64, 1, 16], []), weight_1: weight_3: Buffer(weight_2, int8, [64, 64, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 64, 1, 16], [])} {
  allocate(data_buf: Pointer(global int8), int8, [1024]), storage_scope = global;
  allocate(weight_buf: Pointer(global int8), int8, [1048576]), storage_scope = global;
  allocate(res_gem: Pointer(global int32), int32, [1024]), storage_scope = global {
    for (i1: int32, 0, 64) {
      for (i3: int32, 0, 16) {
        let cse_var_1: int32 = ((i1*16) + i3)
        data_buf_1: Buffer(data_buf, int8, [1024], [])[cse_var_1] = data[cse_var_1]
      }
    }
    for (i0: int32, 0, 64) {
      for (i1_1: int32, 0, 64) {
        for (i2: int32, 0, 16) {
          for (i3_1: int32, 0, 16) {
            let cse_var_2: int32 = ((((i0*16384) + (i1_1*256)) + (i2*16)) + i3_1)
            weight_buf_1: Buffer(weight_buf, int8, [1048576], [])[cse_var_2] = weight[cse_var_2]
          }
        }
      }
    }
    for (co: int32, 0, 64) {
      for (ci: int32, 0, 16) {
        res_gem_1: Buffer(res_gem, int32, [1024], [])[((co*16) + ci)] = 0
        for (ic: int32, 0, 64) {
          for (ic_tns: int32, 0, 16) {
            let cse_var_3: int32 = ((co*16) + ci)
            res_gem_1[cse_var_3] = (res_gem_1[cse_var_3] + (cast(int32, data_buf_1[((ic*16) + ic_tns)])*cast(int32, weight_buf_1[((((co*16384) + (ic*256)) + (ci*16)) + ic_tns)])))
          }
        }
      }
    }
    for (i1_2: int32, 0, 64) {
      for (i3_2: int32, 0, 16) {
        let cse_var_4: int32 = ((i1_2*16) + i3_2)
        res_gem_2: Buffer(res_gem, int32, [1024], [])[cse_var_4] = @tir.shift_right(res_gem_1[cse_var_4], 8, dtype=int32)
      }
    }
    for (i1_3: int32, 0, 64) {
      for (i3_3: int32, 0, 16) {
        let cse_var_5: int32 = ((i1_3*16) + i3_3)
        res_gem_3: Buffer(res_gem, int32, [1024], [])[cse_var_5] = max(res_gem_2[cse_var_5], 0)
      }
    }
    for (i1_4: int32, 0, 64) {
      for (i3_4: int32, 0, 16) {
        let cse_var_6: int32 = ((i1_4*16) + i3_4)
        res_gem_4: Buffer(res_gem, int32, [1024], [])[cse_var_6] = min(res_gem_3[cse_var_6], 127)
      }
    }
    for (i1_5: int32, 0, 64) {
      for (i3_5: int32, 0, 16) {
        let cse_var_7: int32 = ((i1_5*16) + i3_5)
        res[cse_var_7] = cast(int8, res_gem_4[cse_var_7])
      }
    }
  }
}

分块计算#

在默认情况下,矩阵乘法对于激活或权重来说太大了,无法一次性适应 VTA 的 on-chip buffer。将 (1, 1024)×(1024, 1024) 矩阵乘法分成更小的 (1, 256) × (256, 256) 矩阵乘法,这样中间张量就可以装进加速器的 on-chip SRAM 中。这种方法类似于将分块技术应用于 CPU 和 GPU,以提高缓存命中率(cache hit rate)。

沿着每个轴执行分块(batch 轴不受影响,因为正在执行单 batch 推理)。也保持最内侧的 tensorization 轴不变,以便 TVM 能够进行模式匹配的 tensorization。在下面的图表中展示了分块在计算调度上的结果:

../../../../../_images/blocking.png

循环分割(splitting)和重新排序(reordering)后的代码等价于下面的伪代码。忽略 batch 轴,因为在这个例子中只执行单 batch 推断:

for (int oc_out = 0; oc_out < 4; ++oc_out) {
  // Initialization loop
  for (int oc_inn = 0; oc_inn < 16; ++oc_inn) {
   for (int oc_tns = 0; oc_tns < 16; ++oc_tns) {
    int j = (oc_out * 16 + oc_inn) * 16 + oc_tns;
    C[0][j] = 0;
   }
  }
  for (int ic_out = 0; ic_out < 4; ++ic_out) {
   // Block loop
   for (int oc_inn = 0; oc_inn < 16; ++oc_inn) {
    for (int ic_inn = 0; ic_inn < 16; ++ic_inn) {
     // Tensorization loop
     for (int oc_tns = 0; oc_tns < 16; ++oc_tns) {
      for (int ic_tns = 0; ic_tns < 16; ++ic_tns) {
       int i = (ic_out * 16 + ic_inn) * 16 + ic_tns;
       int j = (oc_out * 16 + oc_inn) * 16 + oc_tns;
       C[0][i] = C[0][i] + A[0][i] * B[j][i];
      }
     }
    }
   }
  }
 }
}
# Let's define tiling sizes (expressed in multiples of VTA tensor shape size)
b_block = 1 // env.BATCH
i_block = 256 // env.BLOCK_IN
o_block = 256 // env.BLOCK_OUT

# Tile the output tensor along the batch and output channel dimensions
# (since by default we are doing single batch inference, the split along
#  the batch dimension has no effect)
b, oc, b_tns, oc_tns = s[res].op.axis
b_out, b_inn = s[res].split(b, b_block)
oc_out, oc_inn = s[res].split(oc, o_block)
s[res].reorder(b_out, oc_out, b_inn, oc_inn)

# Move intermediate computation into each output compute tile
s[res_gemm].compute_at(s[res], oc_out)
s[res_shr].compute_at(s[res], oc_out)
s[res_max].compute_at(s[res], oc_out)
s[res_min].compute_at(s[res], oc_out)

# Apply additional loop split along reduction axis (input channel)
b_inn, oc_inn, b_tns, oc_tns = s[res_gemm].op.axis
ic_out, ic_inn = s[res_gemm].split(ic, i_block)

# Reorder axes. We move the ic_out axis all the way out of the GEMM
# loop to block along the reduction axis
s[res_gemm].reorder(ic_out, b_inn, oc_inn, ic_inn, b_tns, oc_tns, ic_tns)

# Let's look at the current TVM schedule after blocking
print(tvm.lower(s, [data, weight, res], simple_mode=True))
@main = primfn(data_1: handle, weight_1: handle, res_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {data: Buffer(data_2: Pointer(int8), int8, [1024], []),
             weight: Buffer(weight_2: Pointer(int8), int8, [1048576], []),
             res: Buffer(res_2: Pointer(int8), int8, [1024], [])}
  buffer_map = {data_1: data, weight_1: weight, res_1: res}
  preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 64, 1, 16], []), weight_1: weight_3: Buffer(weight_2, int8, [64, 64, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 64, 1, 16], [])} {
  allocate(data_buf: Pointer(global int8), int8, [1024]), storage_scope = global;
  allocate(weight_buf: Pointer(global int8), int8, [1048576]), storage_scope = global;
  allocate(res_gem: Pointer(global int32), int32, [256]), storage_scope = global {
    for (i1: int32, 0, 64) {
      for (i3: int32, 0, 16) {
        let cse_var_1: int32 = ((i1*16) + i3)
        data_buf_1: Buffer(data_buf, int8, [1024], [])[cse_var_1] = data[cse_var_1]
      }
    }
    for (i0: int32, 0, 64) {
      for (i1_1: int32, 0, 64) {
        for (i2: int32, 0, 16) {
          for (i3_1: int32, 0, 16) {
            let cse_var_2: int32 = ((((i0*16384) + (i1_1*256)) + (i2*16)) + i3_1)
            weight_buf_1: Buffer(weight_buf, int8, [1048576], [])[cse_var_2] = weight[cse_var_2]
          }
        }
      }
    }
    for (i1.outer: int32, 0, 4) {
      for (co.init: int32, 0, 16) {
        for (ci.init: int32, 0, 16) {
          res_gem_1: Buffer(res_gem, int32, [256], [])[((co.init*16) + ci.init)] = 0
        }
      }
      for (ic.outer: int32, 0, 4) {
        for (co: int32, 0, 16) {
          for (ic.inner: int32, 0, 16) {
            for (ci: int32, 0, 16) {
              for (ic_tns: int32, 0, 16) {
                let cse_var_3: int32 = ((co*16) + ci)
                res_gem_1[cse_var_3] = (res_gem_1[cse_var_3] + (cast(int32, data_buf_1[(((ic.outer*256) + (ic.inner*16)) + ic_tns)])*cast(int32, weight_buf_1[((((((i1.outer*262144) + (co*16384)) + (ic.outer*4096)) + (ic.inner*256)) + (ci*16)) + ic_tns)])))
              }
            }
          }
        }
      }
      for (i1_2: int32, 0, 16) {
        for (i3_2: int32, 0, 16) {
          let cse_var_4: int32 = ((i1_2*16) + i3_2)
          res_gem_2: Buffer(res_gem, int32, [256], [])[cse_var_4] = @tir.shift_right(res_gem_1[cse_var_4], 8, dtype=int32)
        }
      }
      for (i1_3: int32, 0, 16) {
        for (i3_3: int32, 0, 16) {
          let cse_var_5: int32 = ((i1_3*16) + i3_3)
          res_gem_3: Buffer(res_gem, int32, [256], [])[cse_var_5] = max(res_gem_2[cse_var_5], 0)
        }
      }
      for (i1_4: int32, 0, 16) {
        for (i3_4: int32, 0, 16) {
          let cse_var_6: int32 = ((i1_4*16) + i3_4)
          res_gem_4: Buffer(res_gem, int32, [256], [])[cse_var_6] = min(res_gem_3[cse_var_6], 127)
        }
      }
      for (i1.inner: int32, 0, 16) {
        for (i3_5: int32, 0, 16) {
          let cse_var_7: int32 = (i1.inner*16)
          res[(((i1.outer*256) + cse_var_7) + i3_5)] = cast(int8, res_gem_4[(cse_var_7 + i3_5)])
        }
      }
    }
  }
}

lowering 复制到 DMA 传输#

接下来,将 buffer 作用域设置为相应的 on-chip VTA SRAM buffer。将 load 循环移动到矩阵乘法计算循环中,以使它们适合于 on-chip SRAM buffer。最后,用 DMA 复制实用程序对 load/store 循环外轴进行注解,以在 VTA 上执行批量内存传输。

# Set scope of SRAM buffers
s[data_buf].set_scope(env.inp_scope)
s[weight_buf].set_scope(env.wgt_scope)
s[res_gemm].set_scope(env.acc_scope)
s[res_shr].set_scope(env.acc_scope)
s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(env.acc_scope)

# Block data and weight cache reads
s[data_buf].compute_at(s[res_gemm], ic_out)
s[weight_buf].compute_at(s[res_gemm], ic_out)

# Use DMA copy pragma on DRAM->SRAM operations
s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy)
s[weight_buf].pragma(s[weight_buf].op.axis[0], env.dma_copy)

# Use DMA copy pragma on SRAM->DRAM operation
# (this implies that these copies should be performed along b_inn,
# or result axis 2)
s[res].pragma(s[res].op.axis[2], env.dma_copy)

Lowering 计算到 VTA Compute Intrinsics#

最后阶段是通过将矩阵乘法映射到张量 intrinsics,将 shift 映射到矢量 ALU,从而将计算循环 lowering 到 VTA 硬件 intrinsics。

# Apply tensorization over the batch tensor tile axis
s[res_gemm].tensorize(b_tns, env.gemm)

# Add an ALU pragma over the shift and clipping operations
s[res_shr].pragma(s[res_shr].op.axis[0], env.alu)
s[res_min].pragma(s[res_min].op.axis[0], env.alu)
s[res_max].pragma(s[res_max].op.axis[0], env.alu)

# Let's look at the final lowered TVM schedule after lowering memory
# loads/stores down to DMA copy intrinsics, and the computation down to
# VTA compute intrinsics.
print(vta.lower(s, [data, weight, res], simple_mode=True))
@main = primfn(data_1: handle, weight_1: handle, res_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {data: Buffer(data_2: Pointer(int8), int8, [1024], []),
             weight: Buffer(weight_2: Pointer(int8), int8, [1048576], []),
             res: Buffer(res_2: Pointer(int8), int8, [1024], [])}
  buffer_map = {data_1: data, weight_1: weight, res_1: res}
  preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 64, 1, 16], []), weight_1: weight_3: Buffer(weight_2, int8, [64, 64, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 64, 1, 16], [])} {
  @tir.vta.coproc_dep_push(3, 2, dtype=int32)
  for (i1.outer: int32, 0, 4) {
    attr [IterVar(vta: int32, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 2 {
      @tir.vta.coproc_dep_pop(3, 2, dtype=int32)
      attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushGEMMOp" {
        @tir.call_extern("VTAUopLoopBegin", 16, 1, 0, 0, dtype=int32)
        @tir.vta.uop_push(0, 1, 0, 0, 0, 0, 0, 0, dtype=int32)
        @tir.call_extern("VTAUopLoopEnd", dtype=int32)
      }
      @tir.vta.coproc_dep_push(2, 1, dtype=int32)
    }
    for (ic.outer: int32, 0, 4) {
      let cse_var_1: int32 = (ic.outer*16)
       {
        attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 1 {
          @tir.vta.coproc_dep_pop(2, 1, dtype=int32)
          @tir.call_extern("VTALoadBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), data_2, cse_var_1, 16, 1, 16, 0, 0, 0, 0, 0, 2, dtype=int32)
          @tir.call_extern("VTALoadBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), weight_2, ((i1.outer*1024) + cse_var_1), 16, 16, 64, 0, 0, 0, 0, 0, 1, dtype=int32)
          @tir.vta.coproc_dep_push(1, 2, dtype=int32)
        }
        attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 2 {
          @tir.vta.coproc_dep_pop(1, 2, dtype=int32)
          attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushGEMMOp" {
            @tir.call_extern("VTAUopLoopBegin", 16, 1, 0, 16, dtype=int32)
            @tir.call_extern("VTAUopLoopBegin", 16, 0, 1, 1, dtype=int32)
            @tir.vta.uop_push(0, 0, 0, 0, 0, 0, 0, 0, dtype=int32)
            @tir.call_extern("VTAUopLoopEnd", dtype=int32)
            @tir.call_extern("VTAUopLoopEnd", dtype=int32)
          }
          @tir.vta.coproc_dep_push(2, 1, dtype=int32)
        }
      }
    }
    @tir.vta.coproc_dep_pop(2, 1, dtype=int32)
    attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 2 {
      attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushALUOp" {
        @tir.call_extern("VTAUopLoopBegin", 16, 1, 1, 0, dtype=int32)
        @tir.vta.uop_push(1, 0, 0, 0, 0, 3, 1, 8, dtype=int32)
        @tir.call_extern("VTAUopLoopEnd", dtype=int32)
      }
      attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushALUOp" {
        @tir.call_extern("VTAUopLoopBegin", 16, 1, 1, 0, dtype=int32)
        @tir.vta.uop_push(1, 0, 0, 0, 0, 1, 1, 0, dtype=int32)
        @tir.call_extern("VTAUopLoopEnd", dtype=int32)
      }
      attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushALUOp" {
        @tir.call_extern("VTAUopLoopBegin", 16, 1, 1, 0, dtype=int32)
        @tir.vta.uop_push(1, 0, 0, 0, 0, 0, 1, 127, dtype=int32)
        @tir.call_extern("VTAUopLoopEnd", dtype=int32)
      }
      @tir.vta.coproc_dep_push(2, 3, dtype=int32)
    }
    attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 3 {
      @tir.vta.coproc_dep_pop(2, 3, dtype=int32)
      for (i1.inner: int32, 0, 16) {
        @tir.call_extern("VTAStoreBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), i1.inner, 4, res_2, ((i1.outer*16) + i1.inner), 1, 1, 1, dtype=int32)
      }
      @tir.vta.coproc_dep_push(3, 2, dtype=int32)
    }
  }
  @tir.vta.coproc_sync(, dtype=int32)
  @tir.vta.coproc_dep_pop(3, 2, dtype=int32)
}
[21:37:07] /media/pc/data/4tb/lxw/books/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=128

TVM 计算和验证#

在指定调度之后,可以将其编译为 TVM 函数。保存模块,这样就可以通过 RPC 发送它。运行该函数并对 numpy 实现进行验证,以确保其正确性。

# Compile the TVM module
my_gemm = vta.build(
    s, [data, weight, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_gemm"
)
temp = utils.tempdir()
my_gemm.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")

# Get the remote device context
ctx = remote.ext_dev(0)

# Initialize the data and weight arrays randomly in the int range of (-128, 128]
data_np = np.random.randint(-128, 128, size=(batch_size, in_channels)).astype(data.dtype)
weight_np = np.random.randint(-128, 128, size=(out_channels, in_channels)).astype(weight.dtype)

# Apply packing to the data and weight arrays from a 2D to a 4D packed layout
data_packed = data_np.reshape(
    batch_size // env.BATCH, env.BATCH, in_channels // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))
weight_packed = weight_np.reshape(
    out_channels // env.BLOCK_OUT, env.BLOCK_OUT, in_channels // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))

# Format the input/output arrays with tvm.nd.array to the DLPack standard
data_nd = tvm.nd.array(data_packed, ctx)
weight_nd = tvm.nd.array(weight_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)

# Clear stats
if env.TARGET in ["sim", "tsim"]:
    simulator.clear_stats()

# Invoke the module to perform the computation
f(data_nd, weight_nd, res_nd)

# Verify against numpy implementation
res_ref = np.dot(data_np.astype(env.acc_dtype), weight_np.T.astype(env.acc_dtype))
res_ref = res_ref >> env.INP_WIDTH
res_ref = np.clip(res_ref, 0, inp_max)
res_ref = res_ref.astype(res.dtype)
res_ref = res_ref.reshape(
    batch_size // env.BATCH, env.BATCH, out_channels // env.BLOCK_OUT, env.BLOCK_OUT
).transpose((0, 2, 1, 3))
np.testing.assert_equal(res_ref, res_nd.numpy())

# Print stats
if env.TARGET in ["sim", "tsim"]:
    sim_stats = simulator.stats()
    print("Execution statistics:")
    for k, v in sim_stats.items():
        print("\t{:<16}: {:>16}".format(k, v))

print("Successful blocked matrix multiply test!")
Execution statistics:
	inp_load_nbytes :             4096
	wgt_load_nbytes :          1048576
	acc_load_nbytes :                0
	uop_load_nbytes :               20
	out_store_nbytes:             1024
	gemm_counter    :             4096
	alu_counter     :              192
Successful blocked matrix multiply test!
[21:37:07] /media/pc/data/4tb/lxw/books/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement  required_alignment=256, provided_alignment=128

小结#

本教程演示了 TVM 调度原语如何为矩阵乘法示例实现分块计算。这允许将任意大的计算映射到有限的硬件加速器资源上。