Tensor Core 编程教程#

在课程中,我们使用 tmm 内在函数来演示 Tensorize 的进展。 在本教程中,我们将把 TensorIR 运用到 NVIDIA GPU 上的 Tensor Cores。 请注意,Tensor Cores仅在具有 Volta 或更新架构的 NVIDIA GPU 上受支持(例如,V100T4RTX-20X0A100RTX-30X0)。 不幸的是,Colab 提供的大多数 GPU 都太旧,无法支持 Tensor Core。 您可能需要为本教程准备自己的设备。

什么是 Tensor Core?#

张量核心是具有 Volta 和更新架构的 GPU 上的可编程矩阵乘法和累加单元。 每个 Tensor Core 都提供了一个矩阵处理数组,它执行 D = A * B + C 运算,如果我们使用 nvcuda::wmma,则ABCD16x16 的矩阵。 其中,矩阵乘法输入 ABfp16 矩阵,而累加矩阵 CD 可以是 fp16fp32 矩阵。

WMMA16x16x16.png

CUDA 语言只能使用 warp-level 原语 wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag) 在 Tensor Core 上执行 16x16x16 半精度矩阵乘法。 在调用矩阵乘法之前,我们必须使用原始的 wmma::load_matrix_sync 显式地将数据从内存加载到寄存器中(类似于我们在第6章第2部分的 tmm 演示中所做的)。 NVCC 编译器将该原语转换为多个内存加载指令。 在运行时,每个线程从矩阵 A 加载 16 个元素或从 B 加载 16 个元素。

准备工作#

import tvm
from tvm.script import tir as T
from tvm import tir

import numpy as np

编写 Matmul 的 Tensor IR 程序#

