{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Mpn1ti5Urdsv" }, "source": [ "# 张量程序抽象\n", "\n", "## 安装包\n", "\n", "提供以下命令来安装 mlc 课程的打包版本。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "qXysoqn-vZuF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in links: https://mlc.ai/wheels\n", "Requirement already satisfied: mlc-ai-nightly in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (0.9.dev1664+g1f3985de0)\n", "Requirement already satisfied: decorator in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.1.1)\n", "Requirement already satisfied: numpy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.23.1)\n", "Requirement already satisfied: attrs in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (21.4.0)\n", "Requirement already satisfied: scipy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.8.1)\n", "Requirement already satisfied: cloudpickle in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.1.0)\n", "Requirement already satisfied: synr==0.6.0 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (0.6.0)\n", "Requirement already satisfied: psutil in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.9.1)\n", "Requirement already satisfied: tornado in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (6.1)\n" ] } ], "source": [ "!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels" ] }, { "cell_type": "markdown", "metadata": { "id": "BBIuE2jc1DaU" }, "source": [ "## 构建张量程序\n", "\n", "从构造执行两个向量的加法的张量程序开始。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "vvfOgcu-YdaB" }, "outputs": [], "source": [ "import tvm\n", "from tvm.ir.module import IRModule\n", "from tvm.script import tir as T\n", "import numpy as np\n", "from IPython.display import Code" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "qCViJNUNYfTW" }, "outputs": [], "source": [ "@tvm.script.ir_module\n", "class MyModule:\n", " @T.prim_func\n", " def main(A: T.Buffer[128, \"float32\"], \n", " B: T.Buffer[128, \"float32\"], \n", " C: T.Buffer[128, \"float32\"]):\n", " # 函数的额外注解\n", " T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " for i in range(128):\n", " with T.block(\"C\"):\n", " # 在 spatial 域中声明数据并行迭代器\n", " vi = T.axis.spatial(128, i)\n", " C[vi] = A[vi] + B[vi]" ] }, { "cell_type": "markdown", "metadata": { "id": "4PJd0Pw8zVQD" }, "source": [ "TVMScript 是用 Python ast 表达张量程序的方式。注意,这段代码实际上并不对应于 python 程序,而是可以在 MLC 进程中使用的张量程序。该语言被设计成与 python 语法保持一致的附加结构,以方便分析和变换。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QKsLAcDB8Npx", "outputId": "8534ef46-c656-4f36-961c-f6e59e04ad6d" }, "outputs": [ { "data": { "text/plain": [ "tvm.ir.module.IRModule" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(MyModule)" ] }, { "cell_type": "markdown", "metadata": { "id": "GpdPoa5q8Sj7" }, "source": [ "`MyModule` 是 **IRModule** 数据结构的实例(用来保存张量函数的集合)。\n", "\n", "可以使用 `script` 函数获得基于 IRModule 表示的字符串。这个函数对于在变换的每个步骤中检查模块非常有用。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VXy-4v3Czax9", "outputId": "c933d1e0-42d5-4df2-ad9a-6eb997deb10c" }, "outputs": [ { "data": { "text/html": [ "
@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with tir.block("root")\n",
" for i in tir.serial(128):\n",
" with tir.block("C"):\n",
" vi = tir.axis.spatial(128, i)\n",
" tir.reads(A[vi], B[vi])\n",
" tir.writes(C[vi])\n",
" C[vi] = A[vi] + B[vi]\n",
" \n",
"
@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with tir.block("root")\n",
" for i_0, i_1, i_2 in tir.grid(8, 4, 4):\n",
" with tir.block("C"):\n",
" vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
" tir.reads(A[vi], B[vi])\n",
" tir.writes(C[vi])\n",
" C[vi] = A[vi] + B[vi]\n",
" \n",
"
@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with tir.block("root")\n",
" for i_0, i_2, i_1 in tir.grid(8, 4, 4):\n",
" with tir.block("C"):\n",
" vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
" tir.reads(A[vi], B[vi])\n",
" tir.writes(C[vi])\n",
" C[vi] = A[vi] + B[vi]\n",
" \n",
"
@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with tir.block("root")\n",
" for i_0 in tir.parallel(8):\n",
" for i_2, i_1 in tir.grid(4, 4):\n",
" with tir.block("C"):\n",
" vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
" tir.reads(A[vi], B[vi])\n",
" tir.writes(C[vi])\n",
" C[vi] = A[vi] + B[vi]\n",
" \n",
"
@tvm.script.ir_module\n",
"class Module:\n",
" @tir.prim_func\n",
" def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
" # function attr dict\n",
" tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
" # body\n",
" # with tir.block("root")\n",
" for i0 in tir.serial(128):\n",
" with tir.block("C"):\n",
" i = tir.axis.spatial(128, i0)\n",
" tir.reads(A[i], B[i])\n",
" tir.writes(C[i])\n",
" C[i] = A[i] + B[i]\n",
" \n",
"