{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tensor Core 编程教程\n", "\n", "在课程中,我们使用 `tmm` 内在函数来演示 `Tensorize` 的进展。 在本教程中,我们将把 TensorIR 运用到 NVIDIA GPU 上的 Tensor Cores。 请注意,Tensor Cores仅在具有 Volta 或更新架构的 NVIDIA GPU 上受支持(例如,`V100`、`T4`、`RTX-20X0`、`A100`、`RTX-30X0`)。 不幸的是,Colab 提供的大多数 GPU 都太旧,无法支持 Tensor Core。 您可能需要为本教程准备自己的设备。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 什么是 Tensor Core?\n", "\n", "张量核心是具有 Volta 和更新架构的 GPU 上的可编程矩阵乘法和累加单元。 每个 Tensor Core 都提供了一个矩阵处理数组,它执行 `D = A * B + C` 运算,如果我们使用 `nvcuda::wmma`,则`A`、`B`、`C` 和 `D` 是 `16x16` 的矩阵。 其中,矩阵乘法输入 `A` 和 `B` 是 `fp16` 矩阵,而累加矩阵 `C` 和 `D` 可以是 `fp16` 或 `fp32` 矩阵。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![WMMA16x16x16.png]()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "CUDA 语言只能使用 `warp-level` 原语 `wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)` 在 Tensor Core 上执行 `16x16x16` 半精度矩阵乘法。 在调用矩阵乘法之前,我们必须使用原始的 `wmma::load_matrix_sync` 显式地将数据从内存加载到寄存器中(类似于我们在第6章第2部分的 `tmm` 演示中所做的)。 NVCC 编译器将该原语转换为多个内存加载指令。 在运行时,每个线程从矩阵 `A` 加载 16 个元素或从 `B` 加载 16 个元素。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 准备工作" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tvm\n", "from tvm.script import tir as T\n", "from tvm import tir\n", "\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 编写 Matmul 的 Tensor IR 程序" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "@tvm.script.ir_module\n", "class MatmulModule:\n", " @T.prim_func\n", " def main(\n", " X: T.Buffer[(1024, 1024), \"float16\"],\n", " Y: T.Buffer[(1024, 1024), \"float16\"],\n", " Z: T.Buffer[(1024, 1024), \"float32\"],\n", " ) -> None:\n", " T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " for i, j, k in T.grid(1024, 1024, 1024):\n", " with T.block(\"matmul\"):\n", " vi, vj, vk = T.axis.remap(\"SSR\", [i, j, k])\n", " with T.init():\n", " Z[vi, vj] = T.float32(0)\n", " Z[vi, vj] += T.cast(X[vi, vk], \"float32\") * T.cast(Y[vj, vk], \"float32\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "注意计算 `Z[vi, vj] += T.cast(X[vi, vk], \"float32\") * T.cast(Y[vj, vk], \"float32\")` 与常规表示有点不同。 由于 Tensor Cores 加载 `fp16` 的数据,但在 `fp32` 进行计算。 所以我们必须在计算之前将数据 `cast` 到`fp32`。\n", "\n", "![image11.png]()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 内存层级\n", "\n", "在传统的 GPU 调度中,我们有 “全局内存”、“共享内存”和“本地寄存器”的内存层级。 为了支持 Tensor Cores,我们引入了另外三个特殊的内存范围:`wmma.matrix_a`、`wmma.matrix_b` 和 `wmma.accumulator`(类似于在第 6 章第 2 部分的演示中的 `global.A_reg`、`global.B_reg` 和 `global. accumulator`)。 在硬件上,所有 `wmma` 的内存相关层级都存储在片上寄存器级别,与本地内存相同。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 注册 Tensor Intrinsic\n", "\n", "这里我们注册了所有的 Tensor Core intrinsics,包括`load_matrix_a`、`load_matrix_b`、`wmma_fill`(初始化`C = 0`)、`wmma_sync`(累加计算`C += A * B`)和 `store_matrix`。 在本教程中,我们不会解释如何编写 intrinsic ,而是关注如何将给定的 intrinsic 应用于张量化程序。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "@T.prim_func\n", "def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(a, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"shared\")\n", " C = T.match_buffer(c, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_a\")\n", "\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " for i, j in T.grid(16, 16):\n", " with T.block(\"load\"):\n", " vii, vjj = T.axis.remap(\"SS\", [i, j])\n", " C[vii, vjj] = A[vii, vjj]\n", "\n", "\n", "@T.prim_func\n", "def wmma_load_a_impl(a: T.handle, c: T.handle) -> None:\n", " s1 = T.var(\"int32\")\n", " s0 = T.var(\"int32\")\n", " A = T.match_buffer(\n", " a,\n", " (16, 16),\n", " \"float16\",\n", " align=128,\n", " offset_factor=16,\n", " scope=\"shared\",\n", " strides=[s1, s0],\n", " )\n", " C = T.match_buffer(c, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_a\")\n", "\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_load_matrix_sync(\n", " C.data,\n", " 16,\n", " 16,\n", " 16,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " A.access_ptr(\"r\"),\n", " s1,\n", " \"row_major\",\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "@T.prim_func\n", "def wmma_load_b_desc(a: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(a, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"shared\")\n", " C = T.match_buffer(c, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_b\")\n", "\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " for i, j in T.grid(16, 16):\n", " with T.block(\"load\"):\n", " vii, vjj = T.axis.remap(\"SS\", [i, j])\n", " C[vii, vjj] = A[vii, vjj]\n", "\n", "\n", "@T.prim_func\n", "def wmma_load_b_impl(a: T.handle, c: T.handle) -> None:\n", " s1 = T.var(\"int32\")\n", " s0 = T.var(\"int32\")\n", " A = T.match_buffer(\n", " a,\n", " (16, 16),\n", " \"float16\",\n", " align=128,\n", " offset_factor=16,\n", " scope=\"shared\",\n", " strides=[s1, s0],\n", " )\n", " C = T.match_buffer(c, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_b\")\n", "\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_load_matrix_sync(\n", " C.data,\n", " 16,\n", " 16,\n", " 16,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " A.access_ptr(\"r\"),\n", " s1,\n", " \"col_major\",\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "@T.prim_func\n", "def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(a, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_a\")\n", " B = T.match_buffer(b, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_b\")\n", " C = T.match_buffer(\n", " c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", "\n", " with T.block(\"root\"):\n", " T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " for i, j, k in T.grid(16, 16, 16):\n", " with T.block(\"\"):\n", " vii, vjj, vkk = T.axis.remap(\"SSR\", [i, j, k])\n", " C[vii, vjj] += T.cast(A[vii, vkk], \"float32\") * T.cast(B[vjj, vkk], \"float32\")\n", "\n", "\n", "@T.prim_func\n", "def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(a, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_a\")\n", " B = T.match_buffer(b, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_b\")\n", " C = T.match_buffer(\n", " c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", "\n", " with T.block(\"root\"):\n", " T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_mma_sync(\n", " C.data,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " A.data,\n", " A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),\n", " B.data,\n", " B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16),\n", " C.data,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "@T.prim_func\n", "def wmma_fill_desc(c: T.handle) -> None:\n", " C = T.match_buffer(\n", " c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", "\n", " with T.block(\"root\"):\n", " T.reads()\n", " T.writes(C[0:16, 0:16])\n", " for i, j in T.grid(16, 16):\n", " with T.block(\"init\"):\n", " vii, vjj = T.axis.remap(\"SS\", [i, j])\n", " C[vii, vjj] = T.float32(0)\n", "\n", "\n", "@T.prim_func\n", "def wmma_fill_impl(c: T.handle) -> None:\n", " C = T.match_buffer(\n", " c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", " with T.block(\"root\"):\n", " T.reads()\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_fill_fragment(\n", " C.data,\n", " 16,\n", " 16,\n", " 16,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " T.float32(0),\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "@T.prim_func\n", "def wmma_store_desc(a: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(\n", " a, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", " C = T.match_buffer(c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"global\")\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " for i, j in T.grid(16, 16):\n", " with T.block(\"store\"):\n", " vii, vjj = T.axis.remap(\"SS\", [i, j])\n", " C[vii, vjj] = A[vii, vjj]\n", "\n", "\n", "@T.prim_func\n", "def wmma_store_impl(a: T.handle, c: T.handle) -> None:\n", " s1 = T.var(\"int32\")\n", " s0 = T.var(\"int32\")\n", " A = T.match_buffer(\n", " a, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", " C = T.match_buffer(\n", " c,\n", " (16, 16),\n", " \"float32\",\n", " align=128,\n", " offset_factor=16,\n", " scope=\"global\",\n", " strides=[s1, s0],\n", " )\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_store_matrix_sync(\n", " A.data,\n", " 16,\n", " 16,\n", " 16,\n", " A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),\n", " C.access_ptr(\"w\"),\n", " s1,\n", " \"row_major\",\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "try:\n", " # handle exception if we register multi times\n", " tir.TensorIntrin.register(\"wmma_load_a\", wmma_load_a_desc, wmma_load_a_impl)\n", " tir.TensorIntrin.register(\"wmma_load_b\", wmma_load_b_desc, wmma_load_b_impl)\n", " tir.TensorIntrin.register(\"wmma_sync\", wmma_sync_desc, wmma_sync_impl)\n", " tir.TensorIntrin.register(\"wmma_fill\", wmma_fill_desc, wmma_fill_impl)\n", " tir.TensorIntrin.register(\"wmma_store\", wmma_store_desc, wmma_store_impl)\n", "except ValueError:\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Blockize 张量计算\n", "\n", "正如课程中所说,我们可以使用 TensorIR 来表示一组带有 `Block` 的张量化计算。 我们可以直接用 `Block` 编写一个 TensorIR 程序,也可以通过`blockize` 生成新的 `block`。 请记住,`wmma` 操作适用于 `16x16x16` 矩阵乘法,我们需要切分循环,而最里面的循环是 `16x16x16`。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0, j_0, k_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m64\u001b[39m, \u001b[38;5;28m64\u001b[39m, \u001b[38;5;28m64\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o, vj_o, vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_0, j_0, k_0])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "sch = tir.Schedule(MatmulModule)\n", "block = sch.get_block(\"matmul\")\n", "i, j, k = sch.get_loops(block)\n", "\n", "i, ii = sch.split(i, factors=[None, 16])\n", "j, ji = sch.split(j, factors=[None, 16])\n", "k, ki = sch.split(k, factors=[None, 16])\n", "sch.reorder(i, j, k, ii, ji, ki)\n", "wmma_sync = sch.blockize(loop=ii)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 切分循环并绑定 threadIdx\n", "\n", "### Warp 指令\n", "请注意,所有 Tensor Core 指令都是 warp 指令,这意味着一个 warp 中的所有 32 个线程应该同时执行此指令。 使 `threadIdx.x` extent=32 是解决此问题的最简单方法之一。 然后我们可以将`threadIdx.x`绑定到任何循环**除了**那些直接或间接包含Tensor Core内在函数的循环。 另请注意,这不是唯一的解决方案。 我们唯一应该做的就是确保一个 warp 中的所有线程都可以同时调用 Tensor Core。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0, k_0_1, i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "i0, i1, i2 = sch.split(i, factors=[8, 4, 2])\n", "j0, j1, j2 = sch.split(j, factors=[8, 4, 2])\n", "k0, k1, k2 = sch.split(k, factors=[16, 2, 2])\n", "\n", "sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2, k2)\n", "bx = sch.fuse(i0, j0)\n", "sch.bind(bx, \"blockIdx.x\")\n", "ty = sch.fuse(i1, j1)\n", "sch.bind(ty, \"threadIdx.y\")\n", "# We can't bind to `threadIdx.x` since we have warp-level operators under the loop\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 将`A`和`B`缓存到共享内存中\n", "\n", "与 Cuda Cores 的优化技巧类似,我们仍然需要将 `A` 和 `B` 缓存到共享内存中。 此外,还需要利用 cooperative fetching 技术。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1, i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "X_shared = sch.cache_read(wmma_sync, read_buffer_index=0, storage_scope=\"shared\")\n", "Y_shared = sch.cache_read(wmma_sync, read_buffer_index=1, storage_scope=\"shared\")\n", "\n", "\n", "def schedule_shared(block):\n", " sch.compute_at(block, k0)\n", " x, y = sch.get_loops(block)[-2:]\n", " fused = sch.fuse(x, y)\n", " x0, x1, x2, x3 = sch.split(fused, factors=[None, 16, 32, 8])\n", " sch.bind(x1, \"threadIdx.y\")\n", " # here we must bind threadIdx.x == 32 to satisfy the requirements of warp-level operation.\n", " sch.bind(x2, \"threadIdx.x\") \n", " sch.vectorize(x3)\n", "\n", "\n", "schedule_shared(X_shared)\n", "schedule_shared(Y_shared)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 将输入输出数据缓存到特殊内存层级\n", "\n", "Tensor Cores 不能直接使用共享内存或本地内存中的数据。 我们必须将数据缓存到 `wmma.matrix_a`、`wmma.matrix_b` 并更新 `wmma.accumulator` 中的计算。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 缓存输入数据" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0, v1])\n", " X_shared_wmma_matrix_a[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0, v1])\n", " Y_shared_wmma_matrix_b[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "X_local = sch.cache_read(wmma_sync, 0, storage_scope=\"wmma.matrix_a\")\n", "Y_local = sch.cache_read(wmma_sync, 1, storage_scope=\"wmma.matrix_b\")\n", "sch.compute_at(X_local, k1)\n", "sch.compute_at(Y_local, k1)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 缓存输出数据" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Z_wmma_accumulator \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0, v1])\n", " X_shared_wmma_matrix_a[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0, v1])\n", " Y_shared_wmma_matrix_b[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mZ_wmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[v0, v1])\n", " Z[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[v0, v1]\n", " \n", "\n" ] } ], "source": [ "write_back_block = sch.cache_write(wmma_sync, 0, storage_scope=\"wmma.accumulator\")\n", "sch.reverse_compute_at(write_back_block, ty)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 切分 Tensor Core 内存拷贝\n", "\n", "`wmma.load_matrix` 和 `wmma.store_matrix` 使用 `16x16` 矩阵执行内存复制。 然后我们对循环进行切分,以此来匹配 intrinsic。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Z_wmma_accumulator \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0, v1])\n", " X_shared_wmma_matrix_a[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0, v1])\n", " Y_shared_wmma_matrix_b[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mZ_wmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[v0, v1])\n", " Z[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[v0, v1]\n", " \n", "\n" ] } ], "source": [ "def schedule_copy(block):\n", " x, y = sch.get_loops(block)[-2:]\n", " x0, x1 = sch.split(x, factors=[None, 16])\n", " y0, y1 = sch.split(y, factors=[None, 16])\n", " sch.reorder(x0, y0, x1, y1)\n", "\n", "schedule_copy(X_local)\n", "schedule_copy(Y_local)\n", "schedule_copy(write_back_block)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tensorize\n", "\n", "tensorize 之前,我们需要先执行 `decompose_reduction`,因为 `wmma_sync` 和 `wmma_fill` 是两个 intrinsic,需要对 init block 和 update block 进行两次 tensorize\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Z_wmma_accumulator \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2_init, j_0_2_init \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2_init)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2_init)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0, v1])\n", " X_shared_wmma_matrix_a[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0, v1])\n", " Y_shared_wmma_matrix_b[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o_update\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mZ_wmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[v0, v1])\n", " Z[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[v0, v1]\n", " \n", "\n" ] } ], "source": [ "init = sch.decompose_reduction(wmma_sync, k0)\n", "sch.mod.show()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " s0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s0_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s0_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s1_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s1_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Z_wmma_accumulator \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2_init, j_0_2_init \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2_init)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2_init)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " C \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_fill_fragment(C\u001b[38;5;129;01m.\u001b[39;00mdata, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, C\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0)\n", " v1_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " A \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(X_shared[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, strides\u001b[38;5;129;01m=\u001b[39;00m[s1, s0], scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " C_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(X_shared_wmma_matrix_a[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_load_matrix_sync(C_1\u001b[38;5;129;01m.\u001b[39;00mdata, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, C_1\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C_1\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, T\u001b[38;5;129;01m.\u001b[39;00mtvm_access_ptr(T\u001b[38;5;129;01m.\u001b[39;00mtype_annotation(dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m), A\u001b[38;5;129;01m.\u001b[39;00mdata, A\u001b[38;5;129;01m.\u001b[39;00melem_offset, s1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m1\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m), s1, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrow_major\u001b[39m\u001b[38;5;124m\"\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0)\n", " v1_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Y_shared[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, strides\u001b[38;5;129;01m=\u001b[39;00m[s1_1, s0_1], scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " C_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Y_shared_wmma_matrix_b[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_load_matrix_sync(C_2\u001b[38;5;129;01m.\u001b[39;00mdata, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, C_2\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C_2\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, T\u001b[38;5;129;01m.\u001b[39;00mtvm_access_ptr(T\u001b[38;5;129;01m.\u001b[39;00mtype_annotation(dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m), A_1\u001b[38;5;129;01m.\u001b[39;00mdata, A_1\u001b[38;5;129;01m.\u001b[39;00melem_offset, s1_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m1\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m), s1_1, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcol_major\u001b[39m\u001b[38;5;124m\"\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o_update\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " A_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " B \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " C_3 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_mma_sync(C_3\u001b[38;5;129;01m.\u001b[39;00mdata, C_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, A_2\u001b[38;5;129;01m.\u001b[39;00mdata, A_2\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m A_2\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, B\u001b[38;5;129;01m.\u001b[39;00mdata, B\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m B\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, C_3\u001b[38;5;129;01m.\u001b[39;00mdata, C_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mZ_wmma.accumulator_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0)\n", " v1_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " A_3 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Z_wmma_accumulator[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " C_4 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Z[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, strides\u001b[38;5;129;01m=\u001b[39;00m[s1_2, s0_2], offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_store_matrix_sync(A_3\u001b[38;5;129;01m.\u001b[39;00mdata, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, A_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m A_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, T\u001b[38;5;129;01m.\u001b[39;00mtvm_access_ptr(T\u001b[38;5;129;01m.\u001b[39;00mtype_annotation(dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m), C_4\u001b[38;5;129;01m.\u001b[39;00mdata, C_4\u001b[38;5;129;01m.\u001b[39;00melem_offset, s1_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m2\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m), s1_2, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrow_major\u001b[39m\u001b[38;5;124m\"\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \n", "\n" ] } ], "source": [ "sch.tensorize(sch.get_loops(X_local)[-2], \"wmma_load_a\")\n", "sch.tensorize(sch.get_loops(Y_local)[-2], \"wmma_load_b\")\n", "sch.tensorize(init, \"wmma_fill\")\n", "sch.tensorize(wmma_sync, \"wmma_sync\")\n", "sch.tensorize(sch.get_loops(write_back_block)[-2], \"wmma_store\")\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 构建并评估结果" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Performance: 9443.610543 GFLOPS\n" ] } ], "source": [ "rt_mod = tvm.build(sch.mod, target=\"cuda\")\n", "\n", "dev = tvm.cuda()\n", "num_flop = 1024**3 * 2\n", "A_np = np.random.randn(1024, 1024).astype(\"float16\")\n", "B_np = np.random.randn(1024, 1024).astype(\"float16\")\n", "C_np = A_np.astype(\"float32\") @ (B_np.astype(\"float32\").T)\n", "\n", "A_nd = tvm.nd.array(A_np, dev)\n", "B_nd = tvm.nd.array(B_np, dev)\n", "C_nd = tvm.nd.array(np.empty((1024, 1024), dtype=\"float32\"), dev)\n", "\n", "rt_mod(A_nd, B_nd, C_nd)\n", "np.testing.assert_allclose(C_np, C_nd.numpy(), rtol=1e-3, atol=1e-3)\n", "\n", "evaluator = rt_mod.time_evaluator(\"main\", dev, number=10)\n", "print(\"Performance: %f GFLOPS\" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 讨论\n", "\n", "请考虑如何使这个程序运行得更快?(极限性能在 50T 左右,而这个程序在 RTX-3080 上只有 23T)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 参考文献\n", "\n", "https://developer.nvidia.com/blog/programming-tensor-cores-cuda-9/\n", "\n", "https://tvm.apache.org/docs/how_to/optimize_operators/opt_conv_tensorcore.html" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('py38': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7" } } }, "nbformat": 4, "nbformat_minor": 2 }