TVM 中的调度原语#

原作者: Ziheng Jiang

TVM 用于高效构建 kernel 的领域特定语言。

在本教程中,将您展示如何通过 TVM 提供的各种原语调度计算。

import tvm
from tvm import te
import numpy as np

通常有几种方法可以计算相同的结果,但是,不同的方法会导致不同的局部性(locality)和性能。因此 TVM 要求用户提供如何执行名为 Schedule (调度)的计算。

Schedule 是一组变换程序中计算循环的计算变换。

# 声明一些变量以备以后使用
n = te.var("n")
m = te.var("m")

调度可以从 ops 列表中创建,默认情况下,调度以 row-major 顺序的串行方式计算张量。

# 声明矩阵元素级的乘法
A = te.placeholder((m, n), name="A")
B = te.placeholder((m, n), name="B")
C = te.compute((m, n), lambda i, j: A[i, j] * B[i, j], name="C")

s = te.create_schedule([C.op])

lower 将计算从定义转换为实际的可调用函数。使用 simple_mode=True 参数,它将返回可读的 C like 语句,在这里使用它来打印调度结果。

tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        n = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        stride_2 = T.var("int32")
        stride_3 = T.var("int32")
        stride_4 = T.var("int32")
        stride_5 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m, n], dtype="float32", data=A_1.data, strides=[stride, stride_3], type="auto")
        T.preflattened_buffer(B_1, [m, n], dtype="float32", data=B_1.data, strides=[stride_1, stride_4], type="auto")
        T.preflattened_buffer(C_1, [m, n], dtype="float32", data=C_1.data, strides=[stride_2, stride_5], type="auto")
        # body
        for i, j in T.grid(m, n):
            C_1[i * stride_2 + j * stride_5] = A_1[i * stride + j * stride_3] * B_1[i * stride_1 + j * stride_4]
    

每个调度由多个阶段(Stage)组成,每个阶段表示一个运算的调度。

下面提供各种方法来调度每个阶段。

split#

split 可以通过 factor 将指定的轴分裂(split)为两个轴。

m = te.var("m")
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] * 2, name="B")

