{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Mpn1ti5Urdsv" }, "source": [ "# 张量程序抽象\n", "\n", "## 安装包\n", "\n", "提供以下命令来安装 mlc 课程的打包版本。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "qXysoqn-vZuF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in links: https://mlc.ai/wheels\n", "Requirement already satisfied: mlc-ai-nightly in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (0.9.dev1664+g1f3985de0)\n", "Requirement already satisfied: decorator in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.1.1)\n", "Requirement already satisfied: numpy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.23.1)\n", "Requirement already satisfied: attrs in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (21.4.0)\n", "Requirement already satisfied: scipy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.8.1)\n", "Requirement already satisfied: cloudpickle in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.1.0)\n", "Requirement already satisfied: synr==0.6.0 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (0.6.0)\n", "Requirement already satisfied: psutil in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.9.1)\n", "Requirement already satisfied: tornado in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (6.1)\n" ] } ], "source": [ "!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels" ] }, { "cell_type": "markdown", "metadata": { "id": "BBIuE2jc1DaU" }, "source": [ "## 构建张量程序\n", "\n", "从构造执行两个向量的加法的张量程序开始。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "vvfOgcu-YdaB" }, "outputs": [], "source": [ "import tvm\n", "from tvm.ir.module import IRModule\n", "from tvm.script import tir as T\n", "import numpy as np\n", "from IPython.display import Code" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "qCViJNUNYfTW" }, "outputs": [], "source": [ "@tvm.script.ir_module\n", "class MyModule:\n", " @T.prim_func\n", " def main(A: T.Buffer[128, \"float32\"], \n", " B: T.Buffer[128, \"float32\"], \n", " C: T.Buffer[128, \"float32\"]):\n", " # 函数的额外注解\n", " T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " for i in range(128):\n", " with T.block(\"C\"):\n", " # 在 spatial 域中声明数据并行迭代器\n", " vi = T.axis.spatial(128, i)\n", " C[vi] = A[vi] + B[vi]" ] }, { "cell_type": "markdown", "metadata": { "id": "4PJd0Pw8zVQD" }, "source": [ "TVMScript 是用 Python ast 表达张量程序的方式。注意,这段代码实际上并不对应于 python 程序,而是可以在 MLC 进程中使用的张量程序。该语言被设计成与 python 语法保持一致的附加结构,以方便分析和变换。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QKsLAcDB8Npx", "outputId": "8534ef46-c656-4f36-961c-f6e59e04ad6d" }, "outputs": [ { "data": { "text/plain": [ "tvm.ir.module.IRModule" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(MyModule)" ] }, { "cell_type": "markdown", "metadata": { "id": "GpdPoa5q8Sj7" }, "source": [ "`MyModule` 是 **IRModule** 数据结构的实例(用来保存张量函数的集合)。\n", "\n", "可以使用 `script` 函数获得基于 IRModule 表示的字符串。这个函数对于在变换的每个步骤中检查模块非常有用。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VXy-4v3Czax9", "outputId": "c933d1e0-42d5-4df2-ad9a-6eb997deb10c" }, "outputs": [ { "data": { "text/html": [ "
@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
              "        # body\n",
              "        # with tir.block("root")\n",
              "        for i in tir.serial(128):\n",
              "            with tir.block("C"):\n",
              "                vi = tir.axis.spatial(128, i)\n",
              "                tir.reads(A[vi], B[vi])\n",
              "                tir.writes(C[vi])\n",
              "                C[vi] = A[vi] + B[vi]\n",
              "    \n",
              "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{serial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i in tir.serial(128):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Code(MyModule.script(), language=\"python\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tOMSOuW6aIJg" }, "source": [ "### 构建并运行\n", "\n", "在任何时间点,都可以通过调用 `build` 函数将 IRModule 变换为可运行函数。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "oESoqN-xaTCf", "outputId": "58119fbd-b737-400d-bd94-776af0709501" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "rt_mod = tvm.build(MyModule, target=\"llvm\") # CPU 后端的模块\n", "print(type(rt_mod))" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2ZfGrH1z6SV" }, "source": [ "构建后,mod 包含了一组可运行的函数。可以通过函数名检索每个函数。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "5I3GqwnRz-Ne" }, "outputs": [], "source": [ "func = rt_mod[\"main\"]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bngdW1eVl683", "outputId": "d5e0437c-bf16-4107-aa41-e11a4e1865ce" }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "func" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "DKxo8uq_mNlp" }, "outputs": [], "source": [ "a = tvm.nd.array(np.arange(128, dtype=\"float32\"))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "1hAFAqv_mP8W" }, "outputs": [], "source": [ "b = tvm.nd.array(np.ones(128, dtype=\"float32\")) " ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "TseB1UBumivT" }, "outputs": [], "source": [ "c = tvm.nd.empty((128,), dtype=\"float32\") " ] }, { "cell_type": "markdown", "metadata": { "id": "p68xZ0_P0MPw" }, "source": [ "要调用该函数,可以在 tvm 运行时中创建三个 ndarray,然后调用生成的函数。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "SMkcgO-L0Xr5" }, "outputs": [], "source": [ "func(a, b, c)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AakkTpE50b6o", "outputId": "43971a60-2fbb-41bb-cc47-c496bfe5dda2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13.\n", " 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27.\n", " 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41.\n", " 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55.\n", " 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69.\n", " 70. 71. 72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83.\n", " 84. 85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97.\n", " 98. 99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111.\n", " 112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125.\n", " 126. 127.]\n", "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1.]\n", "[ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14.\n", " 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28.\n", " 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42.\n", " 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56.\n", " 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70.\n", " 71. 72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84.\n", " 85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 98.\n", " 99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111. 112.\n", " 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125. 126.\n", " 127. 128.]\n" ] } ], "source": [ "print(a)\n", "print(b)\n", "print(c)" ] }, { "cell_type": "markdown", "metadata": { "id": "i_MIDZCOcmwp" }, "source": [ "## 变换张量程序\n", "\n", "现在开始变换张量程序。张量程序可以使用一种称为调度(schedule)的辅助数据结构进行变换。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xwyjwh51cjWI", "outputId": "d8a8687a-a722-4899-c181-9ca90c7d841e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "sch = tvm.tir.Schedule(MyModule)\n", "print(type(sch))" ] }, { "cell_type": "markdown", "metadata": { "id": "Dw7Fgw8o8HPm" }, "source": [ "先试着划分循环:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kNQf8D0ic4me", "outputId": "1cbfd7f9-a807-4571-b2e3-66ffe052037d" }, "outputs": [ { "data": { "text/html": [ "
@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
              "        # body\n",
              "        # with tir.block("root")\n",
              "        for i_0, i_1, i_2 in tir.grid(8, 4, 4):\n",
              "            with tir.block("C"):\n",
              "                vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
              "                tir.reads(A[vi], B[vi])\n",
              "                tir.writes(C[vi])\n",
              "                C[vi] = A[vi] + B[vi]\n",
              "    \n",
              "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i\\PYZus{}0}\\PY{p}{,} \\PY{n}{i\\PYZus{}1}\\PY{p}{,} \\PY{n}{i\\PYZus{}2} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{grid}\\PY{p}{(}\\PY{l+m+mi}{8}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i\\PYZus{}0} \\PY{o}{*} \\PY{l+m+mi}{16} \\PY{o}{+} \\PY{n}{i\\PYZus{}1} \\PY{o}{*} \\PY{l+m+mi}{4} \\PY{o}{+} \\PY{n}{i\\PYZus{}2}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i_0, i_1, i_2 in tir.grid(8, 4, 4):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 按名称获取块\n", "block_c = sch.get_block(\"C\")\n", "# 获取块周围的循环\n", "(i,) = sch.get_loops(block_c)\n", "# Tile 循环嵌套\n", "i_0, i_1, i_2 = sch.split(i, factors=[None, 4, 4])\n", "Code(sch.mod.script(), language=\"Python\")" ] }, { "cell_type": "markdown", "metadata": { "id": "nzrbvqBSdC-D" }, "source": [ "也可以重新排序循环。把循环 i_2 移到 i_1 的外面。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yJWBq7lRdDmn", "outputId": "1957cf67-f51c-4af1-c5ca-16c9718181a0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i_0, i_2, i_1 in tir.grid(8, 4, 4):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " \n" ] } ], "source": [ "sch.reorder(i_0, i_2, i_1)\n", "print(sch.mod.script())" ] }, { "cell_type": "markdown", "metadata": { "id": "UmUr6b_L07-8" }, "source": [ "最后,可以向程序生成器添加要对最内部循环进行向量化的提示。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "u95zQFuldHs_", "outputId": "b2205442-ac70-405c-8f42-ddb23e46e012" }, "outputs": [ { "data": { "text/html": [ "
@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
              "        # body\n",
              "        # with tir.block("root")\n",
              "        for i_0, i_2, i_1 in tir.grid(8, 4, 4):\n",
              "            with tir.block("C"):\n",
              "                vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
              "                tir.reads(A[vi], B[vi])\n",
              "                tir.writes(C[vi])\n",
              "                C[vi] = A[vi] + B[vi]\n",
              "    \n",
              "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i\\PYZus{}0}\\PY{p}{,} \\PY{n}{i\\PYZus{}2}\\PY{p}{,} \\PY{n}{i\\PYZus{}1} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{grid}\\PY{p}{(}\\PY{l+m+mi}{8}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i\\PYZus{}0} \\PY{o}{*} \\PY{l+m+mi}{16} \\PY{o}{+} \\PY{n}{i\\PYZus{}1} \\PY{o}{*} \\PY{l+m+mi}{4} \\PY{o}{+} \\PY{n}{i\\PYZus{}2}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i_0, i_2, i_1 in tir.grid(8, 4, 4):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Code(sch.mod.script(), language=\"python\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "m7NFPx9fy5Wy" }, "outputs": [ { "data": { "text/html": [ "
@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
              "        # body\n",
              "        # with tir.block("root")\n",
              "        for i_0 in tir.parallel(8):\n",
              "            for i_2, i_1 in tir.grid(4, 4):\n",
              "                with tir.block("C"):\n",
              "                    vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
              "                    tir.reads(A[vi], B[vi])\n",
              "                    tir.writes(C[vi])\n",
              "                    C[vi] = A[vi] + B[vi]\n",
              "    \n",
              "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i\\PYZus{}0} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{parallel}\\PY{p}{(}\\PY{l+m+mi}{8}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{for} \\PY{n}{i\\PYZus{}2}\\PY{p}{,} \\PY{n}{i\\PYZus{}1} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{grid}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i\\PYZus{}0} \\PY{o}{*} \\PY{l+m+mi}{16} \\PY{o}{+} \\PY{n}{i\\PYZus{}1} \\PY{o}{*} \\PY{l+m+mi}{4} \\PY{o}{+} \\PY{n}{i\\PYZus{}2}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i_0 in tir.parallel(8):\n", " for i_2, i_1 in tir.grid(4, 4):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " " ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sch.parallel(i_0)\n", "Code(sch.mod.script(), language=\"python\")" ] }, { "cell_type": "markdown", "metadata": { "id": "OhGlqLTG_tNv" }, "source": [ "可以构建并运行变换后的程序" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "sCIYSDrI_wGq" }, "outputs": [], "source": [ "transformed_mod = tvm.build(sch.mod, target=\"llvm\") # The module for CPU backends.\n", "transformed_mod[\"main\"](a, b, c)" ] }, { "cell_type": "markdown", "metadata": { "id": "wj_01P4mAfu2" }, "source": [ "## 使用张量表达式构建张量程序\n", "\n", "在前面的例子中,直接使用 TVMScript 来构造张量程序。在实践中,从现有定义实用地构造这些函数通常是有帮助的。张量表达式是一个 API,它可以帮助构建一些类似表达式的数组计算。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TZAPcqbGAesY", "outputId": "ac4d6407-8dea-4c75-d166-e071ffee8783" }, "outputs": [ { "data": { "text/html": [ "
@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, "float32"], B: tir.Buffer[128, "float32"], C: tir.Buffer[128, "float32"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({"global_symbol": "main", "tir.noalias": True})\n",
              "        # body\n",
              "        # with tir.block("root")\n",
              "        for i0 in tir.serial(128):\n",
              "            with tir.block("C"):\n",
              "                i = tir.axis.spatial(128, i0)\n",
              "                tir.reads(A[i], B[i])\n",
              "                tir.writes(C[i])\n",
              "                C[i] = A[i] + B[i]\n",
              "    \n",
              "
\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i0} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{serial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{i} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i0}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i0 in tir.serial(128):\n", " with tir.block(\"C\"):\n", " i = tir.axis.spatial(128, i0)\n", " tir.reads(A[i], B[i])\n", " tir.writes(C[i])\n", " C[i] = A[i] + B[i]\n", " " ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# namespace for tensor expression utility\n", "from tvm import te\n", "\n", "# declare the computation using the expression API\n", "A = te.placeholder((128, ), name=\"A\")\n", "B = te.placeholder((128, ), name=\"B\")\n", "C = te.compute((128,), lambda i: A[i] + B[i], name=\"C\")\n", "\n", "# create a function with the specified list of arguments. \n", "func = te.create_prim_func([A, B, C])\n", "# mark that the function name is main\n", "func = func.with_attr(\"global_symbol\", \"main\")\n", "ir_mod_from_te = IRModule({\"main\": func})\n", "\n", "Code(ir_mod_from_te.script(), language=\"python\")" ] }, { "cell_type": "markdown", "metadata": { "id": "GEqpO14Lf0Lq" }, "source": [ "## 变换矩阵乘法程序\n", "\n", "在上面的例子中,展示了如何变换向量加法。现在试着把它应用到稍微复杂一点的程序中(矩阵乘法)。首先尝试使用张量表达式 API 构建初始代码。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "R9ExHE3BfYYv", "outputId": "0d82d527-8051-4bc3-c3a7-0cbb12d9bcce" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[(1024, 1024), \"float32\"], B: tir.Buffer[(1024, 1024), \"float32\"], C: tir.Buffer[(1024, 1024), \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i0, i1, i2 in tir.grid(1024, 1024, 1024):\n", " with tir.block(\"C\"):\n", " m, n, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n", " tir.reads(A[m, k], B[k, n])\n", " tir.writes(C[m, n])\n", " with tir.init():\n", " C[m, n] = tir.float32(0)\n", " C[m, n] = C[m, n] + A[m, k] * B[k, n]\n", " \n", "Baseline: 2.238068\n" ] } ], "source": [ "from tvm import te\n", "\n", "M = 1024\n", "K = 1024\n", "N = 1024\n", "\n", "# The default tensor type in tvm\n", "dtype = \"float32\"\n", "\n", "target = \"llvm\"\n", "dev = tvm.device(target, 0)\n", "\n", "# Algorithm\n", "k = te.reduce_axis((0, K), \"k\")\n", "A = te.placeholder((M, K), name=\"A\")\n", "B = te.placeholder((K, N), name=\"B\")\n", "C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name=\"C\")\n", "\n", "# Default schedule\n", "func = te.create_prim_func([A, B, C])\n", "func = func.with_attr(\"global_symbol\", \"main\")\n", "ir_module = IRModule({\"main\": func})\n", "print(Code(ir_module.script(), language=\"python\"))\n", "\n", "\n", "func = tvm.build(ir_module, target=\"llvm\") # The module for CPU backends.\n", "\n", "a = tvm.nd.array(np.random.rand(M, K).astype(dtype), dev)\n", "b = tvm.nd.array(np.random.rand(K, N).astype(dtype), dev)\n", "c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)\n", "func(a, b, c)\n", "\n", "evaluator = func.time_evaluator(func.entry_name, dev, number=1)\n", "print(\"Baseline: %f\" % evaluator(a, b, c).mean)" ] }, { "cell_type": "markdown", "metadata": { "id": "swj-gMz-1vBE" }, "source": [ "可以变换循环访问模式,使其对缓存更友好。让我们使用下面的调度。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "W60q68KRgdNL", "outputId": "b49a101e-5148-4cf0-df88-0e112e741381" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "after transformation: 0.239306\n" ] } ], "source": [ "sch = tvm.tir.Schedule(ir_module)\n", "print(type(sch))\n", "block_c = sch.get_block(\"C\")\n", "# Get loops surronding the block\n", "(y, x, k) = sch.get_loops(block_c)\n", "block_size = 32\n", "yo, yi = sch.split(y, [None, block_size])\n", "xo, xi = sch.split(x, [None, block_size])\n", "\n", "sch.reorder(yo, xo, k, yi, xi)\n", "Code(sch.mod.script(), language=\"python\")\n", "\n", "func = tvm.build(sch.mod, target=\"llvm\") # The module for CPU backends.\n", "\n", "c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)\n", "func(a, b, c)\n", "\n", "evaluator = func.time_evaluator(func.entry_name, dev, number=1)\n", "print(\"after transformation: %f\" % evaluator(a, b, c).mean)" ] }, { "cell_type": "markdown", "metadata": { "id": "h1RQGOBjn4w_" }, "source": [ "试着改变 bn 的值,看看你能得到什么性能。在实践中,将利用自动系统来搜索一组可能的变换,以找到最优的变换。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "2-tensor-program-abstraction.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3.10.4 ('mlc': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "vscode": { "interpreter": { "hash": "d8a760899c905ec5a15e0d212432af25d7f0b614c7ae634224dffa77837bb03c" } } }, "nbformat": 4, "nbformat_minor": 0 }