@tvm.script.ir_module
class MatmulModule:
    @T.prim_func
    def main(
        X: T.Buffer[(1024, 1024), "float16"],
        Y: T.Buffer[(1024, 1024), "float16"],
        Z: T.Buffer[(1024, 1024), "float32"],
    ) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        for i, j, k in T.grid(1024, 1024, 1024):
            with T.block("matmul"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                with T.init():
                    Z[vi, vj] = T.float32(0)
                Z[vi, vj] += T.cast(X[vi, vk], "float32") * T.cast(Y[vj, vk], "float32")

注意计算 Z[vi, vj] += T.cast(X[vi, vk], "float32") * T.cast(Y[vj, vk], "float32") 与常规表示有点不同。 由于 Tensor Cores 加载 fp16 的数据,但在 fp32 进行计算。 所以我们必须在计算之前将数据 castfp32

image11.png

内存层级#

在传统的 GPU 调度中,我们有 “全局内存”、“共享内存”和“本地寄存器”的内存层级。 为了支持 Tensor Cores,我们引入了另外三个特殊的内存范围:wmma.matrix_awmma.matrix_bwmma.accumulator(类似于在第 6 章第 2 部分的演示中的 global.A_regglobal.B_regglobal. accumulator)。 在硬件上,所有 wmma 的内存相关层级都存储在片上寄存器级别,与本地内存相同。

注册 Tensor Intrinsic#

这里我们注册了所有的 Tensor Core intrinsics,包括load_matrix_aload_matrix_bwmma_fill(初始化C = 0)、wmma_sync(累加计算C += A * B)和 store_matrix。 在本教程中,我们不会解释如何编写 intrinsic ,而是关注如何将给定的 intrinsic 应用于张量化程序。

@T.prim_func
def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

    with T.block("root"):
        T.reads(A[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        for i, j in T.grid(16, 16):
            with T.block("load"):
                vii, vjj = T.axis.remap("SS", [i, j])
                C[vii, vjj] = A[vii, vjj]


@T.prim_func
def wmma_load_a_impl(a: T.handle, c: T.handle) -> None:
    s1 = T.var("int32")
    s0 = T.var("int32")
    A = T.match_buffer(
        a,
        (16, 16),
        "float16",
        align=128,
        offset_factor=16,
        scope="shared",
        strides=[s1, s0],
    )
    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")

    with T.block("root"):
        T.reads(A[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        T.evaluate(
            T.tvm_load_matrix_sync(
                C.data,
                16,
                16,
                16,
                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
                A.access_ptr("r"),
                s1,
                "row_major",
                dtype="handle",
            )
        )


@T.prim_func
def wmma_load_b_desc(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared")
    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")

    with T.block("root"):
        T.reads(A[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        for i, j in T.grid(16, 16):
            with T.block("load"):
                vii, vjj = T.axis.remap("SS", [i, j])
                C[vii, vjj] = A[vii, vjj]


@T.prim_func
def wmma_load_b_impl(a: T.handle, c: T.handle) -> None:
    s1 = T.var("int32")
    s0 = T.var("int32")
    A = T.match_buffer(
        a,
        (16, 16),
        "float16",
        align=128,
        offset_factor=16,
        scope="shared",
        strides=[s1, s0],
    )
    C = T.match_buffer(c, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")

    with T.block("root"):
        T.reads(A[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        T.evaluate(
            T.tvm_load_matrix_sync(
                C.data,
                16,
                16,
                16,
                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
                A.access_ptr("r"),
                s1,
                "col_major",
                dtype="handle",
            )
        )


@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
    B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
    C = T.match_buffer(
        c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
    )

    with T.block("root"):
        T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        for i, j, k in T.grid(16, 16, 16):
            with T.block(""):
                vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
                C[vii, vjj] += T.cast(A[vii, vkk], "float32") * T.cast(B[vjj, vkk], "float32")


@T.prim_func
def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_a")
    B = T.match_buffer(b, (16, 16), "float16", align=128, offset_factor=16, scope="wmma.matrix_b")
    C = T.match_buffer(
        c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
    )

    with T.block("root"):
        T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        T.evaluate(
            T.tvm_mma_sync(
                C.data,
                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
                A.data,
                A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),
                B.data,
                B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16),
                C.data,
                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
                dtype="handle",
            )
        )


@T.prim_func
def wmma_fill_desc(c: T.handle) -> None:
    C = T.match_buffer(
        c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
    )

    with T.block("root"):
        T.reads()
        T.writes(C[0:16, 0:16])
        for i, j in T.grid(16, 16):
            with T.block("init"):
                vii, vjj = T.axis.remap("SS", [i, j])
                C[vii, vjj] = T.float32(0)


@T.prim_func
def wmma_fill_impl(c: T.handle) -> None:
    C = T.match_buffer(
        c, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
    )
    with T.block("root"):
        T.reads()
        T.writes(C[0:16, 0:16])
        T.evaluate(
            T.tvm_fill_fragment(
                C.data,
                16,
                16,
                16,
                C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),
                T.float32(0),
                dtype="handle",
            )
        )


@T.prim_func
def wmma_store_desc(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(
        a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
    )
    C = T.match_buffer(c, (16, 16), "float32", align=128, offset_factor=16, scope="global")
    with T.block("root"):
        T.reads(A[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        for i, j in T.grid(16, 16):
            with T.block("store"):
                vii, vjj = T.axis.remap("SS", [i, j])
                C[vii, vjj] = A[vii, vjj]


@T.prim_func
def wmma_store_impl(a: T.handle, c: T.handle) -> None:
    s1 = T.var("int32")
    s0 = T.var("int32")
    A = T.match_buffer(
        a, (16, 16), "float32", align=128, offset_factor=16, scope="wmma.accumulator"
    )
    C = T.match_buffer(
        c,
        (16, 16),
        "float32",
        align=128,
        offset_factor=16,
        scope="global",
        strides=[s1, s0],
    )
    with T.block("root"):
        T.reads(A[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        T.evaluate(
            T.tvm_store_matrix_sync(
                A.data,
                16,
                16,
                16,
                A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),
                C.access_ptr("w"),
                s1,
                "row_major",
                dtype="handle",
            )
        )


try:
    # handle exception if we register multi times
    tir.TensorIntrin.register("wmma_load_a", wmma_load_a_desc, wmma_load_a_impl)
    tir.TensorIntrin.register("wmma_load_b", wmma_load_b_desc, wmma_load_b_impl)
    tir.TensorIntrin.register("wmma_sync", wmma_sync_desc, wmma_sync_impl)
    tir.TensorIntrin.register("wmma_fill", wmma_fill_desc, wmma_fill_impl)
    tir.TensorIntrin.register("wmma_store", wmma_store_desc, wmma_store_impl)
except ValueError:
    pass

Blockize 张量计算#

正如课程中所说,我们可以使用 TensorIR 来表示一组带有 Block 的张量化计算。 我们可以直接用 Block 编写一个 TensorIR 程序,也可以通过blockize 生成新的 block。 请记住,wmma 操作适用于 16x16x16 矩阵乘法,我们需要切分循环,而最里面的循环是 16x16x16

sch = tir.Schedule(MatmulModule)
block = sch.get_block("matmul")
i, j, k = sch.get_loops(block)

i, ii = sch.split(i, factors=[None, 16])
j, ji = sch.split(j, factors=[None, 16])
k, ki = sch.split(k, factors=[None, 16])
sch.reorder(i, j, k, ii, ji, ki)
wmma_sync = sch.blockize(loop=ii)
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(X: T.Buffer[(1024, 1024), "float16"], Y: T.Buffer[(1024, 1024), "float16"], Z: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i_0, j_0, k_0 in T.grid(64, 64, 64):
            with T.block("matmul_o"):
                vi_o, vj_o, vk_o = T.axis.remap("SSR", [i_0, j_0, k_0])
                T.reads(X[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], Y[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                T.writes(Z[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                with T.init():
                    for i_1, j_1 in T.grid(16, 16):
                        with T.block("matmul_init"):
                            vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
                            T.reads()
                            T.writes(Z[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
                            Z[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
                for i_1, j_1, k_1 in T.grid(16, 16, 16):
                    with T.block("matmul"):
                        vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
                        T.reads(Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i], X[vi_o * 16 + vi_i, vk_o * 16 + vk_i], Y[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
                        T.writes(Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
                        Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + T.cast(X[vi_o * 16 + vi_i, vk_o * 16 + vk_i], "float32") * T.cast(Y[vj_o * 16 + vj_i, vk_o * 16 + vk_i], "float32")
    

切分循环并绑定 threadIdx#

Warp 指令#

请注意,所有 Tensor Core 指令都是 warp 指令,这意味着一个 warp 中的所有 32 个线程应该同时执行此指令。 使 threadIdx.x extent=32 是解决此问题的最简单方法之一。 然后我们可以将threadIdx.x绑定到任何循环除了那些直接或间接包含Tensor Core内在函数的循环。 另请注意,这不是唯一的解决方案。 我们唯一应该做的就是确保一个 warp 中的所有线程都可以同时调用 Tensor Core。

i0, i1, i2 = sch.split(i, factors=[8, 4, 2])
j0, j1, j2 = sch.split(j, factors=[8, 4, 2])
k0, k1, k2 = sch.split(k, factors=[16, 2, 2])

sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2, k2)
bx = sch.fuse(i0, j0)
sch.bind(bx, "blockIdx.x")
ty = sch.fuse(i1, j1)
sch.bind(ty, "threadIdx.y")
# We can't bind to `threadIdx.x` since we have warp-level operators under the loop
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(X: T.Buffer[(1024, 1024), "float16"], Y: T.Buffer[(1024, 1024), "float16"], Z: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        for i_0_0_j_0_0_fused in T.thread_binding(64, thread="blockIdx.x"):
            for i_0_1_j_0_1_fused in T.thread_binding(16, thread="threadIdx.y"):
                for k_0_0, k_0_1, i_0_2, j_0_2, k_0_2 in T.grid(16, 2, 2, 2, 2):
                    with T.block("matmul_o"):
                        vi_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + i_0_2)
                        vj_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + j_0_2)
                        vk_o = T.axis.reduce(64, k_0_0 * 4 + k_0_1 * 2 + k_0_2)
                        T.reads(X[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], Y[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                        T.writes(Z[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                        with T.init():
                            for i_1, j_1 in T.grid(16, 16):
                                with T.block("matmul_init"):
                                    vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
                                    T.reads()
                                    T.writes(Z[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
                                    Z[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
                        for i_1, j_1, k_1 in T.grid(16, 16, 16):
                            with T.block("matmul"):
                                vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
                                T.reads(Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i], X[vi_o * 16 + vi_i, vk_o * 16 + vk_i], Y[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
                                T.writes(Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
                                Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + T.cast(X[vi_o * 16 + vi_i, vk_o * 16 + vk_i], "float32") * T.cast(Y[vj_o * 16 + vj_i, vk_o * 16 + vk_i], "float32")
    

AB缓存到共享内存中#

与 Cuda Cores 的优化技巧类似,我们仍然需要将 AB 缓存到共享内存中。 此外,还需要利用 cooperative fetching 技术。

X_shared = sch.cache_read(wmma_sync, read_buffer_index=0, storage_scope="shared")
Y_shared = sch.cache_read(wmma_sync, read_buffer_index=1, storage_scope="shared")


def schedule_shared(block):
    sch.compute_at(block, k0)
    x, y = sch.get_loops(block)[-2:]
    fused = sch.fuse(x, y)
    x0, x1, x2, x3 = sch.split(fused, factors=[None, 16, 32, 8])
    sch.bind(x1, "threadIdx.y")
    # here we must bind threadIdx.x == 32 to satisfy the requirements of warp-level operation.
    sch.bind(x2, "threadIdx.x") 
    sch.vectorize(x3)


schedule_shared(X_shared)
schedule_shared(Y_shared)
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(X: T.Buffer[(1024, 1024), "float16"], Y: T.Buffer[(1024, 1024), "float16"], Z: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        X_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        Y_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        for i_0_0_j_0_0_fused in T.thread_binding(64, thread="blockIdx.x"):
            for i_0_1_j_0_1_fused in T.thread_binding(16, thread="threadIdx.y"):
                for k_0_0 in T.serial(16):
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("X_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(X[v0, v1])
                                        T.writes(X_shared[v0, v1])
                                        X_shared[v0, v1] = X[v0, v1]
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("Y_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(Y[v0, v1])
                                        T.writes(Y_shared[v0, v1])
                                        Y_shared[v0, v1] = Y[v0, v1]
                    for k_0_1, i_0_2, j_0_2, k_0_2 in T.grid(2, 2, 2, 2):
                        with T.block("matmul_o"):
                            vi_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + i_0_2)
                            vj_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + j_0_2)
                            vk_o = T.axis.reduce(64, k_0_0 * 4 + k_0_1 * 2 + k_0_2)
                            T.reads(X_shared[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], Y_shared[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                            T.writes(Z[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                            with T.init():
                                for i_1, j_1 in T.grid(16, 16):
                                    with T.block("matmul_init"):
                                        vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
                                        T.reads()
                                        T.writes(Z[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
                                        Z[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
                            for i_1, j_1, k_1 in T.grid(16, 16, 16):
                                with T.block("matmul"):
                                    vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
                                    T.reads(Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i], X_shared[vi_o * 16 + vi_i, vk_o * 16 + vk_i], Y_shared[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
                                    T.writes(Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
                                    Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + T.cast(X_shared[vi_o * 16 + vi_i, vk_o * 16 + vk_i], "float32") * T.cast(Y_shared[vj_o * 16 + vj_i, vk_o * 16 + vk_i], "float32")
    

将输入输出数据缓存到特殊内存层级#

Tensor Cores 不能直接使用共享内存或本地内存中的数据。 我们必须将数据缓存到 wmma.matrix_awmma.matrix_b 并更新 wmma.accumulator 中的计算。

缓存输入数据#

X_local = sch.cache_read(wmma_sync, 0, storage_scope="wmma.matrix_a")
Y_local = sch.cache_read(wmma_sync, 1, storage_scope="wmma.matrix_b")
sch.compute_at(X_local, k1)
sch.compute_at(Y_local, k1)
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(X: T.Buffer[(1024, 1024), "float16"], Y: T.Buffer[(1024, 1024), "float16"], Z: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        X_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        Y_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        X_shared_wmma_matrix_a = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_a")
        Y_shared_wmma_matrix_b = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_b")
        for i_0_0_j_0_0_fused in T.thread_binding(64, thread="blockIdx.x"):
            for i_0_1_j_0_1_fused in T.thread_binding(16, thread="threadIdx.y"):
                for k_0_0 in T.serial(16):
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("X_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(X[v0, v1])
                                        T.writes(X_shared[v0, v1])
                                        X_shared[v0, v1] = X[v0, v1]
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("Y_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(Y[v0, v1])
                                        T.writes(Y_shared[v0, v1])
                                        Y_shared[v0, v1] = Y[v0, v1]
                    for k_0_1 in T.serial(2):
                        for ax0, ax1 in T.grid(32, 32):
                            with T.block("X_shared_wmma.matrix_a"):
                                v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + i_0_1_j_0_1_fused // 4 * 32 + ax0)
                                v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1)
                                T.reads(X_shared[v0, v1])
                                T.writes(X_shared_wmma_matrix_a[v0, v1])
                                X_shared_wmma_matrix_a[v0, v1] = X_shared[v0, v1]
                        for ax0, ax1 in T.grid(32, 32):
                            with T.block("Y_shared_wmma.matrix_b"):
                                v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + i_0_1_j_0_1_fused % 4 * 32 + ax0)
                                v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1)
                                T.reads(Y_shared[v0, v1])
                                T.writes(Y_shared_wmma_matrix_b[v0, v1])
                                Y_shared_wmma_matrix_b[v0, v1] = Y_shared[v0, v1]
                        for i_0_2, j_0_2, k_0_2 in T.grid(2, 2, 2):
                            with T.block("matmul_o"):
                                vi_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + i_0_2)
                                vj_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + j_0_2)
                                vk_o = T.axis.reduce(64, k_0_0 * 4 + k_0_1 * 2 + k_0_2)
                                T.reads(X_shared_wmma_matrix_a[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], Y_shared_wmma_matrix_b[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                                T.writes(Z[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                                with T.init():
                                    for i_1, j_1 in T.grid(16, 16):
                                        with T.block("matmul_init"):
                                            vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
                                            T.reads()
                                            T.writes(Z[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
                                            Z[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
                                for i_1, j_1, k_1 in T.grid(16, 16, 16):
                                    with T.block("matmul"):
                                        vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
                                        T.reads(Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i], X_shared_wmma_matrix_a[vi_o * 16 + vi_i, vk_o * 16 + vk_i], Y_shared_wmma_matrix_b[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
                                        T.writes(Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
                                        Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = Z[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + T.cast(X_shared_wmma_matrix_a[vi_o * 16 + vi_i, vk_o * 16 + vk_i], "float32") * T.cast(Y_shared_wmma_matrix_b[vj_o * 16 + vj_i, vk_o * 16 + vk_i], "float32")
    

缓存输出数据#

write_back_block = sch.cache_write(wmma_sync, 0, storage_scope="wmma.accumulator")
sch.reverse_compute_at(write_back_block, ty)
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(X: T.Buffer[(1024, 1024), "float16"], Y: T.Buffer[(1024, 1024), "float16"], Z: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        X_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        Y_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        X_shared_wmma_matrix_a = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_a")
        Y_shared_wmma_matrix_b = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_b")
        Z_wmma_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", scope="wmma.accumulator")
        for i_0_0_j_0_0_fused in T.thread_binding(64, thread="blockIdx.x"):
            for i_0_1_j_0_1_fused in T.thread_binding(16, thread="threadIdx.y"):
                for k_0_0 in T.serial(16):
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("X_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(X[v0, v1])
                                        T.writes(X_shared[v0, v1])
                                        X_shared[v0, v1] = X[v0, v1]
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("Y_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(Y[v0, v1])
                                        T.writes(Y_shared[v0, v1])
                                        Y_shared[v0, v1] = Y[v0, v1]
                    for k_0_1 in T.serial(2):
                        for ax0, ax1 in T.grid(32, 32):
                            with T.block("X_shared_wmma.matrix_a"):
                                v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + i_0_1_j_0_1_fused // 4 * 32 + ax0)
                                v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1)
                                T.reads(X_shared[v0, v1])
                                T.writes(X_shared_wmma_matrix_a[v0, v1])
                                X_shared_wmma_matrix_a[v0, v1] = X_shared[v0, v1]
                        for ax0, ax1 in T.grid(32, 32):
                            with T.block("Y_shared_wmma.matrix_b"):
                                v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + i_0_1_j_0_1_fused % 4 * 32 + ax0)
                                v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1)
                                T.reads(Y_shared[v0, v1])
                                T.writes(Y_shared_wmma_matrix_b[v0, v1])
                                Y_shared_wmma_matrix_b[v0, v1] = Y_shared[v0, v1]
                        for i_0_2, j_0_2, k_0_2 in T.grid(2, 2, 2):
                            with T.block("matmul_o"):
                                vi_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + i_0_2)
                                vj_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + j_0_2)
                                vk_o = T.axis.reduce(64, k_0_0 * 4 + k_0_1 * 2 + k_0_2)
                                T.reads(X_shared_wmma_matrix_a[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], Y_shared_wmma_matrix_b[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                                T.writes(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                                with T.init():
                                    for i_1, j_1 in T.grid(16, 16):
                                        with T.block("matmul_init"):
                                            vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
                                            T.reads()
                                            T.writes(Z_wmma_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
                                            Z_wmma_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
                                for i_1, j_1, k_1 in T.grid(16, 16, 16):
                                    with T.block("matmul"):
                                        vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
                                        T.reads(Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i], X_shared_wmma_matrix_a[vi_o * 16 + vi_i, vk_o * 16 + vk_i], Y_shared_wmma_matrix_b[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
                                        T.writes(Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
                                        Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + T.cast(X_shared_wmma_matrix_a[vi_o * 16 + vi_i, vk_o * 16 + vk_i], "float32") * T.cast(Y_shared_wmma_matrix_b[vj_o * 16 + vj_i, vk_o * 16 + vk_i], "float32")
                for ax0, ax1 in T.grid(32, 32):
                    with T.block("Z_wmma.accumulator"):
                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + i_0_1_j_0_1_fused // 4 * 32 + ax0)
                        v1 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + i_0_1_j_0_1_fused % 4 * 32 + ax1)
                        T.reads(Z_wmma_accumulator[v0, v1])
                        T.writes(Z[v0, v1])
                        Z[v0, v1] = Z_wmma_accumulator[v0, v1]
    

切分 Tensor Core 内存拷贝#

wmma.load_matrixwmma.store_matrix 使用 16x16 矩阵执行内存复制。 然后我们对循环进行切分,以此来匹配 intrinsic。

def schedule_copy(block):
    x, y = sch.get_loops(block)[-2:]
    x0, x1 = sch.split(x, factors=[None, 16])
    y0, y1 = sch.split(y, factors=[None, 16])
    sch.reorder(x0, y0, x1, y1)

schedule_copy(X_local)
schedule_copy(Y_local)
schedule_copy(write_back_block)
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(X: T.Buffer[(1024, 1024), "float16"], Y: T.Buffer[(1024, 1024), "float16"], Z: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        X_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        Y_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        X_shared_wmma_matrix_a = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_a")
        Y_shared_wmma_matrix_b = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_b")
        Z_wmma_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", scope="wmma.accumulator")
        for i_0_0_j_0_0_fused in T.thread_binding(64, thread="blockIdx.x"):
            for i_0_1_j_0_1_fused in T.thread_binding(16, thread="threadIdx.y"):
                for k_0_0 in T.serial(16):
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("X_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(X[v0, v1])
                                        T.writes(X_shared[v0, v1])
                                        X_shared[v0, v1] = X[v0, v1]
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("Y_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(Y[v0, v1])
                                        T.writes(Y_shared[v0, v1])
                                        Y_shared[v0, v1] = Y[v0, v1]
                    for k_0_1 in T.serial(2):
                        for ax0_0, ax1_0, ax0_1, ax1_1 in T.grid(2, 2, 16, 16):
                            with T.block("X_shared_wmma.matrix_a"):
                                v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + i_0_1_j_0_1_fused // 4 * 32 + ax0_0 * 16 + ax0_1)
                                v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1_0 * 16 + ax1_1)
                                T.reads(X_shared[v0, v1])
                                T.writes(X_shared_wmma_matrix_a[v0, v1])
                                X_shared_wmma_matrix_a[v0, v1] = X_shared[v0, v1]
                        for ax0_0, ax1_0, ax0_1, ax1_1 in T.grid(2, 2, 16, 16):
                            with T.block("Y_shared_wmma.matrix_b"):
                                v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + i_0_1_j_0_1_fused % 4 * 32 + ax0_0 * 16 + ax0_1)
                                v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1_0 * 16 + ax1_1)
                                T.reads(Y_shared[v0, v1])
                                T.writes(Y_shared_wmma_matrix_b[v0, v1])
                                Y_shared_wmma_matrix_b[v0, v1] = Y_shared[v0, v1]
                        for i_0_2, j_0_2, k_0_2 in T.grid(2, 2, 2):
                            with T.block("matmul_o"):
                                vi_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + i_0_2)
                                vj_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + j_0_2)
                                vk_o = T.axis.reduce(64, k_0_0 * 4 + k_0_1 * 2 + k_0_2)
                                T.reads(X_shared_wmma_matrix_a[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], Y_shared_wmma_matrix_b[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                                T.writes(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                                with T.init():
                                    for i_1, j_1 in T.grid(16, 16):
                                        with T.block("matmul_init"):
                                            vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
                                            T.reads()
                                            T.writes(Z_wmma_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
                                            Z_wmma_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
                                for i_1, j_1, k_1 in T.grid(16, 16, 16):
                                    with T.block("matmul"):
                                        vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
                                        T.reads(Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i], X_shared_wmma_matrix_a[vi_o * 16 + vi_i, vk_o * 16 + vk_i], Y_shared_wmma_matrix_b[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
                                        T.writes(Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
                                        Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + T.cast(X_shared_wmma_matrix_a[vi_o * 16 + vi_i, vk_o * 16 + vk_i], "float32") * T.cast(Y_shared_wmma_matrix_b[vj_o * 16 + vj_i, vk_o * 16 + vk_i], "float32")
                for ax0_0, ax1_0, ax0_1, ax1_1 in T.grid(2, 2, 16, 16):
                    with T.block("Z_wmma.accumulator"):
                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + i_0_1_j_0_1_fused // 4 * 32 + ax0_0 * 16 + ax0_1)
                        v1 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + i_0_1_j_0_1_fused % 4 * 32 + ax1_0 * 16 + ax1_1)
                        T.reads(Z_wmma_accumulator[v0, v1])
                        T.writes(Z[v0, v1])
                        Z[v0, v1] = Z_wmma_accumulator[v0, v1]
    

Tensorize#

tensorize 之前,我们需要先执行 decompose_reduction,因为 wmma_syncwmma_fill 是两个 intrinsic,需要对 init block 和 update block 进行两次 tensorize

init = sch.decompose_reduction(wmma_sync, k0)
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(X: T.Buffer[(1024, 1024), "float16"], Y: T.Buffer[(1024, 1024), "float16"], Z: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        # body
        # with T.block("root")
        X_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        Y_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        X_shared_wmma_matrix_a = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_a")
        Y_shared_wmma_matrix_b = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_b")
        Z_wmma_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", scope="wmma.accumulator")
        for i_0_0_j_0_0_fused in T.thread_binding(64, thread="blockIdx.x"):
            for i_0_1_j_0_1_fused in T.thread_binding(16, thread="threadIdx.y"):
                for i_0_2_init, j_0_2_init in T.grid(2, 2):
                    with T.block("matmul_o_init"):
                        vi_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + i_0_2_init)
                        vj_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + j_0_2_init)
                        T.reads()
                        T.writes(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                        for i_1, j_1 in T.grid(16, 16):
                            with T.block("matmul_init"):
                                vi_i_init, vj_i_init = T.axis.remap("SS", [i_1, j_1])
                                T.reads()
                                T.writes(Z_wmma_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init])
                                Z_wmma_accumulator[vi_o * 16 + vi_i_init, vj_o * 16 + vj_i_init] = T.float32(0)
                for k_0_0 in T.serial(16):
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("X_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(X[v0, v1])
                                        T.writes(X_shared[v0, v1])
                                        X_shared[v0, v1] = X[v0, v1]
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("Y_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(Y[v0, v1])
                                        T.writes(Y_shared[v0, v1])
                                        Y_shared[v0, v1] = Y[v0, v1]
                    for k_0_1 in T.serial(2):
                        for ax0_0, ax1_0, ax0_1, ax1_1 in T.grid(2, 2, 16, 16):
                            with T.block("X_shared_wmma.matrix_a"):
                                v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + i_0_1_j_0_1_fused // 4 * 32 + ax0_0 * 16 + ax0_1)
                                v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1_0 * 16 + ax1_1)
                                T.reads(X_shared[v0, v1])
                                T.writes(X_shared_wmma_matrix_a[v0, v1])
                                X_shared_wmma_matrix_a[v0, v1] = X_shared[v0, v1]
                        for ax0_0, ax1_0, ax0_1, ax1_1 in T.grid(2, 2, 16, 16):
                            with T.block("Y_shared_wmma.matrix_b"):
                                v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + i_0_1_j_0_1_fused % 4 * 32 + ax0_0 * 16 + ax0_1)
                                v1 = T.axis.spatial(1024, k_0_0 * 64 + k_0_1 * 32 + ax1_0 * 16 + ax1_1)
                                T.reads(Y_shared[v0, v1])
                                T.writes(Y_shared_wmma_matrix_b[v0, v1])
                                Y_shared_wmma_matrix_b[v0, v1] = Y_shared[v0, v1]
                        for i_0_2, j_0_2, k_0_2 in T.grid(2, 2, 2):
                            with T.block("matmul_o_update"):
                                vi_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + i_0_2)
                                vj_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + j_0_2)
                                vk_o = T.axis.reduce(64, k_0_0 * 4 + k_0_1 * 2 + k_0_2)
                                T.reads(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16], X_shared_wmma_matrix_a[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], Y_shared_wmma_matrix_b[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                                T.writes(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                                for i_1, j_1, k_1 in T.grid(16, 16, 16):
                                    with T.block("matmul"):
                                        vi_i, vj_i, vk_i = T.axis.remap("SSR", [i_1, j_1, k_1])
                                        T.reads(Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i], X_shared_wmma_matrix_a[vi_o * 16 + vi_i, vk_o * 16 + vk_i], Y_shared_wmma_matrix_b[vj_o * 16 + vj_i, vk_o * 16 + vk_i])
                                        T.writes(Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i])
                                        Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = Z_wmma_accumulator[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + T.cast(X_shared_wmma_matrix_a[vi_o * 16 + vi_i, vk_o * 16 + vk_i], "float32") * T.cast(Y_shared_wmma_matrix_b[vj_o * 16 + vj_i, vk_o * 16 + vk_i], "float32")
                for ax0_0, ax1_0, ax0_1, ax1_1 in T.grid(2, 2, 16, 16):
                    with T.block("Z_wmma.accumulator"):
                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + i_0_1_j_0_1_fused // 4 * 32 + ax0_0 * 16 + ax0_1)
                        v1 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + i_0_1_j_0_1_fused % 4 * 32 + ax1_0 * 16 + ax1_1)
                        T.reads(Z_wmma_accumulator[v0, v1])
                        T.writes(Z[v0, v1])
                        Z[v0, v1] = Z_wmma_accumulator[v0, v1]
    
sch.tensorize(sch.get_loops(X_local)[-2], "wmma_load_a")
sch.tensorize(sch.get_loops(Y_local)[-2], "wmma_load_b")
sch.tensorize(init, "wmma_fill")
sch.tensorize(wmma_sync, "wmma_sync")
sch.tensorize(sch.get_loops(write_back_block)[-2], "wmma_store")
sch.mod.show()
@tvm.script.ir_module
class Module:
    @T.prim_func
    def main(X: T.Buffer[(1024, 1024), "float16"], Y: T.Buffer[(1024, 1024), "float16"], Z: T.Buffer[(1024, 1024), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        s0 = T.var("int32")
        s0_1 = T.var("int32")
        s0_2 = T.var("int32")
        s1 = T.var("int32")
        s1_1 = T.var("int32")
        s1_2 = T.var("int32")
        # body
        # with T.block("root")
        X_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        Y_shared = T.alloc_buffer([1024, 1024], dtype="float16", scope="shared")
        X_shared_wmma_matrix_a = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_a")
        Y_shared_wmma_matrix_b = T.alloc_buffer([1024, 1024], dtype="float16", scope="wmma.matrix_b")
        Z_wmma_accumulator = T.alloc_buffer([1024, 1024], dtype="float32", scope="wmma.accumulator")
        for i_0_0_j_0_0_fused in T.thread_binding(64, thread="blockIdx.x"):
            for i_0_1_j_0_1_fused in T.thread_binding(16, thread="threadIdx.y"):
                for i_0_2_init, j_0_2_init in T.grid(2, 2):
                    with T.block("matmul_o_init"):
                        vi_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + i_0_2_init)
                        vj_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + j_0_2_init)
                        T.reads()
                        T.writes(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                        C = T.match_buffer(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16)
                        T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // 256 + C.elem_offset % 256 // 16, T.float32(0), dtype="handle"))
                for k_0_0 in T.serial(16):
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("X_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused // 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(X[v0, v1])
                                        T.writes(X_shared[v0, v1])
                                        X_shared[v0, v1] = X[v0, v1]
                    for ax0_ax1_fused_0 in T.serial(2):
                        for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"):
                            for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                for ax0_ax1_fused_3 in T.vectorized(8):
                                    with T.block("Y_shared"):
                                        v0 = T.axis.spatial(1024, i_0_0_j_0_0_fused % 8 * 128 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 64)
                                        v1 = T.axis.spatial(1024, k_0_0 * 64 + (ax0_ax1_fused_0 * 4096 + ax0_ax1_fused_1 * 256 + ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 64)
                                        T.reads(Y[v0, v1])
                                        T.writes(Y_shared[v0, v1])
                                        Y_shared[v0, v1] = Y[v0, v1]
                    for k_0_1 in T.serial(2):
                        for ax0_0, ax1_0 in T.grid(2, 2):
                            with T.block("X_shared_wmma.matrix_a_o"):
                                v0_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + ax0_0)
                                v1_o = T.axis.spatial(64, k_0_0 * 4 + k_0_1 * 2 + ax1_0)
                                T.reads(X_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                T.writes(X_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                A = T.match_buffer(X_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="float16", strides=[s1, s0], scope="shared", offset_factor=16)
                                C_1 = T.match_buffer(X_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_a", offset_factor=16)
                                T.evaluate(T.tvm_load_matrix_sync(C_1.data, 16, 16, 16, C_1.elem_offset // 256 + C_1.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float16"), A.data, A.elem_offset, s1 * 16, 1, dtype="handle"), s1, "row_major", dtype="handle"))
                        for ax0_0, ax1_0 in T.grid(2, 2):
                            with T.block("Y_shared_wmma.matrix_b_o"):
                                v0_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + ax0_0)
                                v1_o = T.axis.spatial(64, k_0_0 * 4 + k_0_1 * 2 + ax1_0)
                                T.reads(Y_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                T.writes(Y_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                                A_1 = T.match_buffer(Y_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="float16", strides=[s1_1, s0_1], scope="shared", offset_factor=16)
                                C_2 = T.match_buffer(Y_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_b", offset_factor=16)
                                T.evaluate(T.tvm_load_matrix_sync(C_2.data, 16, 16, 16, C_2.elem_offset // 256 + C_2.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float16"), A_1.data, A_1.elem_offset, s1_1 * 16, 1, dtype="handle"), s1_1, "col_major", dtype="handle"))
                        for i_0_2, j_0_2, k_0_2 in T.grid(2, 2, 2):
                            with T.block("matmul_o_update"):
                                vi_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + i_0_2)
                                vj_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + j_0_2)
                                vk_o = T.axis.reduce(64, k_0_0 * 4 + k_0_1 * 2 + k_0_2)
                                T.reads(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16], X_shared_wmma_matrix_a[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], Y_shared_wmma_matrix_b[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16])
                                T.writes(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16])
                                A_2 = T.match_buffer(X_shared_wmma_matrix_a[vi_o * 16 : vi_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_a", offset_factor=16)
                                B = T.match_buffer(Y_shared_wmma_matrix_b[vj_o * 16 : vj_o * 16 + 16, vk_o * 16 : vk_o * 16 + 16], [16, 16], dtype="float16", scope="wmma.matrix_b", offset_factor=16)
                                C_3 = T.match_buffer(Z_wmma_accumulator[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16)
                                T.evaluate(T.tvm_mma_sync(C_3.data, C_3.elem_offset // 256 + C_3.elem_offset % 256 // 16, A_2.data, A_2.elem_offset // 256 + A_2.elem_offset % 256 // 16, B.data, B.elem_offset // 256 + B.elem_offset % 256 // 16, C_3.data, C_3.elem_offset // 256 + C_3.elem_offset % 256 // 16, dtype="handle"))
                for ax0_0, ax1_0 in T.grid(2, 2):
                    with T.block("Z_wmma.accumulator_o"):
                        v0_o = T.axis.spatial(64, i_0_0_j_0_0_fused // 8 * 8 + i_0_1_j_0_1_fused // 4 * 2 + ax0_0)
                        v1_o = T.axis.spatial(64, i_0_0_j_0_0_fused % 8 * 8 + i_0_1_j_0_1_fused % 4 * 2 + ax1_0)
                        T.reads(Z_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                        T.writes(Z[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16])
                        A_3 = T.match_buffer(Z_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="float32", scope="wmma.accumulator", offset_factor=16)
                        C_4 = T.match_buffer(Z[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16], [16, 16], dtype="float32", strides=[s1_2, s0_2], offset_factor=16)
                        T.evaluate(T.tvm_store_matrix_sync(A_3.data, 16, 16, 16, A_3.elem_offset // 256 + A_3.elem_offset % 256 // 16, T.tvm_access_ptr(T.type_annotation(dtype="float32"), C_4.data, C_4.elem_offset, s1_2 * 16, 2, dtype="handle"), s1_2, "row_major", dtype="handle"))
    

构建并评估结果#

rt_mod = tvm.build(sch.mod, target="cuda")

dev = tvm.cuda()
num_flop = 1024**3 * 2
A_np = np.random.randn(1024, 1024).astype("float16")
B_np = np.random.randn(1024, 1024).astype("float16")
C_np = A_np.astype("float32") @ (B_np.astype("float32").T)

A_nd = tvm.nd.array(A_np, dev)
B_nd = tvm.nd.array(B_np, dev)
C_nd = tvm.nd.array(np.empty((1024, 1024), dtype="float32"), dev)

rt_mod(A_nd, B_nd, C_nd)
np.testing.assert_allclose(C_np, C_nd.numpy(), rtol=1e-3, atol=1e-3)

evaluator = rt_mod.time_evaluator("main", dev, number=10)
print("Performance: %f GFLOPS" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))
Performance: 9443.610543 GFLOPS

讨论#

请考虑如何使这个程序运行得更快?(极限性能在 50T 左右,而这个程序在 RTX-3080 上只有 23T)

参考文献#

https://developer.nvidia.com/blog/programming-tensor-cores-cuda-9/

https://tvm.apache.org/docs/how_to/optimize_operators/opt_conv_tensorcore.html