{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Mpn1ti5Urdsv" }, "source": [ "# 张量程序抽象\n", "\n", "## 安装包\n", "\n", "提供以下命令来安装 mlc 课程的打包版本。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "qXysoqn-vZuF" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in links: https://mlc.ai/wheels\n", "Requirement already satisfied: mlc-ai-nightly in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (0.9.dev1664+g1f3985de0)\n", "Requirement already satisfied: decorator in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.1.1)\n", "Requirement already satisfied: numpy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.23.1)\n", "Requirement already satisfied: attrs in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (21.4.0)\n", "Requirement already satisfied: scipy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.8.1)\n", "Requirement already satisfied: cloudpickle in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.1.0)\n", "Requirement already satisfied: synr==0.6.0 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (0.6.0)\n", "Requirement already satisfied: psutil in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.9.1)\n", "Requirement already satisfied: tornado in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (6.1)\n" ] } ], "source": [ "!python3 -m pip install mlc-ai-nightly -f https://mlc.ai/wheels" ] }, { "cell_type": "markdown", "metadata": { "id": "BBIuE2jc1DaU" }, "source": [ "## 构建张量程序\n", "\n", "从构造执行两个向量的加法的张量程序开始。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "vvfOgcu-YdaB" }, "outputs": [], "source": [ "import tvm\n", "from tvm.ir.module import IRModule\n", "from tvm.script import tir as T\n", "import numpy as np\n", "from IPython.display import Code" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "qCViJNUNYfTW" }, "outputs": [], "source": [ "@tvm.script.ir_module\n", "class MyModule:\n", " @T.prim_func\n", " def main(A: T.Buffer[128, \"float32\"], \n", " B: T.Buffer[128, \"float32\"], \n", " C: T.Buffer[128, \"float32\"]):\n", " # 函数的额外注解\n", " T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " for i in range(128):\n", " with T.block(\"C\"):\n", " # 在 spatial 域中声明数据并行迭代器\n", " vi = T.axis.spatial(128, i)\n", " C[vi] = A[vi] + B[vi]" ] }, { "cell_type": "markdown", "metadata": { "id": "4PJd0Pw8zVQD" }, "source": [ "TVMScript 是用 Python ast 表达张量程序的方式。注意,这段代码实际上并不对应于 python 程序,而是可以在 MLC 进程中使用的张量程序。该语言被设计成与 python 语法保持一致的附加结构,以方便分析和变换。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QKsLAcDB8Npx", "outputId": "8534ef46-c656-4f36-961c-f6e59e04ad6d" }, "outputs": [ { "data": { "text/plain": [ "tvm.ir.module.IRModule" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(MyModule)" ] }, { "cell_type": "markdown", "metadata": { "id": "GpdPoa5q8Sj7" }, "source": [ "`MyModule` 是 **IRModule** 数据结构的实例(用来保存张量函数的集合)。\n", "\n", "可以使用 `script` 函数获得基于 IRModule 表示的字符串。这个函数对于在变换的每个步骤中检查模块非常有用。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VXy-4v3Czax9", "outputId": "c933d1e0-42d5-4df2-ad9a-6eb997deb10c" }, "outputs": [ { "data": { "text/html": [ "<style>pre { line-height: 125%; }\n", "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", ".output_html .hll { background-color: #ffffcc }\n", ".output_html { background: #f8f8f8; }\n", ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n", ".output_html .err { border: 1px solid #FF0000 } /* Error */\n", ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n", ".output_html .o { color: #666666 } /* Operator */\n", ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n", ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n", ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n", ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n", ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n", ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n", ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n", ".output_html .ge { font-style: italic } /* Generic.Emph */\n", ".output_html .gr { color: #E40000 } /* Generic.Error */\n", ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n", ".output_html .gi { color: #008400 } /* Generic.Inserted */\n", ".output_html .go { color: #717171 } /* Generic.Output */\n", ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n", ".output_html .gs { font-weight: bold } /* Generic.Strong */\n", ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n", ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n", ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n", ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n", ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n", ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n", ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n", ".output_html .kt { color: #B00040 } /* Keyword.Type */\n", ".output_html .m { color: #666666 } /* Literal.Number */\n", ".output_html .s { color: #BA2121 } /* Literal.String */\n", ".output_html .na { color: #687822 } /* Name.Attribute */\n", ".output_html .nb { color: #008000 } /* Name.Builtin */\n", ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n", ".output_html .no { color: #880000 } /* Name.Constant */\n", ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n", ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n", ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n", ".output_html .nf { color: #0000FF } /* Name.Function */\n", ".output_html .nl { color: #767600 } /* Name.Label */\n", ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n", ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n", ".output_html .nv { color: #19177C } /* Name.Variable */\n", ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n", ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n", ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n", ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n", ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n", ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n", ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n", ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n", ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n", ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n", ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n", ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n", ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n", ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n", ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n", ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n", ".output_html .sx { color: #008000 } /* Literal.String.Other */\n", ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n", ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n", ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n", ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n", ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n", ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n", ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n", ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n", ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n", ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n", "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n", " <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n", " <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">])</span> <span class=\"o\">-></span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n", " <span class=\"c1\"># function attr dict</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">"global_symbol"</span><span class=\"p\">:</span> <span class=\"s2\">"main"</span><span class=\"p\">,</span> <span class=\"s2\">"tir.noalias"</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n", " <span class=\"c1\"># body</span>\n", " <span class=\"c1\"># with tir.block("root")</span>\n", " <span class=\"k\">for</span> <span class=\"n\">i</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">serial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">):</span>\n", " <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">"C"</span><span class=\"p\">):</span>\n", " <span class=\"n\">vi</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i</span><span class=\"p\">)</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n", " <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span>\n", " \n", "</pre></div>\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{serial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i in tir.serial(128):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Code(MyModule.script(), language=\"python\")" ] }, { "cell_type": "markdown", "metadata": { "id": "tOMSOuW6aIJg" }, "source": [ "### 构建并运行\n", "\n", "在任何时间点,都可以通过调用 `build` 函数将 IRModule 变换为可运行函数。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "oESoqN-xaTCf", "outputId": "58119fbd-b737-400d-bd94-776af0709501" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<class 'tvm.driver.build_module.OperatorModule'>\n" ] } ], "source": [ "rt_mod = tvm.build(MyModule, target=\"llvm\") # CPU 后端的模块\n", "print(type(rt_mod))" ] }, { "cell_type": "markdown", "metadata": { "id": "Y2ZfGrH1z6SV" }, "source": [ "构建后,mod 包含了一组可运行的函数。可以通过函数名检索每个函数。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "5I3GqwnRz-Ne" }, "outputs": [], "source": [ "func = rt_mod[\"main\"]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bngdW1eVl683", "outputId": "d5e0437c-bf16-4107-aa41-e11a4e1865ce" }, "outputs": [ { "data": { "text/plain": [ "<tvm.runtime.packed_func.PackedFunc at 0x7fdd0e398980>" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "func" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "DKxo8uq_mNlp" }, "outputs": [], "source": [ "a = tvm.nd.array(np.arange(128, dtype=\"float32\"))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "1hAFAqv_mP8W" }, "outputs": [], "source": [ "b = tvm.nd.array(np.ones(128, dtype=\"float32\")) " ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "TseB1UBumivT" }, "outputs": [], "source": [ "c = tvm.nd.empty((128,), dtype=\"float32\") " ] }, { "cell_type": "markdown", "metadata": { "id": "p68xZ0_P0MPw" }, "source": [ "要调用该函数,可以在 tvm 运行时中创建三个 ndarray,然后调用生成的函数。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "SMkcgO-L0Xr5" }, "outputs": [], "source": [ "func(a, b, c)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AakkTpE50b6o", "outputId": "43971a60-2fbb-41bb-cc47-c496bfe5dda2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13.\n", " 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27.\n", " 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41.\n", " 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55.\n", " 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69.\n", " 70. 71. 72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83.\n", " 84. 85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97.\n", " 98. 99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111.\n", " 112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125.\n", " 126. 127.]\n", "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1. 1. 1. 1. 1. 1. 1.]\n", "[ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14.\n", " 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28.\n", " 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42.\n", " 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56.\n", " 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70.\n", " 71. 72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84.\n", " 85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 98.\n", " 99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111. 112.\n", " 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125. 126.\n", " 127. 128.]\n" ] } ], "source": [ "print(a)\n", "print(b)\n", "print(c)" ] }, { "cell_type": "markdown", "metadata": { "id": "i_MIDZCOcmwp" }, "source": [ "## 变换张量程序\n", "\n", "现在开始变换张量程序。张量程序可以使用一种称为调度(schedule)的辅助数据结构进行变换。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xwyjwh51cjWI", "outputId": "d8a8687a-a722-4899-c181-9ca90c7d841e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<class 'tvm.tir.schedule.schedule.Schedule'>\n" ] } ], "source": [ "sch = tvm.tir.Schedule(MyModule)\n", "print(type(sch))" ] }, { "cell_type": "markdown", "metadata": { "id": "Dw7Fgw8o8HPm" }, "source": [ "先试着划分循环:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kNQf8D0ic4me", "outputId": "1cbfd7f9-a807-4571-b2e3-66ffe052037d" }, "outputs": [ { "data": { "text/html": [ "<style>pre { line-height: 125%; }\n", "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", ".output_html .hll { background-color: #ffffcc }\n", ".output_html { background: #f8f8f8; }\n", ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n", ".output_html .err { border: 1px solid #FF0000 } /* Error */\n", ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n", ".output_html .o { color: #666666 } /* Operator */\n", ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n", ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n", ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n", ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n", ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n", ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n", ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n", ".output_html .ge { font-style: italic } /* Generic.Emph */\n", ".output_html .gr { color: #E40000 } /* Generic.Error */\n", ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n", ".output_html .gi { color: #008400 } /* Generic.Inserted */\n", ".output_html .go { color: #717171 } /* Generic.Output */\n", ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n", ".output_html .gs { font-weight: bold } /* Generic.Strong */\n", ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n", ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n", ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n", ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n", ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n", ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n", ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n", ".output_html .kt { color: #B00040 } /* Keyword.Type */\n", ".output_html .m { color: #666666 } /* Literal.Number */\n", ".output_html .s { color: #BA2121 } /* Literal.String */\n", ".output_html .na { color: #687822 } /* Name.Attribute */\n", ".output_html .nb { color: #008000 } /* Name.Builtin */\n", ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n", ".output_html .no { color: #880000 } /* Name.Constant */\n", ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n", ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n", ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n", ".output_html .nf { color: #0000FF } /* Name.Function */\n", ".output_html .nl { color: #767600 } /* Name.Label */\n", ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n", ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n", ".output_html .nv { color: #19177C } /* Name.Variable */\n", ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n", ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n", ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n", ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n", ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n", ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n", ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n", ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n", ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n", ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n", ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n", ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n", ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n", ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n", ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n", ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n", ".output_html .sx { color: #008000 } /* Literal.String.Other */\n", ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n", ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n", ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n", ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n", ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n", ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n", ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n", ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n", ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n", ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n", "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n", " <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n", " <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">])</span> <span class=\"o\">-></span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n", " <span class=\"c1\"># function attr dict</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">"global_symbol"</span><span class=\"p\">:</span> <span class=\"s2\">"main"</span><span class=\"p\">,</span> <span class=\"s2\">"tir.noalias"</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n", " <span class=\"c1\"># body</span>\n", " <span class=\"c1\"># with tir.block("root")</span>\n", " <span class=\"k\">for</span> <span class=\"n\">i_0</span><span class=\"p\">,</span> <span class=\"n\">i_1</span><span class=\"p\">,</span> <span class=\"n\">i_2</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">grid</span><span class=\"p\">(</span><span class=\"mi\">8</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">):</span>\n", " <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">"C"</span><span class=\"p\">):</span>\n", " <span class=\"n\">vi</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i_0</span> <span class=\"o\">*</span> <span class=\"mi\">16</span> <span class=\"o\">+</span> <span class=\"n\">i_1</span> <span class=\"o\">*</span> <span class=\"mi\">4</span> <span class=\"o\">+</span> <span class=\"n\">i_2</span><span class=\"p\">)</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n", " <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span>\n", " \n", "</pre></div>\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i\\PYZus{}0}\\PY{p}{,} \\PY{n}{i\\PYZus{}1}\\PY{p}{,} \\PY{n}{i\\PYZus{}2} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{grid}\\PY{p}{(}\\PY{l+m+mi}{8}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i\\PYZus{}0} \\PY{o}{*} \\PY{l+m+mi}{16} \\PY{o}{+} \\PY{n}{i\\PYZus{}1} \\PY{o}{*} \\PY{l+m+mi}{4} \\PY{o}{+} \\PY{n}{i\\PYZus{}2}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i_0, i_1, i_2 in tir.grid(8, 4, 4):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " " ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 按名称获取块\n", "block_c = sch.get_block(\"C\")\n", "# 获取块周围的循环\n", "(i,) = sch.get_loops(block_c)\n", "# Tile 循环嵌套\n", "i_0, i_1, i_2 = sch.split(i, factors=[None, 4, 4])\n", "Code(sch.mod.script(), language=\"Python\")" ] }, { "cell_type": "markdown", "metadata": { "id": "nzrbvqBSdC-D" }, "source": [ "也可以重新排序循环。把循环 i_2 移到 i_1 的外面。" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yJWBq7lRdDmn", "outputId": "1957cf67-f51c-4af1-c5ca-16c9718181a0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i_0, i_2, i_1 in tir.grid(8, 4, 4):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " \n" ] } ], "source": [ "sch.reorder(i_0, i_2, i_1)\n", "print(sch.mod.script())" ] }, { "cell_type": "markdown", "metadata": { "id": "UmUr6b_L07-8" }, "source": [ "最后,可以向程序生成器添加要对最内部循环进行向量化的提示。" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "u95zQFuldHs_", "outputId": "b2205442-ac70-405c-8f42-ddb23e46e012" }, "outputs": [ { "data": { "text/html": [ "<style>pre { line-height: 125%; }\n", "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", ".output_html .hll { background-color: #ffffcc }\n", ".output_html { background: #f8f8f8; }\n", ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n", ".output_html .err { border: 1px solid #FF0000 } /* Error */\n", ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n", ".output_html .o { color: #666666 } /* Operator */\n", ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n", ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n", ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n", ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n", ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n", ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n", ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n", ".output_html .ge { font-style: italic } /* Generic.Emph */\n", ".output_html .gr { color: #E40000 } /* Generic.Error */\n", ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n", ".output_html .gi { color: #008400 } /* Generic.Inserted */\n", ".output_html .go { color: #717171 } /* Generic.Output */\n", ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n", ".output_html .gs { font-weight: bold } /* Generic.Strong */\n", ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n", ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n", ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n", ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n", ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n", ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n", ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n", ".output_html .kt { color: #B00040 } /* Keyword.Type */\n", ".output_html .m { color: #666666 } /* Literal.Number */\n", ".output_html .s { color: #BA2121 } /* Literal.String */\n", ".output_html .na { color: #687822 } /* Name.Attribute */\n", ".output_html .nb { color: #008000 } /* Name.Builtin */\n", ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n", ".output_html .no { color: #880000 } /* Name.Constant */\n", ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n", ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n", ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n", ".output_html .nf { color: #0000FF } /* Name.Function */\n", ".output_html .nl { color: #767600 } /* Name.Label */\n", ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n", ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n", ".output_html .nv { color: #19177C } /* Name.Variable */\n", ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n", ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n", ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n", ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n", ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n", ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n", ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n", ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n", ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n", ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n", ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n", ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n", ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n", ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n", ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n", ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n", ".output_html .sx { color: #008000 } /* Literal.String.Other */\n", ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n", ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n", ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n", ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n", ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n", ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n", ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n", ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n", ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n", ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n", "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n", " <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n", " <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">])</span> <span class=\"o\">-></span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n", " <span class=\"c1\"># function attr dict</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">"global_symbol"</span><span class=\"p\">:</span> <span class=\"s2\">"main"</span><span class=\"p\">,</span> <span class=\"s2\">"tir.noalias"</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n", " <span class=\"c1\"># body</span>\n", " <span class=\"c1\"># with tir.block("root")</span>\n", " <span class=\"k\">for</span> <span class=\"n\">i_0</span><span class=\"p\">,</span> <span class=\"n\">i_2</span><span class=\"p\">,</span> <span class=\"n\">i_1</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">grid</span><span class=\"p\">(</span><span class=\"mi\">8</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">):</span>\n", " <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">"C"</span><span class=\"p\">):</span>\n", " <span class=\"n\">vi</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i_0</span> <span class=\"o\">*</span> <span class=\"mi\">16</span> <span class=\"o\">+</span> <span class=\"n\">i_1</span> <span class=\"o\">*</span> <span class=\"mi\">4</span> <span class=\"o\">+</span> <span class=\"n\">i_2</span><span class=\"p\">)</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n", " <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span>\n", " \n", "</pre></div>\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i\\PYZus{}0}\\PY{p}{,} \\PY{n}{i\\PYZus{}2}\\PY{p}{,} \\PY{n}{i\\PYZus{}1} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{grid}\\PY{p}{(}\\PY{l+m+mi}{8}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i\\PYZus{}0} \\PY{o}{*} \\PY{l+m+mi}{16} \\PY{o}{+} \\PY{n}{i\\PYZus{}1} \\PY{o}{*} \\PY{l+m+mi}{4} \\PY{o}{+} \\PY{n}{i\\PYZus{}2}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i_0, i_2, i_1 in tir.grid(8, 4, 4):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Code(sch.mod.script(), language=\"python\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "m7NFPx9fy5Wy" }, "outputs": [ { "data": { "text/html": [ "<style>pre { line-height: 125%; }\n", "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", ".output_html .hll { background-color: #ffffcc }\n", ".output_html { background: #f8f8f8; }\n", ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n", ".output_html .err { border: 1px solid #FF0000 } /* Error */\n", ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n", ".output_html .o { color: #666666 } /* Operator */\n", ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n", ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n", ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n", ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n", ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n", ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n", ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n", ".output_html .ge { font-style: italic } /* Generic.Emph */\n", ".output_html .gr { color: #E40000 } /* Generic.Error */\n", ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n", ".output_html .gi { color: #008400 } /* Generic.Inserted */\n", ".output_html .go { color: #717171 } /* Generic.Output */\n", ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n", ".output_html .gs { font-weight: bold } /* Generic.Strong */\n", ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n", ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n", ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n", ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n", ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n", ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n", ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n", ".output_html .kt { color: #B00040 } /* Keyword.Type */\n", ".output_html .m { color: #666666 } /* Literal.Number */\n", ".output_html .s { color: #BA2121 } /* Literal.String */\n", ".output_html .na { color: #687822 } /* Name.Attribute */\n", ".output_html .nb { color: #008000 } /* Name.Builtin */\n", ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n", ".output_html .no { color: #880000 } /* Name.Constant */\n", ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n", ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n", ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n", ".output_html .nf { color: #0000FF } /* Name.Function */\n", ".output_html .nl { color: #767600 } /* Name.Label */\n", ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n", ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n", ".output_html .nv { color: #19177C } /* Name.Variable */\n", ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n", ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n", ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n", ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n", ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n", ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n", ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n", ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n", ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n", ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n", ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n", ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n", ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n", ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n", ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n", ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n", ".output_html .sx { color: #008000 } /* Literal.String.Other */\n", ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n", ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n", ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n", ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n", ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n", ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n", ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n", ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n", ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n", ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n", "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n", " <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n", " <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">])</span> <span class=\"o\">-></span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n", " <span class=\"c1\"># function attr dict</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">"global_symbol"</span><span class=\"p\">:</span> <span class=\"s2\">"main"</span><span class=\"p\">,</span> <span class=\"s2\">"tir.noalias"</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n", " <span class=\"c1\"># body</span>\n", " <span class=\"c1\"># with tir.block("root")</span>\n", " <span class=\"k\">for</span> <span class=\"n\">i_0</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">parallel</span><span class=\"p\">(</span><span class=\"mi\">8</span><span class=\"p\">):</span>\n", " <span class=\"k\">for</span> <span class=\"n\">i_2</span><span class=\"p\">,</span> <span class=\"n\">i_1</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">grid</span><span class=\"p\">(</span><span class=\"mi\">4</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">):</span>\n", " <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">"C"</span><span class=\"p\">):</span>\n", " <span class=\"n\">vi</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i_0</span> <span class=\"o\">*</span> <span class=\"mi\">16</span> <span class=\"o\">+</span> <span class=\"n\">i_1</span> <span class=\"o\">*</span> <span class=\"mi\">4</span> <span class=\"o\">+</span> <span class=\"n\">i_2</span><span class=\"p\">)</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n", " <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span>\n", " \n", "</pre></div>\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i\\PYZus{}0} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{parallel}\\PY{p}{(}\\PY{l+m+mi}{8}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{for} \\PY{n}{i\\PYZus{}2}\\PY{p}{,} \\PY{n}{i\\PYZus{}1} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{grid}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i\\PYZus{}0} \\PY{o}{*} \\PY{l+m+mi}{16} \\PY{o}{+} \\PY{n}{i\\PYZus{}1} \\PY{o}{*} \\PY{l+m+mi}{4} \\PY{o}{+} \\PY{n}{i\\PYZus{}2}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i_0 in tir.parallel(8):\n", " for i_2, i_1 in tir.grid(4, 4):\n", " with tir.block(\"C\"):\n", " vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n", " tir.reads(A[vi], B[vi])\n", " tir.writes(C[vi])\n", " C[vi] = A[vi] + B[vi]\n", " " ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sch.parallel(i_0)\n", "Code(sch.mod.script(), language=\"python\")" ] }, { "cell_type": "markdown", "metadata": { "id": "OhGlqLTG_tNv" }, "source": [ "可以构建并运行变换后的程序" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "sCIYSDrI_wGq" }, "outputs": [], "source": [ "transformed_mod = tvm.build(sch.mod, target=\"llvm\") # The module for CPU backends.\n", "transformed_mod[\"main\"](a, b, c)" ] }, { "cell_type": "markdown", "metadata": { "id": "wj_01P4mAfu2" }, "source": [ "## 使用张量表达式构建张量程序\n", "\n", "在前面的例子中,直接使用 TVMScript 来构造张量程序。在实践中,从现有定义实用地构造这些函数通常是有帮助的。张量表达式是一个 API,它可以帮助构建一些类似表达式的数组计算。" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TZAPcqbGAesY", "outputId": "ac4d6407-8dea-4c75-d166-e071ffee8783" }, "outputs": [ { "data": { "text/html": [ "<style>pre { line-height: 125%; }\n", "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n", "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n", ".output_html .hll { background-color: #ffffcc }\n", ".output_html { background: #f8f8f8; }\n", ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n", ".output_html .err { border: 1px solid #FF0000 } /* Error */\n", ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n", ".output_html .o { color: #666666 } /* Operator */\n", ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n", ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n", ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n", ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n", ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n", ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n", ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n", ".output_html .ge { font-style: italic } /* Generic.Emph */\n", ".output_html .gr { color: #E40000 } /* Generic.Error */\n", ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n", ".output_html .gi { color: #008400 } /* Generic.Inserted */\n", ".output_html .go { color: #717171 } /* Generic.Output */\n", ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n", ".output_html .gs { font-weight: bold } /* Generic.Strong */\n", ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n", ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n", ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n", ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n", ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n", ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n", ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n", ".output_html .kt { color: #B00040 } /* Keyword.Type */\n", ".output_html .m { color: #666666 } /* Literal.Number */\n", ".output_html .s { color: #BA2121 } /* Literal.String */\n", ".output_html .na { color: #687822 } /* Name.Attribute */\n", ".output_html .nb { color: #008000 } /* Name.Builtin */\n", ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n", ".output_html .no { color: #880000 } /* Name.Constant */\n", ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n", ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n", ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n", ".output_html .nf { color: #0000FF } /* Name.Function */\n", ".output_html .nl { color: #767600 } /* Name.Label */\n", ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n", ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n", ".output_html .nv { color: #19177C } /* Name.Variable */\n", ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n", ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n", ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n", ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n", ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n", ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n", ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n", ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n", ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n", ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n", ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n", ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n", ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n", ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n", ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n", ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n", ".output_html .sx { color: #008000 } /* Literal.String.Other */\n", ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n", ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n", ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n", ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n", ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n", ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n", ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n", ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n", ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n", ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n", "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n", " <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n", " <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">"float32"</span><span class=\"p\">])</span> <span class=\"o\">-></span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n", " <span class=\"c1\"># function attr dict</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">"global_symbol"</span><span class=\"p\">:</span> <span class=\"s2\">"main"</span><span class=\"p\">,</span> <span class=\"s2\">"tir.noalias"</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n", " <span class=\"c1\"># body</span>\n", " <span class=\"c1\"># with tir.block("root")</span>\n", " <span class=\"k\">for</span> <span class=\"n\">i0</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">serial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">):</span>\n", " <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">"C"</span><span class=\"p\">):</span>\n", " <span class=\"n\">i</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i0</span><span class=\"p\">)</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">])</span>\n", " <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">])</span>\n", " <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">]</span>\n", " \n", "</pre></div>\n" ], "text/latex": [ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n", "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n", " \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n", " \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n", " \\PY{c+c1}{\\PYZsh{} function attr dict}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n", " \\PY{c+c1}{\\PYZsh{} body}\n", " \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n", " \\PY{k}{for} \\PY{n}{i0} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{serial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n", " \\PY{n}{i} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i0}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{)}\n", " \\PY{n}{C}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\n", " \n", "\\end{Verbatim}\n" ], "text/plain": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i0 in tir.serial(128):\n", " with tir.block(\"C\"):\n", " i = tir.axis.spatial(128, i0)\n", " tir.reads(A[i], B[i])\n", " tir.writes(C[i])\n", " C[i] = A[i] + B[i]\n", " " ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# namespace for tensor expression utility\n", "from tvm import te\n", "\n", "# declare the computation using the expression API\n", "A = te.placeholder((128, ), name=\"A\")\n", "B = te.placeholder((128, ), name=\"B\")\n", "C = te.compute((128,), lambda i: A[i] + B[i], name=\"C\")\n", "\n", "# create a function with the specified list of arguments. \n", "func = te.create_prim_func([A, B, C])\n", "# mark that the function name is main\n", "func = func.with_attr(\"global_symbol\", \"main\")\n", "ir_mod_from_te = IRModule({\"main\": func})\n", "\n", "Code(ir_mod_from_te.script(), language=\"python\")" ] }, { "cell_type": "markdown", "metadata": { "id": "GEqpO14Lf0Lq" }, "source": [ "## 变换矩阵乘法程序\n", "\n", "在上面的例子中,展示了如何变换向量加法。现在试着把它应用到稍微复杂一点的程序中(矩阵乘法)。首先尝试使用张量表达式 API 构建初始代码。" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "R9ExHE3BfYYv", "outputId": "0d82d527-8051-4bc3-c3a7-0cbb12d9bcce" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@tvm.script.ir_module\n", "class Module:\n", " @tir.prim_func\n", " def main(A: tir.Buffer[(1024, 1024), \"float32\"], B: tir.Buffer[(1024, 1024), \"float32\"], C: tir.Buffer[(1024, 1024), \"float32\"]) -> None:\n", " # function attr dict\n", " tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " # body\n", " # with tir.block(\"root\")\n", " for i0, i1, i2 in tir.grid(1024, 1024, 1024):\n", " with tir.block(\"C\"):\n", " m, n, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n", " tir.reads(A[m, k], B[k, n])\n", " tir.writes(C[m, n])\n", " with tir.init():\n", " C[m, n] = tir.float32(0)\n", " C[m, n] = C[m, n] + A[m, k] * B[k, n]\n", " \n", "Baseline: 2.238068\n" ] } ], "source": [ "from tvm import te\n", "\n", "M = 1024\n", "K = 1024\n", "N = 1024\n", "\n", "# The default tensor type in tvm\n", "dtype = \"float32\"\n", "\n", "target = \"llvm\"\n", "dev = tvm.device(target, 0)\n", "\n", "# Algorithm\n", "k = te.reduce_axis((0, K), \"k\")\n", "A = te.placeholder((M, K), name=\"A\")\n", "B = te.placeholder((K, N), name=\"B\")\n", "C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name=\"C\")\n", "\n", "# Default schedule\n", "func = te.create_prim_func([A, B, C])\n", "func = func.with_attr(\"global_symbol\", \"main\")\n", "ir_module = IRModule({\"main\": func})\n", "print(Code(ir_module.script(), language=\"python\"))\n", "\n", "\n", "func = tvm.build(ir_module, target=\"llvm\") # The module for CPU backends.\n", "\n", "a = tvm.nd.array(np.random.rand(M, K).astype(dtype), dev)\n", "b = tvm.nd.array(np.random.rand(K, N).astype(dtype), dev)\n", "c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)\n", "func(a, b, c)\n", "\n", "evaluator = func.time_evaluator(func.entry_name, dev, number=1)\n", "print(\"Baseline: %f\" % evaluator(a, b, c).mean)" ] }, { "cell_type": "markdown", "metadata": { "id": "swj-gMz-1vBE" }, "source": [ "可以变换循环访问模式,使其对缓存更友好。让我们使用下面的调度。" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "W60q68KRgdNL", "outputId": "b49a101e-5148-4cf0-df88-0e112e741381" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<class 'tvm.tir.schedule.schedule.Schedule'>\n", "after transformation: 0.239306\n" ] } ], "source": [ "sch = tvm.tir.Schedule(ir_module)\n", "print(type(sch))\n", "block_c = sch.get_block(\"C\")\n", "# Get loops surronding the block\n", "(y, x, k) = sch.get_loops(block_c)\n", "block_size = 32\n", "yo, yi = sch.split(y, [None, block_size])\n", "xo, xi = sch.split(x, [None, block_size])\n", "\n", "sch.reorder(yo, xo, k, yi, xi)\n", "Code(sch.mod.script(), language=\"python\")\n", "\n", "func = tvm.build(sch.mod, target=\"llvm\") # The module for CPU backends.\n", "\n", "c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)\n", "func(a, b, c)\n", "\n", "evaluator = func.time_evaluator(func.entry_name, dev, number=1)\n", "print(\"after transformation: %f\" % evaluator(a, b, c).mean)" ] }, { "cell_type": "markdown", "metadata": { "id": "h1RQGOBjn4w_" }, "source": [ "试着改变 bn 的值,看看你能得到什么性能。在实践中,将利用自动系统来搜索一组可能的变换,以找到最优的变换。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "2-tensor-program-abstraction.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3.10.4 ('mlc': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "vscode": { "interpreter": { "hash": "d8a760899c905ec5a15e0d212432af25d7f0b614c7ae634224dffa77837bb03c" } } }, "nbformat": 4, "nbformat_minor": 0 }