2.5. TensorIR 练习
导航
2.5. TensorIR 练习#
import IPython
import numpy as np
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
如何编写 TensorIR#
然后编写向量加法#
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(16, 0, -1).reshape(4, 4)
# numpy version
c_np = a + b
c_np
array([[16, 16, 16, 16],
[16, 16, 16, 16],
[16, 16, 16, 16],
[16, 16, 16, 16]])
# low-level numpy version
def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):
for i in range(4):
for j in range(4):
c[i, j] = a[i, j] + b[i, j]
c_lnumpy = np.empty((4, 4), dtype=np.int64)
lnumpy_add(a, b, c_lnumpy)
c_lnumpy
array([[16, 16, 16, 16],
[16, 16, 16, 16],
[16, 16, 16, 16],
[16, 16, 16, 16]])
@tvm.script.ir_module
class MyAdd:
@T.prim_func
def add(A: T.Buffer[(4, 4), "int64"],
B: T.Buffer[(4, 4), "int64"],
C: T.Buffer[(4, 4), "int64"]):
T.func_attr({"global_symbol": "add"})
for i, j in T.grid(4, 4):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] + B[vi, vj]
rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)
广播加法#
# init data
a = np.arange(16).reshape(4, 4)
b = np.arange(4, 0, -1).reshape(4)
# numpy version
c_np = a + b
c_np
array([[ 4, 4, 4, 4],
[ 8, 8, 8, 8],
[12, 12, 12, 12],
[16, 16, 16, 16]])
@tvm.script.ir_module
class MyAdd:
@T.prim_func
def add(A: T.Buffer[(4, 4), "int64"],
B: T.Buffer[(4), "int64"],
C: T.Buffer[(4, 4), "int64"]):
T.func_attr({"global_symbol": "add"})
for i, j in T.grid(4, 4):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] + B[vj]
rt_lib = tvm.build(MyAdd, target="llvm")
a_tvm = tvm.nd.array(a)
b_tvm = tvm.nd.array(b)
c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))
rt_lib["add"](a_tvm, b_tvm, c_tvm)
np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)
2D 卷积#
使用 NCHW 布局的卷积的数学定义:
\[
\text{Conv}[b, k, i, j] =
\sum_{d_i, d_j, q} A[b, q, \text{strides} * i + d_i, \text{strides} * j + d_j] * W[k, q, d_i, d_j],
\]
其中,\(A\) 是输入张量,\(W\) 是权重张量,\(b\) 是批次索引,\(k\) 是输出通道,\(i\) 和 \(j\) 是图像高度和宽度的索引,\(d_i\) 和 \(d_j\) 是权重的索引,\(q\) 是输入通道,strides
是过滤器窗口的步幅。
下面考虑简单的情况:stride=1, padding=0
。
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K + 1, W - K + 1
data = np.arange(N*CI*H*W).reshape(N, CI, H, W)
weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)
# torch version
import torch
data_torch = torch.Tensor(data)
weight_torch = torch.Tensor(weight)
conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch
array([[[[ 474, 510, 546, 582, 618, 654],
[ 762, 798, 834, 870, 906, 942],
[1050, 1086, 1122, 1158, 1194, 1230],
[1338, 1374, 1410, 1446, 1482, 1518],
[1626, 1662, 1698, 1734, 1770, 1806],
[1914, 1950, 1986, 2022, 2058, 2094]],
[[1203, 1320, 1437, 1554, 1671, 1788],
[2139, 2256, 2373, 2490, 2607, 2724],
[3075, 3192, 3309, 3426, 3543, 3660],
[4011, 4128, 4245, 4362, 4479, 4596],
[4947, 5064, 5181, 5298, 5415, 5532],
[5883, 6000, 6117, 6234, 6351, 6468]]]])
@tvm.script.ir_module
class MyConv:
@T.prim_func
def conv(A: T.Buffer[(1, 1, 8, 8), "int64"], # 1,1,8,8
B: T.Buffer[(2, 1, 3, 3), "int64"], # 2,1,3,3
C: T.Buffer[(1, 2, 6, 6), "int64"]): # 1,2,6,6
T.func_attr({"global_symbol": "conv", "tir.noalias": True})
for n, c, h, w, i, k1, k2 in T.grid(N, CO, OUT_H, OUT_W, CI, K, K):
with T.block("C"):
vn = T.axis.spatial(1, n)
vc = T.axis.spatial(2, c)
vh = T.axis.spatial(6, h)
vw = T.axis.spatial(6, w)
vi = T.axis.spatial(1, i)
vk1 = T.axis.reduce(3, k1)
vk2 = T.axis.reduce(3, k2)
with T.init():
C[vn, vc, vh, vw] = T.int64(0)
C[vn, vc, vh, vw] = C[vn, vc, vh, vw] + A[vn, vi, vh + vk1, vw + vk2] * B[vc, vi, vk1, vk2]
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)