分块矩阵乘法
导航
分块矩阵乘法#
原作者: Thierry Moreau
本教程概述了如何在 VTA 设计中使用 TVM 有效地映射矩阵乘法。建议先学习 简单的矩阵乘法 教程。
在本教程中,将演示 TVM 调度优化,将大型神经网络算子分解为较小的块,以在有限的硬件加速器资源内实现计算。
RPC 设置#
首先编程 Pynq 的 FPGA 并构建它的 RPC 运行时。
import os
import tvm
from tvm import te
import vta
import numpy as np
from tvm import rpc
from tvm.contrib import utils
from vta.testing import simulator
# Load VTA parameters from the 3rdparty/vta-hw/config/vta_config.json file
env = vta.get_env()
# We read the Pynq RPC host IP address and port number from the OS environment
host = os.environ.get("VTA_RPC_HOST", "192.168.2.99")
port = int(os.environ.get("VTA_RPC_PORT", "9091"))
# We configure both the bitstream and the runtime system on the Pynq
# to match the VTA configuration specified by the vta_config.json file.
if env.TARGET == "pynq":
# Make sure that TVM was compiled with RPC=1
assert tvm.runtime.enabled("rpc")
remote = rpc.connect(host, port)
# Reconfigure the JIT runtime
vta.reconfig_runtime(remote)
# Program the FPGA with a pre-compiled VTA bitstream.
# You can program the FPGA with your own custom bitstream
# by passing the path to the bitstream file instead of None.
vta.program_fpga(remote, bitstream=None)
# In simulation mode, host the RPC server locally.
elif env.TARGET in ["sim", "tsim"]:
remote = rpc.LocalSession()
声明计算#
作为第一步,需要描述矩阵乘法的计算。将矩阵乘法定义为全连接层中的计算,由其 batch size、输入通道和输出通道定义。它们必须是 VTA 张量形状的整数倍:BATCH
、BLOCK_IN
和 BLOCK_OUT
。
在矩阵乘法中添加额外的算子,这些算子对输出进行了移位(shifting)和剪切(clipping),以模拟定点矩阵乘法,然后是修正的线性激活。将全连通层的 TVM 数据流图描述如下:
此计算被故意设置得太大,以至于不能一次全部放入 VTA 的 on-chip buffer。因此,在调度阶段,将依靠计算阻塞策略将计算分解为可管理的块。
# Fully connected layer dimensions: 1024 x 1024
batch_size = 1
in_channels = 1024
out_channels = 1024
assert batch_size % env.BATCH == 0
assert in_channels % env.BLOCK_IN == 0
assert out_channels % env.BLOCK_OUT == 0
# Let's derive the tiled input tensor shapes
data_shape = (batch_size // env.BATCH, in_channels // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
weight_shape = (
out_channels // env.BLOCK_OUT,
in_channels // env.BLOCK_IN,
env.BLOCK_OUT,
env.BLOCK_IN,
)
output_shape = (batch_size // env.BATCH, out_channels // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT)
num_ops = in_channels * out_channels * batch_size * 2
# Reduction axes
ic = te.reduce_axis((0, in_channels // env.BLOCK_IN), name="ic")
ic_tns = te.reduce_axis((0, env.BLOCK_IN), name="ic_tns")
# Input placeholder tensors
data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
weight = te.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype)
# Copy buffers
data_buf = te.compute(data_shape, lambda *i: data(*i), "data_buf")
weight_buf = te.compute(weight_shape, lambda *i: weight(*i), "weight_buf")
# Declare matrix multiply computation
res_gemm = te.compute(
output_shape,
lambda bo, co, bi, ci: te.sum(
data_buf[bo, ic, bi, ic_tns].astype(env.acc_dtype)
* weight_buf[co, ic, ci, ic_tns].astype(env.acc_dtype),
axis=[ic, ic_tns],
),
name="res_gem",
)
# Add shift stage for fix-point normalization
res_shr = te.compute(output_shape, lambda *i: res_gemm(*i) >> env.INP_WIDTH, name="res_shr")
# Apply clipping between (0, input max value)
inp_max = (1 << (env.INP_WIDTH - 1)) - 1
res_max = te.compute(output_shape, lambda *i: tvm.te.max(res_shr(*i), 0), "res_max")
res_min = te.compute(output_shape, lambda *i: tvm.te.min(res_max(*i), inp_max), "res_min")
# Apply typecast to input data type before sending results back
res = te.compute(output_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")
调度计算#
查看一组必要的调度变换,以有效的方式将矩阵乘法映射到 VTA。这些包括:
分块计算(Computation blocking)
Lowering 到 VTA 硬件 intrinsics
# Create TVM schedule
s = te.create_schedule(res.op)
# Let's look at the default TVM schedule
print(tvm.lower(s, [data, weight, res], simple_mode=True))
@main = primfn(data_1: handle, weight_1: handle, res_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {data: Buffer(data_2: Pointer(int8), int8, [1024], []),
weight: Buffer(weight_2: Pointer(int8), int8, [1048576], []),
res: Buffer(res_2: Pointer(int8), int8, [1024], [])}
buffer_map = {data_1: data, weight_1: weight, res_1: res}
preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 64, 1, 16], []), weight_1: weight_3: Buffer(weight_2, int8, [64, 64, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 64, 1, 16], [])} {
allocate(data_buf: Pointer(global int8), int8, [1024]), storage_scope = global;
allocate(weight_buf: Pointer(global int8), int8, [1048576]), storage_scope = global;
allocate(res_gem: Pointer(global int32), int32, [1024]), storage_scope = global {
for (i1: int32, 0, 64) {
for (i3: int32, 0, 16) {
let cse_var_1: int32 = ((i1*16) + i3)
data_buf_1: Buffer(data_buf, int8, [1024], [])[cse_var_1] = data[cse_var_1]
}
}
for (i0: int32, 0, 64) {
for (i1_1: int32, 0, 64) {
for (i2: int32, 0, 16) {
for (i3_1: int32, 0, 16) {
let cse_var_2: int32 = ((((i0*16384) + (i1_1*256)) + (i2*16)) + i3_1)
weight_buf_1: Buffer(weight_buf, int8, [1048576], [])[cse_var_2] = weight[cse_var_2]
}
}
}
}
for (co: int32, 0, 64) {
for (ci: int32, 0, 16) {
res_gem_1: Buffer(res_gem, int32, [1024], [])[((co*16) + ci)] = 0
for (ic: int32, 0, 64) {
for (ic_tns: int32, 0, 16) {
let cse_var_3: int32 = ((co*16) + ci)
res_gem_1[cse_var_3] = (res_gem_1[cse_var_3] + (cast(int32, data_buf_1[((ic*16) + ic_tns)])*cast(int32, weight_buf_1[((((co*16384) + (ic*256)) + (ci*16)) + ic_tns)])))
}
}
}
}
for (i1_2: int32, 0, 64) {
for (i3_2: int32, 0, 16) {
let cse_var_4: int32 = ((i1_2*16) + i3_2)
res_gem_2: Buffer(res_gem, int32, [1024], [])[cse_var_4] = @tir.shift_right(res_gem_1[cse_var_4], 8, dtype=int32)
}
}
for (i1_3: int32, 0, 64) {
for (i3_3: int32, 0, 16) {
let cse_var_5: int32 = ((i1_3*16) + i3_3)
res_gem_3: Buffer(res_gem, int32, [1024], [])[cse_var_5] = max(res_gem_2[cse_var_5], 0)
}
}
for (i1_4: int32, 0, 64) {
for (i3_4: int32, 0, 16) {
let cse_var_6: int32 = ((i1_4*16) + i3_4)
res_gem_4: Buffer(res_gem, int32, [1024], [])[cse_var_6] = min(res_gem_3[cse_var_6], 127)
}
}
for (i1_5: int32, 0, 64) {
for (i3_5: int32, 0, 16) {
let cse_var_7: int32 = ((i1_5*16) + i3_5)
res[cse_var_7] = cast(int8, res_gem_4[cse_var_7])
}
}
}
}
分块计算#
在默认情况下,矩阵乘法对于激活或权重来说太大了,无法一次性适应 VTA 的 on-chip buffer。将 (1, 1024)×(1024, 1024) 矩阵乘法分成更小的 (1, 256) × (256, 256) 矩阵乘法,这样中间张量就可以装进加速器的 on-chip SRAM 中。这种方法类似于将分块技术应用于 CPU 和 GPU,以提高缓存命中率(cache hit rate)。
沿着每个轴执行分块(batch 轴不受影响,因为正在执行单 batch 推理)。也保持最内侧的 tensorization 轴不变,以便 TVM 能够进行模式匹配的 tensorization。在下面的图表中展示了分块在计算调度上的结果:
循环分割(splitting)和重新排序(reordering)后的代码等价于下面的伪代码。忽略 batch 轴,因为在这个例子中只执行单 batch 推断:
for (int oc_out = 0; oc_out < 4; ++oc_out) {
// Initialization loop
for (int oc_inn = 0; oc_inn < 16; ++oc_inn) {
for (int oc_tns = 0; oc_tns < 16; ++oc_tns) {
int j = (oc_out * 16 + oc_inn) * 16 + oc_tns;
C[0][j] = 0;
}
}
for (int ic_out = 0; ic_out < 4; ++ic_out) {
// Block loop
for (int oc_inn = 0; oc_inn < 16; ++oc_inn) {
for (int ic_inn = 0; ic_inn < 16; ++ic_inn) {
// Tensorization loop
for (int oc_tns = 0; oc_tns < 16; ++oc_tns) {
for (int ic_tns = 0; ic_tns < 16; ++ic_tns) {
int i = (ic_out * 16 + ic_inn) * 16 + ic_tns;
int j = (oc_out * 16 + oc_inn) * 16 + oc_tns;
C[0][i] = C[0][i] + A[0][i] * B[j][i];
}
}
}
}
}
}
}
# Let's define tiling sizes (expressed in multiples of VTA tensor shape size)
b_block = 1 // env.BATCH
i_block = 256 // env.BLOCK_IN
o_block = 256 // env.BLOCK_OUT
# Tile the output tensor along the batch and output channel dimensions
# (since by default we are doing single batch inference, the split along
# the batch dimension has no effect)
b, oc, b_tns, oc_tns = s[res].op.axis
b_out, b_inn = s[res].split(b, b_block)
oc_out, oc_inn = s[res].split(oc, o_block)
s[res].reorder(b_out, oc_out, b_inn, oc_inn)
# Move intermediate computation into each output compute tile
s[res_gemm].compute_at(s[res], oc_out)
s[res_shr].compute_at(s[res], oc_out)
s[res_max].compute_at(s[res], oc_out)
s[res_min].compute_at(s[res], oc_out)
# Apply additional loop split along reduction axis (input channel)
b_inn, oc_inn, b_tns, oc_tns = s[res_gemm].op.axis
ic_out, ic_inn = s[res_gemm].split(ic, i_block)
# Reorder axes. We move the ic_out axis all the way out of the GEMM
# loop to block along the reduction axis
s[res_gemm].reorder(ic_out, b_inn, oc_inn, ic_inn, b_tns, oc_tns, ic_tns)
# Let's look at the current TVM schedule after blocking
print(tvm.lower(s, [data, weight, res], simple_mode=True))
@main = primfn(data_1: handle, weight_1: handle, res_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {data: Buffer(data_2: Pointer(int8), int8, [1024], []),
weight: Buffer(weight_2: Pointer(int8), int8, [1048576], []),
res: Buffer(res_2: Pointer(int8), int8, [1024], [])}
buffer_map = {data_1: data, weight_1: weight, res_1: res}
preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 64, 1, 16], []), weight_1: weight_3: Buffer(weight_2, int8, [64, 64, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 64, 1, 16], [])} {
allocate(data_buf: Pointer(global int8), int8, [1024]), storage_scope = global;
allocate(weight_buf: Pointer(global int8), int8, [1048576]), storage_scope = global;
allocate(res_gem: Pointer(global int32), int32, [256]), storage_scope = global {
for (i1: int32, 0, 64) {
for (i3: int32, 0, 16) {
let cse_var_1: int32 = ((i1*16) + i3)
data_buf_1: Buffer(data_buf, int8, [1024], [])[cse_var_1] = data[cse_var_1]
}
}
for (i0: int32, 0, 64) {
for (i1_1: int32, 0, 64) {
for (i2: int32, 0, 16) {
for (i3_1: int32, 0, 16) {
let cse_var_2: int32 = ((((i0*16384) + (i1_1*256)) + (i2*16)) + i3_1)
weight_buf_1: Buffer(weight_buf, int8, [1048576], [])[cse_var_2] = weight[cse_var_2]
}
}
}
}
for (i1.outer: int32, 0, 4) {
for (co.init: int32, 0, 16) {
for (ci.init: int32, 0, 16) {
res_gem_1: Buffer(res_gem, int32, [256], [])[((co.init*16) + ci.init)] = 0
}
}
for (ic.outer: int32, 0, 4) {
for (co: int32, 0, 16) {
for (ic.inner: int32, 0, 16) {
for (ci: int32, 0, 16) {
for (ic_tns: int32, 0, 16) {
let cse_var_3: int32 = ((co*16) + ci)
res_gem_1[cse_var_3] = (res_gem_1[cse_var_3] + (cast(int32, data_buf_1[(((ic.outer*256) + (ic.inner*16)) + ic_tns)])*cast(int32, weight_buf_1[((((((i1.outer*262144) + (co*16384)) + (ic.outer*4096)) + (ic.inner*256)) + (ci*16)) + ic_tns)])))
}
}
}
}
}
for (i1_2: int32, 0, 16) {
for (i3_2: int32, 0, 16) {
let cse_var_4: int32 = ((i1_2*16) + i3_2)
res_gem_2: Buffer(res_gem, int32, [256], [])[cse_var_4] = @tir.shift_right(res_gem_1[cse_var_4], 8, dtype=int32)
}
}
for (i1_3: int32, 0, 16) {
for (i3_3: int32, 0, 16) {
let cse_var_5: int32 = ((i1_3*16) + i3_3)
res_gem_3: Buffer(res_gem, int32, [256], [])[cse_var_5] = max(res_gem_2[cse_var_5], 0)
}
}
for (i1_4: int32, 0, 16) {
for (i3_4: int32, 0, 16) {
let cse_var_6: int32 = ((i1_4*16) + i3_4)
res_gem_4: Buffer(res_gem, int32, [256], [])[cse_var_6] = min(res_gem_3[cse_var_6], 127)
}
}
for (i1.inner: int32, 0, 16) {
for (i3_5: int32, 0, 16) {
let cse_var_7: int32 = (i1.inner*16)
res[(((i1.outer*256) + cse_var_7) + i3_5)] = cast(int8, res_gem_4[(cse_var_7 + i3_5)])
}
}
}
}
}
lowering 复制到 DMA 传输#
接下来,将 buffer 作用域设置为相应的 on-chip VTA SRAM buffer。将 load 循环移动到矩阵乘法计算循环中,以使它们适合于 on-chip SRAM buffer。最后,用 DMA 复制实用程序对 load/store 循环外轴进行注解,以在 VTA 上执行批量内存传输。
# Set scope of SRAM buffers
s[data_buf].set_scope(env.inp_scope)
s[weight_buf].set_scope(env.wgt_scope)
s[res_gemm].set_scope(env.acc_scope)
s[res_shr].set_scope(env.acc_scope)
s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(env.acc_scope)
# Block data and weight cache reads
s[data_buf].compute_at(s[res_gemm], ic_out)
s[weight_buf].compute_at(s[res_gemm], ic_out)
# Use DMA copy pragma on DRAM->SRAM operations
s[data_buf].pragma(s[data_buf].op.axis[0], env.dma_copy)
s[weight_buf].pragma(s[weight_buf].op.axis[0], env.dma_copy)
# Use DMA copy pragma on SRAM->DRAM operation
# (this implies that these copies should be performed along b_inn,
# or result axis 2)
s[res].pragma(s[res].op.axis[2], env.dma_copy)
Lowering 计算到 VTA Compute Intrinsics#
最后阶段是通过将矩阵乘法映射到张量 intrinsics,将 shift 映射到矢量 ALU,从而将计算循环 lowering 到 VTA 硬件 intrinsics。
# Apply tensorization over the batch tensor tile axis
s[res_gemm].tensorize(b_tns, env.gemm)
# Add an ALU pragma over the shift and clipping operations
s[res_shr].pragma(s[res_shr].op.axis[0], env.alu)
s[res_min].pragma(s[res_min].op.axis[0], env.alu)
s[res_max].pragma(s[res_max].op.axis[0], env.alu)
# Let's look at the final lowered TVM schedule after lowering memory
# loads/stores down to DMA copy intrinsics, and the computation down to
# VTA compute intrinsics.
print(vta.lower(s, [data, weight, res], simple_mode=True))
@main = primfn(data_1: handle, weight_1: handle, res_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
buffers = {data: Buffer(data_2: Pointer(int8), int8, [1024], []),
weight: Buffer(weight_2: Pointer(int8), int8, [1048576], []),
res: Buffer(res_2: Pointer(int8), int8, [1024], [])}
buffer_map = {data_1: data, weight_1: weight, res_1: res}
preflattened_buffer_map = {data_1: data_3: Buffer(data_2, int8, [1, 64, 1, 16], []), weight_1: weight_3: Buffer(weight_2, int8, [64, 64, 16, 16], []), res_1: res_3: Buffer(res_2, int8, [1, 64, 1, 16], [])} {
@tir.vta.coproc_dep_push(3, 2, dtype=int32)
for (i1.outer: int32, 0, 4) {
attr [IterVar(vta: int32, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 2 {
@tir.vta.coproc_dep_pop(3, 2, dtype=int32)
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushGEMMOp" {
@tir.call_extern("VTAUopLoopBegin", 16, 1, 0, 0, dtype=int32)
@tir.vta.uop_push(0, 1, 0, 0, 0, 0, 0, 0, dtype=int32)
@tir.call_extern("VTAUopLoopEnd", dtype=int32)
}
@tir.vta.coproc_dep_push(2, 1, dtype=int32)
}
for (ic.outer: int32, 0, 4) {
let cse_var_1: int32 = (ic.outer*16)
{
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 1 {
@tir.vta.coproc_dep_pop(2, 1, dtype=int32)
@tir.call_extern("VTALoadBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), data_2, cse_var_1, 16, 1, 16, 0, 0, 0, 0, 0, 2, dtype=int32)
@tir.call_extern("VTALoadBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), weight_2, ((i1.outer*1024) + cse_var_1), 16, 16, 64, 0, 0, 0, 0, 0, 1, dtype=int32)
@tir.vta.coproc_dep_push(1, 2, dtype=int32)
}
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 2 {
@tir.vta.coproc_dep_pop(1, 2, dtype=int32)
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushGEMMOp" {
@tir.call_extern("VTAUopLoopBegin", 16, 1, 0, 16, dtype=int32)
@tir.call_extern("VTAUopLoopBegin", 16, 0, 1, 1, dtype=int32)
@tir.vta.uop_push(0, 0, 0, 0, 0, 0, 0, 0, dtype=int32)
@tir.call_extern("VTAUopLoopEnd", dtype=int32)
@tir.call_extern("VTAUopLoopEnd", dtype=int32)
}
@tir.vta.coproc_dep_push(2, 1, dtype=int32)
}
}
}
@tir.vta.coproc_dep_pop(2, 1, dtype=int32)
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 2 {
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushALUOp" {
@tir.call_extern("VTAUopLoopBegin", 16, 1, 1, 0, dtype=int32)
@tir.vta.uop_push(1, 0, 0, 0, 0, 3, 1, 8, dtype=int32)
@tir.call_extern("VTAUopLoopEnd", dtype=int32)
}
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushALUOp" {
@tir.call_extern("VTAUopLoopBegin", 16, 1, 1, 0, dtype=int32)
@tir.vta.uop_push(1, 0, 0, 0, 0, 1, 1, 0, dtype=int32)
@tir.call_extern("VTAUopLoopEnd", dtype=int32)
}
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_uop_scope" = "VTAPushALUOp" {
@tir.call_extern("VTAUopLoopBegin", 16, 1, 1, 0, dtype=int32)
@tir.vta.uop_push(1, 0, 0, 0, 0, 0, 1, 127, dtype=int32)
@tir.call_extern("VTAUopLoopEnd", dtype=int32)
}
@tir.vta.coproc_dep_push(2, 3, dtype=int32)
}
attr [IterVar(vta, (nullptr), "ThreadIndex", "vta")] "coproc_scope" = 3 {
@tir.vta.coproc_dep_pop(2, 3, dtype=int32)
for (i1.inner: int32, 0, 16) {
@tir.call_extern("VTAStoreBuffer2D", @tir.tvm_thread_context(@tir.vta.command_handle(, dtype=handle), dtype=handle), i1.inner, 4, res_2, ((i1.outer*16) + i1.inner), 1, 1, 1, dtype=int32)
}
@tir.vta.coproc_dep_push(3, 2, dtype=int32)
}
}
@tir.vta.coproc_sync(, dtype=int32)
@tir.vta.coproc_dep_pop(3, 2, dtype=int32)
}
[21:37:07] /media/pc/data/4tb/lxw/books/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=128
TVM 计算和验证#
在指定调度之后,可以将其编译为 TVM 函数。保存模块,这样就可以通过 RPC 发送它。运行该函数并对 numpy 实现进行验证,以确保其正确性。
# Compile the TVM module
my_gemm = vta.build(
s, [data, weight, res], tvm.target.Target("ext_dev", host=env.target_host), name="my_gemm"
)
temp = utils.tempdir()
my_gemm.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# Get the remote device context
ctx = remote.ext_dev(0)
# Initialize the data and weight arrays randomly in the int range of (-128, 128]
data_np = np.random.randint(-128, 128, size=(batch_size, in_channels)).astype(data.dtype)
weight_np = np.random.randint(-128, 128, size=(out_channels, in_channels)).astype(weight.dtype)
# Apply packing to the data and weight arrays from a 2D to a 4D packed layout
data_packed = data_np.reshape(
batch_size // env.BATCH, env.BATCH, in_channels // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))
weight_packed = weight_np.reshape(
out_channels // env.BLOCK_OUT, env.BLOCK_OUT, in_channels // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))
# Format the input/output arrays with tvm.nd.array to the DLPack standard
data_nd = tvm.nd.array(data_packed, ctx)
weight_nd = tvm.nd.array(weight_packed, ctx)
res_nd = tvm.nd.array(np.zeros(output_shape).astype(res.dtype), ctx)
# Clear stats
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
# Invoke the module to perform the computation
f(data_nd, weight_nd, res_nd)
# Verify against numpy implementation
res_ref = np.dot(data_np.astype(env.acc_dtype), weight_np.T.astype(env.acc_dtype))
res_ref = res_ref >> env.INP_WIDTH
res_ref = np.clip(res_ref, 0, inp_max)
res_ref = res_ref.astype(res.dtype)
res_ref = res_ref.reshape(
batch_size // env.BATCH, env.BATCH, out_channels // env.BLOCK_OUT, env.BLOCK_OUT
).transpose((0, 2, 1, 3))
np.testing.assert_equal(res_ref, res_nd.numpy())
# Print stats
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
print("Successful blocked matrix multiply test!")
Execution statistics:
inp_load_nbytes : 4096
wgt_load_nbytes : 1048576
acc_load_nbytes : 0
uop_load_nbytes : 20
out_store_nbytes: 1024
gemm_counter : 4096
alu_counter : 192
Successful blocked matrix multiply test!
[21:37:07] /media/pc/data/4tb/lxw/books/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=128
小结#
本教程演示了 TVM 调度原语如何为矩阵乘法示例实现分块计算。这允许将任意大的计算映射到有限的硬件加速器资源上。