s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=32)
tvm.lower(s, [A, B]).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
        T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
        # body
        for i_outer, i_inner in T.grid(m // 32, 32):
            cse_var_1: T.int32 = i_outer * 32 + i_inner
            B_1[cse_var_1 * stride_1] = A_1[cse_var_1 * stride] * T.float32(2)
        for i_outer, i_inner in T.grid((m % 32 + 31) // 32, 32):
            if m // 32 * 32 + i_outer * 32 + i_inner < m:
                B_1[(m // 32 * 32 + i_outer * 32 + i_inner) * stride_1] = A_1[(m // 32 * 32 + i_outer * 32 + i_inner) * stride] * T.float32(2)
    

你也可以通过 nparts 分裂轴,它与 factor 分割轴相对。

m = te.var("m")
A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i], name="B")

s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], nparts=32)
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
        T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
        # body
        for i_outer, i_inner in T.grid(32, (m + 31) // 32):
            if T.likely(i_inner + i_outer * ((m + 31) // 32) < m, dtype="bool"):
                B_1[(i_inner + i_outer * ((m + 31) // 32)) * stride_1] = A_1[(i_inner + i_outer * ((m + 31) // 32)) * stride]
    

tile#

tile 帮助你在两个轴上逐块(tile by tile)执行计算。

A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")

s = te.create_schedule(B.op)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        n = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        stride_2 = T.var("int32")
        stride_3 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m, n], dtype="float32", data=A_1.data, strides=[stride, stride_2], type="auto")
        T.preflattened_buffer(B_1, [m, n], dtype="float32", data=B_1.data, strides=[stride_1, stride_3], type="auto")
        # body
        for i_outer, j_outer, i_inner in T.grid((m + 9) // 10, (n + 4) // 5, 10):
            if T.likely(i_outer * 10 + i_inner < m, dtype="bool"):
                for j_inner in T.serial(5):
                    if T.likely(j_outer * 5 + j_inner < n, dtype="bool"):
                        cse_var_2: T.int32 = j_outer * 5 + j_inner
                        cse_var_1: T.int32 = i_outer * 10 + i_inner
                        B_1[cse_var_1 * stride_1 + cse_var_2 * stride_3] = A_1[cse_var_1 * stride + cse_var_2 * stride_2]
    

fuse#

fuse 可以融合一个计算的两个连续轴。

A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")

s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then fuse (i.inner, j.inner) into one axis: (i.inner.j.inner.fused)
fused = s[B].fuse(xi, yi)
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        n = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        stride_2 = T.var("int32")
        stride_3 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m, n], dtype="float32", data=A_1.data, strides=[stride, stride_2], type="auto")
        T.preflattened_buffer(B_1, [m, n], dtype="float32", data=B_1.data, strides=[stride_1, stride_3], type="auto")
        # body
        for i_outer, j_outer, i_inner_j_inner_fused in T.grid((m + 9) // 10, (n + 4) // 5, 50):
            if T.likely(i_outer * 10 + i_inner_j_inner_fused // 5 < m, dtype="bool"):
                if T.likely(j_outer * 5 + i_inner_j_inner_fused % 5 < n, dtype="bool"):
                    cse_var_2: T.int32 = j_outer * 5 + i_inner_j_inner_fused % 5
                    cse_var_1: T.int32 = i_outer * 10 + i_inner_j_inner_fused // 5
                    B_1[cse_var_1 * stride_1 + cse_var_2 * stride_3] = A_1[cse_var_1 * stride + cse_var_2 * stride_2]
    

reorder#

:code:reorder can reorder the axes in the specified order.

A = te.placeholder((m, n), name="A")
B = te.compute((m, n), lambda i, j: A[i, j], name="B")

s = te.create_schedule(B.op)
# tile to four axes first: (i.outer, j.outer, i.inner, j.inner)
xo, yo, xi, yi = s[B].tile(B.op.axis[0], B.op.axis[1], x_factor=10, y_factor=5)
# then reorder the axes: (i.inner, j.outer, i.outer, j.inner)
s[B].reorder(xi, yo, xo, yi)
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        n = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        stride_2 = T.var("int32")
        stride_3 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m, n], dtype="float32", data=A_1.data, strides=[stride, stride_2], type="auto")
        T.preflattened_buffer(B_1, [m, n], dtype="float32", data=B_1.data, strides=[stride_1, stride_3], type="auto")
        # body
        for i_inner, j_outer, i_outer in T.grid(10, (n + 4) // 5, (m + 9) // 10):
            if T.likely(i_outer * 10 + i_inner < m, dtype="bool"):
                for j_inner in T.serial(5):
                    if T.likely(j_outer * 5 + j_inner < n, dtype="bool"):
                        cse_var_2: T.int32 = j_outer * 5 + j_inner
                        cse_var_1: T.int32 = i_outer * 10 + i_inner
                        B_1[cse_var_1 * stride_1 + cse_var_2 * stride_3] = A_1[cse_var_1 * stride + cse_var_2 * stride_2]
    

bind#

:code:bind can bind a specified axis with a thread axis, often used in gpu programming.

A = te.placeholder((n,), name="A")
B = te.compute(A.shape, lambda i: A[i] * 2, name="B")

s = te.create_schedule(B.op)
bx, tx = s[B].split(B.op.axis[0], factor=64)
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
tvm.lower(s, [A, B], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        # var definition
        threadIdx_x = T.env_thread("threadIdx.x")
        blockIdx_x = T.env_thread("blockIdx.x")
        n = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * n], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * n], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [n], dtype="float32", data=A_1.data, strides=[stride], type="auto")
        T.preflattened_buffer(B_1, [n], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
        # body
        T.launch_thread(blockIdx_x, (n + 63) // 64)
        T.launch_thread(threadIdx_x, 64)
        if T.likely(blockIdx_x * 64 + threadIdx_x < n, dtype="bool"):
            B_1[(blockIdx_x * 64 + threadIdx_x) * stride_1] = A_1[(blockIdx_x * 64 + threadIdx_x) * stride] * T.float32(2)
    

compute_at#

For a schedule that consists of multiple operators, TVM will compute tensors at the root separately by default.

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

s = te.create_schedule(C.op)
tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        stride_2 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
        T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
        T.preflattened_buffer(C_1, [m], dtype="float32", data=C_1.data, strides=[stride_2], type="auto")
        # body
        for i in T.serial(m):
            B_1[i * stride_1] = A_1[i * stride] + T.float32(1)
        for i in T.serial(m):
            C_1[i * stride_2] = B_1[i * stride_1] * T.float32(2)
    

:code:compute_at can move computation of B into the first axis of computation of C.

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        stride_2 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
        T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
        T.preflattened_buffer(C_1, [m], dtype="float32", data=C_1.data, strides=[stride_2], type="auto")
        # body
        for i in T.serial(m):
            B_1[i * stride_1] = A_1[i * stride] + T.float32(1)
            C_1[i * stride_2] = B_1[i * stride_1] * T.float32(2)
    

compute_inline#

:code:compute_inline can mark one stage as inline, then the body of computation will be expanded and inserted at the address where the tensor is required.

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

s = te.create_schedule(C.op)
s[B].compute_inline()
tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        stride_2 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
        T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
        T.preflattened_buffer(C_1, [m], dtype="float32", data=C_1.data, strides=[stride_2], type="auto")
        # body
        for i in T.serial(m):
            C_1[i * stride_2] = (A_1[i * stride] + T.float32(1)) * T.float32(2)
    

compute_root#

:code:compute_root can move computation of one stage to the root.

A = te.placeholder((m,), name="A")
B = te.compute((m,), lambda i: A[i] + 1, name="B")
C = te.compute((m,), lambda i: B[i] * 2, name="C")

s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[0])
s[B].compute_root()
tvm.lower(s, [A, B, C], simple_mode=True).show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(A: T.handle, B: T.handle, C: T.handle) -> None:
        # function attr dict
        T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
        m = T.var("int32")
        stride = T.var("int32")
        stride_1 = T.var("int32")
        stride_2 = T.var("int32")
        A_1 = T.match_buffer(A, [stride * m], dtype="float32", type="auto")
        B_1 = T.match_buffer(B, [stride_1 * m], dtype="float32", type="auto")
        C_1 = T.match_buffer(C, [stride_2 * m], dtype="float32", type="auto")
        T.preflattened_buffer(A_1, [m], dtype="float32", data=A_1.data, strides=[stride], type="auto")
        T.preflattened_buffer(B_1, [m], dtype="float32", data=B_1.data, strides=[stride_1], type="auto")
        T.preflattened_buffer(C_1, [m], dtype="float32", data=C_1.data, strides=[stride_2], type="auto")
        # body
        for i in T.serial(m):
            B_1[i * stride_1] = A_1[i * stride] + T.float32(1)
        for i in T.serial(m):
            C_1[i * stride_2] = B_1[i * stride_1] * T.float32(2)
    

Summary#

This tutorial provides an introduction to schedule primitives in tvm, which permits users schedule the computation easily and flexibly.

In order to get a good performance kernel implementation, the general workflow often is:

  • Describe your computation via series of operations.

  • Try to schedule the computation with primitives.

  • Compile and run to see the performance difference.

  • Adjust your schedule according the running result.