TVM 中的调度原语
导航
TVM 中的调度原语#
原作者: Ziheng Jiang
TVM 用于高效构建 kernel 的领域特定语言。
在本教程中,将您展示如何通过 TVM 提供的各种原语调度计算。
import tvm
from tvm import te
import numpy as np
通常有几种方法可以计算相同的结果,但是,不同的方法会导致不同的局部性(locality)和性能。因此 TVM 要求用户提供如何执行名为 Schedule (调度)的计算。
Schedule 是一组变换程序中计算循环的计算变换。
# 声明一些变量以备以后使用
n = te.var("n")
m = te.var("m")
调度可以从 ops 列表中创建,默认情况下,调度以 row-major 顺序的串行方式计算张量。
# 声明矩阵元素级的乘法
A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")
s = te.create_schedule([C.op])
lower
将计算从定义转换为实际的可调用函数。使用 simple_mode=True
参数,它将返回可读的 C like 语句,在这里使用它来打印调度结果。
tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
stride_2 = T.var("int32")
stride_3 = T.var("int32")
stride_4 = T.var("int32")
stride_5 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m, n], dtype="float32", data=A_1.data, strides=[stride, stride_3], type="auto")
T.preflattened_buffer(B_1, [m, n], dtype="float32", data=B_1.data, strides=[stride_1, stride_4], type="auto")
T.preflattened_buffer(C_1, [m, n], dtype="float32", data=C_1.data, strides=[stride_2, stride_5], type="auto")
# body
for i, j in T.grid(m, n):
C_1[i * stride_2 + j * stride_5] = A_1[i * stride + j * stride_3] * B_1[i * stride_1 + j * stride_4]
每个调度由多个阶段(Stage)组成,每个阶段表示一个运算的调度。
下面提供各种方法来调度每个阶段。
split#
split
可以通过 factor
将指定的轴分裂(split)为两个轴。
m = te.var("m")
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] * 2, name="B")
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
tvm.lower(s, [A, B]).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
# body
for i_outer, i_inner in T.grid(m // 32, 32):
cse_var_1: T.int32 = i_outer * 32 + i_inner
B_1[cse_var_1 * stride_1] = A_1[cse_var_1 * stride] * T.float32(2)
for i_outer, i_inner in T.grid((m % 32 + 31) // 32, 32):
if m // 32 * 32 + i_outer * 32 + i_inner < m:
B_1[(m // 32 * 32 + i_outer * 32 + i_inner) * stride_1] = A_1[(m // 32 * 32 + i_outer * 32 + i_inner) * stride] * T.float32(2)
你也可以通过 nparts
分裂轴,它与 factor
分割轴相对。
m = te.var("m")
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i], name="B")
s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], nparts=32)
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
# body
for i_outer, i_inner in T.grid(32, (m + 31) // 32):
if T.likely(i_inner + i_outer * ((m + 31) // 32) < m, dtype="bool"):
B_1[(i_inner + i_outer * ((m + 31) // 32)) * stride_1] = A_1[(i_inner + i_outer * ((m + 31) // 32)) * stride]
tile#
tile
帮助你在两个轴上逐块(tile by tile)执行计算。
A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
s = te.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
stride_2 = T.var("int32")
stride_3 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m, n], dtype="float32", data=A_1.data, strides=[stride, stride_2], type="auto")
T.preflattened_buffer(B_1, [m, n], dtype="float32", data=B_1.data, strides=[stride_1, stride_3], type="auto")
# body
for i_outer, j_outer, i_inner in T.grid((m + 9) // 10, (n + 4) // 5, 10):
if T.likely(i_outer * 10 + i_inner < m, dtype="bool"):
for j_inner in T.serial(5):
if T.likely(j_outer * 5 + j_inner < n, dtype="bool"):
cse_var_2: T.int32 = j_outer * 5 + j_inner
cse_var_1: T.int32 = i_outer * 10 + i_inner
B_1[cse_var_1 * stride_1 + cse_var_2 * stride_3] = A_1[cse_var_1 * stride + cse_var_2 * stride_2]
fuse#
fuse
可以融合一个计算的两个连续轴。
A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(xi, yi)
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
stride_2 = T.var("int32")
stride_3 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m, n], dtype="float32", data=A_1.data, strides=[stride, stride_2], type="auto")
T.preflattened_buffer(B_1, [m, n], dtype="float32", data=B_1.data, strides=[stride_1, stride_3], type="auto")
# body
for i_outer, j_outer, i_inner_j_inner_fused in T.grid((m + 9) // 10, (n + 4) // 5, 50):
if T.likely(i_outer * 10 + i_inner_j_inner_fused // 5 < m, dtype="bool"):
if T.likely(j_outer * 5 + i_inner_j_inner_fused % 5 < n, dtype="bool"):
cse_var_2: T.int32 = j_outer * 5 + i_inner_j_inner_fused % 5
cse_var_1: T.int32 = i_outer * 10 + i_inner_j_inner_fused // 5
B_1[cse_var_1 * stride_1 + cse_var_2 * stride_3] = A_1[cse_var_1 * stride + cse_var_2 * stride_2]
reorder#
:code:reorder
can reorder the axes in the specified order.
A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")
s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then reorder the axes: (i.inner, j.outer, i.outer, j.inner)
s[B].reorder(xi, yo, xo, yi)
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
n = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
stride_2 = T.var("int32")
stride_3 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m, n], dtype="float32", data=A_1.data, strides=[stride, stride_2], type="auto")
T.preflattened_buffer(B_1, [m, n], dtype="float32", data=B_1.data, strides=[stride_1, stride_3], type="auto")
# body
for i_inner, j_outer, i_outer in T.grid(10, (n + 4) // 5, (m + 9) // 10):
if T.likely(i_outer * 10 + i_inner < m, dtype="bool"):
for j_inner in T.serial(5):
if T.likely(j_outer * 5 + j_inner < n, dtype="bool"):
cse_var_2: T.int32 = j_outer * 5 + j_inner
cse_var_1: T.int32 = i_outer * 10 + i_inner
B_1[cse_var_1 * stride_1 + cse_var_2 * stride_3] = A_1[cse_var_1 * stride + cse_var_2 * stride_2]
bind#
:code:bind
can bind a specified axis with a thread axis, often used
in gpu programming.
A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: A[i] * 2, name="B")
s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
blockIdx_x = T.env_thread("blockIdx.x")
n = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
A_1 = T.match_buffer(A, [stride * n], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * n], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [n], dtype="float32", data=A_1.data, strides=[stride], type="auto")
T.preflattened_buffer(B_1, [n], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
# body
T.launch_thread(blockIdx_x, (n + 63) // 64)
T.launch_thread(threadIdx_x, 64)
if T.likely(blockIdx_x * 64 + threadIdx_x < n, dtype="bool"):
B_1[(blockIdx_x * 64 + threadIdx_x) * stride_1] = A_1[(blockIdx_x * 64 + threadIdx_x) * stride] * T.float32(2)
compute_at#
For a schedule that consists of multiple operators, TVM will compute tensors at the root separately by default.
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
s = te.create_schedule(C.op)
tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
stride_2 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
T.preflattened_buffer(C_1, [m], dtype="float32", data=C_1.data, strides=[stride_2], type="auto")
# body
for i in T.serial(m):
B_1[i * stride_1] = A_1[i * stride] + T.float32(1)
for i in T.serial(m):
C_1[i * stride_2] = B_1[i * stride_1] * T.float32(2)
:code:compute_at
can move computation of B
into the first axis
of computation of C
.
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
stride_2 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
T.preflattened_buffer(C_1, [m], dtype="float32", data=C_1.data, strides=[stride_2], type="auto")
# body
for i in T.serial(m):
B_1[i * stride_1] = A_1[i * stride] + T.float32(1)
C_1[i * stride_2] = B_1[i * stride_1] * T.float32(2)
compute_inline#
:code:compute_inline
can mark one stage as inline, then the body of
computation will be expanded and inserted at the address where the
tensor is required.
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
s = te.create_schedule(C.op)
s[B].compute_inline()
tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
stride_2 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
T.preflattened_buffer(C_1, [m], dtype="float32", data=C_1.data, strides=[stride_2], type="auto")
# body
for i in T.serial(m):
C_1[i * stride_2] = (A_1[i * stride] + T.float32(1)) * T.float32(2)
compute_root#
:code:compute_root
can move computation of one stage to the root.
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")
s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
s[B].compute_root()
tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.handle, B: T.handle, C: T.handle) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
m = T.var("int32")
stride = T.var("int32")
stride_1 = T.var("int32")
stride_2 = T.var("int32")
A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
T.preflattened_buffer(C_1, [m], dtype="float32", data=C_1.data, strides=[stride_2], type="auto")
# body
for i in T.serial(m):
B_1[i * stride_1] = A_1[i * stride] + T.float32(1)
for i in T.serial(m):
C_1[i * stride_2] = B_1[i * stride_1] * T.float32(2)
Summary#
This tutorial provides an introduction to schedule primitives in tvm, which permits users schedule the computation easily and flexibly.
In order to get a good performance kernel implementation, the general workflow often is:
Describe your computation via series of operations.
Try to schedule the computation with primitives.
Compile and run to see the performance difference.
Adjust your schedule according the running result.