{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mpn1ti5Urdsv"
      },
      "source": [
        "# 张量程序抽象\n",
        "\n",
        "## 安装包\n",
        "\n",
        "提供以下命令来安装 mlc 课程的打包版本。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "qXysoqn-vZuF"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Looking in links: https://mlc.ai/wheels\n",
            "Requirement already satisfied: mlc-ai-nightly in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (0.9.dev1664+g1f3985de0)\n",
            "Requirement already satisfied: decorator in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.1.1)\n",
            "Requirement already satisfied: numpy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.23.1)\n",
            "Requirement already satisfied: attrs in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (21.4.0)\n",
            "Requirement already satisfied: scipy in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (1.8.1)\n",
            "Requirement already satisfied: cloudpickle in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (2.1.0)\n",
            "Requirement already satisfied: synr==0.6.0 in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (0.6.0)\n",
            "Requirement already satisfied: psutil in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (5.9.1)\n",
            "Requirement already satisfied: tornado in /media/pc/data/4tb/lxw/anaconda3/envs/mlc/lib/python3.10/site-packages (from mlc-ai-nightly) (6.1)\n"
          ]
        }
      ],
      "source": [
        "!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BBIuE2jc1DaU"
      },
      "source": [
        "## 构建张量程序\n",
        "\n",
        "从构造执行两个向量的加法的张量程序开始。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "vvfOgcu-YdaB"
      },
      "outputs": [],
      "source": [
        "import tvm\n",
        "from tvm.ir.module import IRModule\n",
        "from tvm.script import tir as T\n",
        "import numpy as np\n",
        "from IPython.display import Code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "qCViJNUNYfTW"
      },
      "outputs": [],
      "source": [
        "@tvm.script.ir_module\n",
        "class MyModule:\n",
        "    @T.prim_func\n",
        "    def main(A: T.Buffer[128, \"float32\"], \n",
        "             B: T.Buffer[128, \"float32\"], \n",
        "             C: T.Buffer[128, \"float32\"]):\n",
        "        # 函数的额外注解\n",
        "        T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
        "        for i in range(128):\n",
        "            with T.block(\"C\"):\n",
        "                # 在 spatial 域中声明数据并行迭代器\n",
        "                vi = T.axis.spatial(128, i)\n",
        "                C[vi] = A[vi] + B[vi]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4PJd0Pw8zVQD"
      },
      "source": [
        "TVMScript 是用 Python ast 表达张量程序的方式。注意,这段代码实际上并不对应于 python 程序,而是可以在 MLC 进程中使用的张量程序。该语言被设计成与 python 语法保持一致的附加结构,以方便分析和变换。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QKsLAcDB8Npx",
        "outputId": "8534ef46-c656-4f36-961c-f6e59e04ad6d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tvm.ir.module.IRModule"
            ]
          },
          "execution_count": 4,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "type(MyModule)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GpdPoa5q8Sj7"
      },
      "source": [
        "`MyModule` 是 **IRModule** 数据结构的实例(用来保存张量函数的集合)。\n",
        "\n",
        "可以使用 `script` 函数获得基于 IRModule 表示的字符串。这个函数对于在变换的每个步骤中检查模块非常有用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VXy-4v3Czax9",
        "outputId": "c933d1e0-42d5-4df2-ad9a-6eb997deb10c"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<style>pre { line-height: 125%; }\n",
              "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              ".output_html .hll { background-color: #ffffcc }\n",
              ".output_html { background: #f8f8f8; }\n",
              ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n",
              ".output_html .err { border: 1px solid #FF0000 } /* Error */\n",
              ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n",
              ".output_html .o { color: #666666 } /* Operator */\n",
              ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n",
              ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n",
              ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n",
              ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n",
              ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n",
              ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n",
              ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n",
              ".output_html .ge { font-style: italic } /* Generic.Emph */\n",
              ".output_html .gr { color: #E40000 } /* Generic.Error */\n",
              ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n",
              ".output_html .gi { color: #008400 } /* Generic.Inserted */\n",
              ".output_html .go { color: #717171 } /* Generic.Output */\n",
              ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n",
              ".output_html .gs { font-weight: bold } /* Generic.Strong */\n",
              ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n",
              ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n",
              ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n",
              ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n",
              ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n",
              ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n",
              ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n",
              ".output_html .kt { color: #B00040 } /* Keyword.Type */\n",
              ".output_html .m { color: #666666 } /* Literal.Number */\n",
              ".output_html .s { color: #BA2121 } /* Literal.String */\n",
              ".output_html .na { color: #687822 } /* Name.Attribute */\n",
              ".output_html .nb { color: #008000 } /* Name.Builtin */\n",
              ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n",
              ".output_html .no { color: #880000 } /* Name.Constant */\n",
              ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n",
              ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n",
              ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n",
              ".output_html .nf { color: #0000FF } /* Name.Function */\n",
              ".output_html .nl { color: #767600 } /* Name.Label */\n",
              ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n",
              ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n",
              ".output_html .nv { color: #19177C } /* Name.Variable */\n",
              ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n",
              ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n",
              ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n",
              ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n",
              ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n",
              ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n",
              ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n",
              ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n",
              ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n",
              ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n",
              ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n",
              ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n",
              ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n",
              ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n",
              ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n",
              ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n",
              ".output_html .sx { color: #008000 } /* Literal.String.Other */\n",
              ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n",
              ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n",
              ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n",
              ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n",
              ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n",
              ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n",
              ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n",
              ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n",
              ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n",
              ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n",
              "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n",
              "    <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n",
              "    <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">])</span> <span class=\"o\">-&gt;</span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n",
              "        <span class=\"c1\"># function attr dict</span>\n",
              "        <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">&quot;global_symbol&quot;</span><span class=\"p\">:</span> <span class=\"s2\">&quot;main&quot;</span><span class=\"p\">,</span> <span class=\"s2\">&quot;tir.noalias&quot;</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n",
              "        <span class=\"c1\"># body</span>\n",
              "        <span class=\"c1\"># with tir.block(&quot;root&quot;)</span>\n",
              "        <span class=\"k\">for</span> <span class=\"n\">i</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">serial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">):</span>\n",
              "            <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">&quot;C&quot;</span><span class=\"p\">):</span>\n",
              "                <span class=\"n\">vi</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i</span><span class=\"p\">)</span>\n",
              "                <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n",
              "                <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n",
              "                <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span>\n",
              "    \n",
              "</pre></div>\n"
            ],
            "text/latex": [
              "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
              "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n",
              "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n",
              "    \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n",
              "    \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n",
              "        \\PY{c+c1}{\\PYZsh{} function attr dict}\n",
              "        \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n",
              "        \\PY{c+c1}{\\PYZsh{} body}\n",
              "        \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n",
              "        \\PY{k}{for} \\PY{n}{i} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{serial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{)}\\PY{p}{:}\n",
              "            \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n",
              "                \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i}\\PY{p}{)}\n",
              "                \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n",
              "                \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n",
              "                \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n",
              "    \n",
              "\\end{Verbatim}\n"
            ],
            "text/plain": [
              "@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
              "        # body\n",
              "        # with tir.block(\"root\")\n",
              "        for i in tir.serial(128):\n",
              "            with tir.block(\"C\"):\n",
              "                vi = tir.axis.spatial(128, i)\n",
              "                tir.reads(A[vi], B[vi])\n",
              "                tir.writes(C[vi])\n",
              "                C[vi] = A[vi] + B[vi]\n",
              "    "
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "Code(MyModule.script(), language=\"python\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tOMSOuW6aIJg"
      },
      "source": [
        "### 构建并运行\n",
        "\n",
        "在任何时间点,都可以通过调用 `build` 函数将 IRModule 变换为可运行函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oESoqN-xaTCf",
        "outputId": "58119fbd-b737-400d-bd94-776af0709501"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.driver.build_module.OperatorModule'>\n"
          ]
        }
      ],
      "source": [
        "rt_mod = tvm.build(MyModule, target=\"llvm\")  # CPU 后端的模块\n",
        "print(type(rt_mod))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y2ZfGrH1z6SV"
      },
      "source": [
        "构建后,mod 包含了一组可运行的函数。可以通过函数名检索每个函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "5I3GqwnRz-Ne"
      },
      "outputs": [],
      "source": [
        "func = rt_mod[\"main\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bngdW1eVl683",
        "outputId": "d5e0437c-bf16-4107-aa41-e11a4e1865ce"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<tvm.runtime.packed_func.PackedFunc at 0x7fdd0e398980>"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "func"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "DKxo8uq_mNlp"
      },
      "outputs": [],
      "source": [
        "a = tvm.nd.array(np.arange(128, dtype=\"float32\"))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "1hAFAqv_mP8W"
      },
      "outputs": [],
      "source": [
        "b = tvm.nd.array(np.ones(128, dtype=\"float32\")) "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "TseB1UBumivT"
      },
      "outputs": [],
      "source": [
        "c = tvm.nd.empty((128,), dtype=\"float32\") "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p68xZ0_P0MPw"
      },
      "source": [
        "要调用该函数,可以在 tvm 运行时中创建三个 ndarray,然后调用生成的函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "SMkcgO-L0Xr5"
      },
      "outputs": [],
      "source": [
        "func(a, b, c)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AakkTpE50b6o",
        "outputId": "43971a60-2fbb-41bb-cc47-c496bfe5dda2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[  0.   1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.\n",
            "  14.  15.  16.  17.  18.  19.  20.  21.  22.  23.  24.  25.  26.  27.\n",
            "  28.  29.  30.  31.  32.  33.  34.  35.  36.  37.  38.  39.  40.  41.\n",
            "  42.  43.  44.  45.  46.  47.  48.  49.  50.  51.  52.  53.  54.  55.\n",
            "  56.  57.  58.  59.  60.  61.  62.  63.  64.  65.  66.  67.  68.  69.\n",
            "  70.  71.  72.  73.  74.  75.  76.  77.  78.  79.  80.  81.  82.  83.\n",
            "  84.  85.  86.  87.  88.  89.  90.  91.  92.  93.  94.  95.  96.  97.\n",
            "  98.  99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111.\n",
            " 112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125.\n",
            " 126. 127.]\n",
            "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
            " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
            " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
            " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
            " 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.\n",
            " 1. 1. 1. 1. 1. 1. 1. 1.]\n",
            "[  1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  11.  12.  13.  14.\n",
            "  15.  16.  17.  18.  19.  20.  21.  22.  23.  24.  25.  26.  27.  28.\n",
            "  29.  30.  31.  32.  33.  34.  35.  36.  37.  38.  39.  40.  41.  42.\n",
            "  43.  44.  45.  46.  47.  48.  49.  50.  51.  52.  53.  54.  55.  56.\n",
            "  57.  58.  59.  60.  61.  62.  63.  64.  65.  66.  67.  68.  69.  70.\n",
            "  71.  72.  73.  74.  75.  76.  77.  78.  79.  80.  81.  82.  83.  84.\n",
            "  85.  86.  87.  88.  89.  90.  91.  92.  93.  94.  95.  96.  97.  98.\n",
            "  99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111. 112.\n",
            " 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125. 126.\n",
            " 127. 128.]\n"
          ]
        }
      ],
      "source": [
        "print(a)\n",
        "print(b)\n",
        "print(c)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i_MIDZCOcmwp"
      },
      "source": [
        "## 变换张量程序\n",
        "\n",
        "现在开始变换张量程序。张量程序可以使用一种称为调度(schedule)的辅助数据结构进行变换。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xwyjwh51cjWI",
        "outputId": "d8a8687a-a722-4899-c181-9ca90c7d841e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.tir.schedule.schedule.Schedule'>\n"
          ]
        }
      ],
      "source": [
        "sch = tvm.tir.Schedule(MyModule)\n",
        "print(type(sch))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dw7Fgw8o8HPm"
      },
      "source": [
        "先试着划分循环:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "kNQf8D0ic4me",
        "outputId": "1cbfd7f9-a807-4571-b2e3-66ffe052037d"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<style>pre { line-height: 125%; }\n",
              "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              ".output_html .hll { background-color: #ffffcc }\n",
              ".output_html { background: #f8f8f8; }\n",
              ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n",
              ".output_html .err { border: 1px solid #FF0000 } /* Error */\n",
              ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n",
              ".output_html .o { color: #666666 } /* Operator */\n",
              ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n",
              ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n",
              ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n",
              ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n",
              ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n",
              ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n",
              ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n",
              ".output_html .ge { font-style: italic } /* Generic.Emph */\n",
              ".output_html .gr { color: #E40000 } /* Generic.Error */\n",
              ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n",
              ".output_html .gi { color: #008400 } /* Generic.Inserted */\n",
              ".output_html .go { color: #717171 } /* Generic.Output */\n",
              ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n",
              ".output_html .gs { font-weight: bold } /* Generic.Strong */\n",
              ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n",
              ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n",
              ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n",
              ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n",
              ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n",
              ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n",
              ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n",
              ".output_html .kt { color: #B00040 } /* Keyword.Type */\n",
              ".output_html .m { color: #666666 } /* Literal.Number */\n",
              ".output_html .s { color: #BA2121 } /* Literal.String */\n",
              ".output_html .na { color: #687822 } /* Name.Attribute */\n",
              ".output_html .nb { color: #008000 } /* Name.Builtin */\n",
              ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n",
              ".output_html .no { color: #880000 } /* Name.Constant */\n",
              ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n",
              ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n",
              ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n",
              ".output_html .nf { color: #0000FF } /* Name.Function */\n",
              ".output_html .nl { color: #767600 } /* Name.Label */\n",
              ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n",
              ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n",
              ".output_html .nv { color: #19177C } /* Name.Variable */\n",
              ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n",
              ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n",
              ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n",
              ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n",
              ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n",
              ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n",
              ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n",
              ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n",
              ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n",
              ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n",
              ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n",
              ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n",
              ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n",
              ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n",
              ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n",
              ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n",
              ".output_html .sx { color: #008000 } /* Literal.String.Other */\n",
              ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n",
              ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n",
              ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n",
              ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n",
              ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n",
              ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n",
              ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n",
              ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n",
              ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n",
              ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n",
              "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n",
              "    <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n",
              "    <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">])</span> <span class=\"o\">-&gt;</span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n",
              "        <span class=\"c1\"># function attr dict</span>\n",
              "        <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">&quot;global_symbol&quot;</span><span class=\"p\">:</span> <span class=\"s2\">&quot;main&quot;</span><span class=\"p\">,</span> <span class=\"s2\">&quot;tir.noalias&quot;</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n",
              "        <span class=\"c1\"># body</span>\n",
              "        <span class=\"c1\"># with tir.block(&quot;root&quot;)</span>\n",
              "        <span class=\"k\">for</span> <span class=\"n\">i_0</span><span class=\"p\">,</span> <span class=\"n\">i_1</span><span class=\"p\">,</span> <span class=\"n\">i_2</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">grid</span><span class=\"p\">(</span><span class=\"mi\">8</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">):</span>\n",
              "            <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">&quot;C&quot;</span><span class=\"p\">):</span>\n",
              "                <span class=\"n\">vi</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i_0</span> <span class=\"o\">*</span> <span class=\"mi\">16</span> <span class=\"o\">+</span> <span class=\"n\">i_1</span> <span class=\"o\">*</span> <span class=\"mi\">4</span> <span class=\"o\">+</span> <span class=\"n\">i_2</span><span class=\"p\">)</span>\n",
              "                <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n",
              "                <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n",
              "                <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span>\n",
              "    \n",
              "</pre></div>\n"
            ],
            "text/latex": [
              "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
              "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n",
              "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n",
              "    \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n",
              "    \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n",
              "        \\PY{c+c1}{\\PYZsh{} function attr dict}\n",
              "        \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n",
              "        \\PY{c+c1}{\\PYZsh{} body}\n",
              "        \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n",
              "        \\PY{k}{for} \\PY{n}{i\\PYZus{}0}\\PY{p}{,} \\PY{n}{i\\PYZus{}1}\\PY{p}{,} \\PY{n}{i\\PYZus{}2} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{grid}\\PY{p}{(}\\PY{l+m+mi}{8}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{:}\n",
              "            \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n",
              "                \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i\\PYZus{}0} \\PY{o}{*} \\PY{l+m+mi}{16} \\PY{o}{+} \\PY{n}{i\\PYZus{}1} \\PY{o}{*} \\PY{l+m+mi}{4} \\PY{o}{+} \\PY{n}{i\\PYZus{}2}\\PY{p}{)}\n",
              "                \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n",
              "                \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n",
              "                \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n",
              "    \n",
              "\\end{Verbatim}\n"
            ],
            "text/plain": [
              "@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
              "        # body\n",
              "        # with tir.block(\"root\")\n",
              "        for i_0, i_1, i_2 in tir.grid(8, 4, 4):\n",
              "            with tir.block(\"C\"):\n",
              "                vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
              "                tir.reads(A[vi], B[vi])\n",
              "                tir.writes(C[vi])\n",
              "                C[vi] = A[vi] + B[vi]\n",
              "    "
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# 按名称获取块\n",
        "block_c = sch.get_block(\"C\")\n",
        "# 获取块周围的循环\n",
        "(i,) = sch.get_loops(block_c)\n",
        "# Tile 循环嵌套\n",
        "i_0, i_1, i_2 = sch.split(i, factors=[None, 4, 4])\n",
        "Code(sch.mod.script(), language=\"Python\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nzrbvqBSdC-D"
      },
      "source": [
        "也可以重新排序循环。把循环 i_2 移到 i_1 的外面。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yJWBq7lRdDmn",
        "outputId": "1957cf67-f51c-4af1-c5ca-16c9718181a0"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@tvm.script.ir_module\n",
            "class Module:\n",
            "    @tir.prim_func\n",
            "    def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n",
            "        # function attr dict\n",
            "        tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
            "        # body\n",
            "        # with tir.block(\"root\")\n",
            "        for i_0, i_2, i_1 in tir.grid(8, 4, 4):\n",
            "            with tir.block(\"C\"):\n",
            "                vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
            "                tir.reads(A[vi], B[vi])\n",
            "                tir.writes(C[vi])\n",
            "                C[vi] = A[vi] + B[vi]\n",
            "    \n"
          ]
        }
      ],
      "source": [
        "sch.reorder(i_0, i_2, i_1)\n",
        "print(sch.mod.script())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UmUr6b_L07-8"
      },
      "source": [
        "最后,可以向程序生成器添加要对最内部循环进行向量化的提示。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "u95zQFuldHs_",
        "outputId": "b2205442-ac70-405c-8f42-ddb23e46e012"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<style>pre { line-height: 125%; }\n",
              "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              ".output_html .hll { background-color: #ffffcc }\n",
              ".output_html { background: #f8f8f8; }\n",
              ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n",
              ".output_html .err { border: 1px solid #FF0000 } /* Error */\n",
              ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n",
              ".output_html .o { color: #666666 } /* Operator */\n",
              ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n",
              ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n",
              ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n",
              ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n",
              ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n",
              ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n",
              ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n",
              ".output_html .ge { font-style: italic } /* Generic.Emph */\n",
              ".output_html .gr { color: #E40000 } /* Generic.Error */\n",
              ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n",
              ".output_html .gi { color: #008400 } /* Generic.Inserted */\n",
              ".output_html .go { color: #717171 } /* Generic.Output */\n",
              ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n",
              ".output_html .gs { font-weight: bold } /* Generic.Strong */\n",
              ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n",
              ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n",
              ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n",
              ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n",
              ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n",
              ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n",
              ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n",
              ".output_html .kt { color: #B00040 } /* Keyword.Type */\n",
              ".output_html .m { color: #666666 } /* Literal.Number */\n",
              ".output_html .s { color: #BA2121 } /* Literal.String */\n",
              ".output_html .na { color: #687822 } /* Name.Attribute */\n",
              ".output_html .nb { color: #008000 } /* Name.Builtin */\n",
              ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n",
              ".output_html .no { color: #880000 } /* Name.Constant */\n",
              ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n",
              ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n",
              ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n",
              ".output_html .nf { color: #0000FF } /* Name.Function */\n",
              ".output_html .nl { color: #767600 } /* Name.Label */\n",
              ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n",
              ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n",
              ".output_html .nv { color: #19177C } /* Name.Variable */\n",
              ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n",
              ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n",
              ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n",
              ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n",
              ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n",
              ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n",
              ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n",
              ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n",
              ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n",
              ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n",
              ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n",
              ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n",
              ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n",
              ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n",
              ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n",
              ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n",
              ".output_html .sx { color: #008000 } /* Literal.String.Other */\n",
              ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n",
              ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n",
              ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n",
              ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n",
              ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n",
              ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n",
              ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n",
              ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n",
              ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n",
              ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n",
              "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n",
              "    <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n",
              "    <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">])</span> <span class=\"o\">-&gt;</span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n",
              "        <span class=\"c1\"># function attr dict</span>\n",
              "        <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">&quot;global_symbol&quot;</span><span class=\"p\">:</span> <span class=\"s2\">&quot;main&quot;</span><span class=\"p\">,</span> <span class=\"s2\">&quot;tir.noalias&quot;</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n",
              "        <span class=\"c1\"># body</span>\n",
              "        <span class=\"c1\"># with tir.block(&quot;root&quot;)</span>\n",
              "        <span class=\"k\">for</span> <span class=\"n\">i_0</span><span class=\"p\">,</span> <span class=\"n\">i_2</span><span class=\"p\">,</span> <span class=\"n\">i_1</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">grid</span><span class=\"p\">(</span><span class=\"mi\">8</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">):</span>\n",
              "            <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">&quot;C&quot;</span><span class=\"p\">):</span>\n",
              "                <span class=\"n\">vi</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i_0</span> <span class=\"o\">*</span> <span class=\"mi\">16</span> <span class=\"o\">+</span> <span class=\"n\">i_1</span> <span class=\"o\">*</span> <span class=\"mi\">4</span> <span class=\"o\">+</span> <span class=\"n\">i_2</span><span class=\"p\">)</span>\n",
              "                <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n",
              "                <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n",
              "                <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span>\n",
              "    \n",
              "</pre></div>\n"
            ],
            "text/latex": [
              "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
              "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n",
              "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n",
              "    \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n",
              "    \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n",
              "        \\PY{c+c1}{\\PYZsh{} function attr dict}\n",
              "        \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n",
              "        \\PY{c+c1}{\\PYZsh{} body}\n",
              "        \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n",
              "        \\PY{k}{for} \\PY{n}{i\\PYZus{}0}\\PY{p}{,} \\PY{n}{i\\PYZus{}2}\\PY{p}{,} \\PY{n}{i\\PYZus{}1} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{grid}\\PY{p}{(}\\PY{l+m+mi}{8}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{:}\n",
              "            \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n",
              "                \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i\\PYZus{}0} \\PY{o}{*} \\PY{l+m+mi}{16} \\PY{o}{+} \\PY{n}{i\\PYZus{}1} \\PY{o}{*} \\PY{l+m+mi}{4} \\PY{o}{+} \\PY{n}{i\\PYZus{}2}\\PY{p}{)}\n",
              "                \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n",
              "                \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n",
              "                \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n",
              "    \n",
              "\\end{Verbatim}\n"
            ],
            "text/plain": [
              "@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
              "        # body\n",
              "        # with tir.block(\"root\")\n",
              "        for i_0, i_2, i_1 in tir.grid(8, 4, 4):\n",
              "            with tir.block(\"C\"):\n",
              "                vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
              "                tir.reads(A[vi], B[vi])\n",
              "                tir.writes(C[vi])\n",
              "                C[vi] = A[vi] + B[vi]\n",
              "    "
            ]
          },
          "execution_count": 17,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "Code(sch.mod.script(), language=\"python\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "m7NFPx9fy5Wy"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<style>pre { line-height: 125%; }\n",
              "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              ".output_html .hll { background-color: #ffffcc }\n",
              ".output_html { background: #f8f8f8; }\n",
              ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n",
              ".output_html .err { border: 1px solid #FF0000 } /* Error */\n",
              ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n",
              ".output_html .o { color: #666666 } /* Operator */\n",
              ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n",
              ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n",
              ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n",
              ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n",
              ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n",
              ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n",
              ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n",
              ".output_html .ge { font-style: italic } /* Generic.Emph */\n",
              ".output_html .gr { color: #E40000 } /* Generic.Error */\n",
              ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n",
              ".output_html .gi { color: #008400 } /* Generic.Inserted */\n",
              ".output_html .go { color: #717171 } /* Generic.Output */\n",
              ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n",
              ".output_html .gs { font-weight: bold } /* Generic.Strong */\n",
              ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n",
              ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n",
              ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n",
              ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n",
              ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n",
              ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n",
              ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n",
              ".output_html .kt { color: #B00040 } /* Keyword.Type */\n",
              ".output_html .m { color: #666666 } /* Literal.Number */\n",
              ".output_html .s { color: #BA2121 } /* Literal.String */\n",
              ".output_html .na { color: #687822 } /* Name.Attribute */\n",
              ".output_html .nb { color: #008000 } /* Name.Builtin */\n",
              ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n",
              ".output_html .no { color: #880000 } /* Name.Constant */\n",
              ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n",
              ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n",
              ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n",
              ".output_html .nf { color: #0000FF } /* Name.Function */\n",
              ".output_html .nl { color: #767600 } /* Name.Label */\n",
              ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n",
              ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n",
              ".output_html .nv { color: #19177C } /* Name.Variable */\n",
              ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n",
              ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n",
              ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n",
              ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n",
              ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n",
              ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n",
              ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n",
              ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n",
              ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n",
              ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n",
              ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n",
              ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n",
              ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n",
              ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n",
              ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n",
              ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n",
              ".output_html .sx { color: #008000 } /* Literal.String.Other */\n",
              ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n",
              ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n",
              ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n",
              ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n",
              ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n",
              ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n",
              ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n",
              ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n",
              ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n",
              ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n",
              "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n",
              "    <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n",
              "    <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">])</span> <span class=\"o\">-&gt;</span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n",
              "        <span class=\"c1\"># function attr dict</span>\n",
              "        <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">&quot;global_symbol&quot;</span><span class=\"p\">:</span> <span class=\"s2\">&quot;main&quot;</span><span class=\"p\">,</span> <span class=\"s2\">&quot;tir.noalias&quot;</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n",
              "        <span class=\"c1\"># body</span>\n",
              "        <span class=\"c1\"># with tir.block(&quot;root&quot;)</span>\n",
              "        <span class=\"k\">for</span> <span class=\"n\">i_0</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">parallel</span><span class=\"p\">(</span><span class=\"mi\">8</span><span class=\"p\">):</span>\n",
              "            <span class=\"k\">for</span> <span class=\"n\">i_2</span><span class=\"p\">,</span> <span class=\"n\">i_1</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">grid</span><span class=\"p\">(</span><span class=\"mi\">4</span><span class=\"p\">,</span> <span class=\"mi\">4</span><span class=\"p\">):</span>\n",
              "                <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">&quot;C&quot;</span><span class=\"p\">):</span>\n",
              "                    <span class=\"n\">vi</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i_0</span> <span class=\"o\">*</span> <span class=\"mi\">16</span> <span class=\"o\">+</span> <span class=\"n\">i_1</span> <span class=\"o\">*</span> <span class=\"mi\">4</span> <span class=\"o\">+</span> <span class=\"n\">i_2</span><span class=\"p\">)</span>\n",
              "                    <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n",
              "                    <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">])</span>\n",
              "                    <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">vi</span><span class=\"p\">]</span>\n",
              "    \n",
              "</pre></div>\n"
            ],
            "text/latex": [
              "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
              "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n",
              "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n",
              "    \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n",
              "    \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n",
              "        \\PY{c+c1}{\\PYZsh{} function attr dict}\n",
              "        \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n",
              "        \\PY{c+c1}{\\PYZsh{} body}\n",
              "        \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n",
              "        \\PY{k}{for} \\PY{n}{i\\PYZus{}0} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{parallel}\\PY{p}{(}\\PY{l+m+mi}{8}\\PY{p}{)}\\PY{p}{:}\n",
              "            \\PY{k}{for} \\PY{n}{i\\PYZus{}2}\\PY{p}{,} \\PY{n}{i\\PYZus{}1} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{grid}\\PY{p}{(}\\PY{l+m+mi}{4}\\PY{p}{,} \\PY{l+m+mi}{4}\\PY{p}{)}\\PY{p}{:}\n",
              "                \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n",
              "                    \\PY{n}{vi} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i\\PYZus{}0} \\PY{o}{*} \\PY{l+m+mi}{16} \\PY{o}{+} \\PY{n}{i\\PYZus{}1} \\PY{o}{*} \\PY{l+m+mi}{4} \\PY{o}{+} \\PY{n}{i\\PYZus{}2}\\PY{p}{)}\n",
              "                    \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n",
              "                    \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\\PY{p}{)}\n",
              "                    \\PY{n}{C}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{vi}\\PY{p}{]}\n",
              "    \n",
              "\\end{Verbatim}\n"
            ],
            "text/plain": [
              "@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
              "        # body\n",
              "        # with tir.block(\"root\")\n",
              "        for i_0 in tir.parallel(8):\n",
              "            for i_2, i_1 in tir.grid(4, 4):\n",
              "                with tir.block(\"C\"):\n",
              "                    vi = tir.axis.spatial(128, i_0 * 16 + i_1 * 4 + i_2)\n",
              "                    tir.reads(A[vi], B[vi])\n",
              "                    tir.writes(C[vi])\n",
              "                    C[vi] = A[vi] + B[vi]\n",
              "    "
            ]
          },
          "execution_count": 18,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sch.parallel(i_0)\n",
        "Code(sch.mod.script(), language=\"python\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OhGlqLTG_tNv"
      },
      "source": [
        "可以构建并运行变换后的程序"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "sCIYSDrI_wGq"
      },
      "outputs": [],
      "source": [
        "transformed_mod = tvm.build(sch.mod, target=\"llvm\")  # The module for CPU backends.\n",
        "transformed_mod[\"main\"](a, b, c)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wj_01P4mAfu2"
      },
      "source": [
        "## 使用张量表达式构建张量程序\n",
        "\n",
        "在前面的例子中,直接使用 TVMScript 来构造张量程序。在实践中,从现有定义实用地构造这些函数通常是有帮助的。张量表达式是一个 API,它可以帮助构建一些类似表达式的数组计算。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TZAPcqbGAesY",
        "outputId": "ac4d6407-8dea-4c75-d166-e071ffee8783"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<style>pre { line-height: 125%; }\n",
              "td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }\n",
              "td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              "span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }\n",
              ".output_html .hll { background-color: #ffffcc }\n",
              ".output_html { background: #f8f8f8; }\n",
              ".output_html .c { color: #3D7B7B; font-style: italic } /* Comment */\n",
              ".output_html .err { border: 1px solid #FF0000 } /* Error */\n",
              ".output_html .k { color: #008000; font-weight: bold } /* Keyword */\n",
              ".output_html .o { color: #666666 } /* Operator */\n",
              ".output_html .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */\n",
              ".output_html .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */\n",
              ".output_html .cp { color: #9C6500 } /* Comment.Preproc */\n",
              ".output_html .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */\n",
              ".output_html .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */\n",
              ".output_html .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */\n",
              ".output_html .gd { color: #A00000 } /* Generic.Deleted */\n",
              ".output_html .ge { font-style: italic } /* Generic.Emph */\n",
              ".output_html .gr { color: #E40000 } /* Generic.Error */\n",
              ".output_html .gh { color: #000080; font-weight: bold } /* Generic.Heading */\n",
              ".output_html .gi { color: #008400 } /* Generic.Inserted */\n",
              ".output_html .go { color: #717171 } /* Generic.Output */\n",
              ".output_html .gp { color: #000080; font-weight: bold } /* Generic.Prompt */\n",
              ".output_html .gs { font-weight: bold } /* Generic.Strong */\n",
              ".output_html .gu { color: #800080; font-weight: bold } /* Generic.Subheading */\n",
              ".output_html .gt { color: #0044DD } /* Generic.Traceback */\n",
              ".output_html .kc { color: #008000; font-weight: bold } /* Keyword.Constant */\n",
              ".output_html .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */\n",
              ".output_html .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */\n",
              ".output_html .kp { color: #008000 } /* Keyword.Pseudo */\n",
              ".output_html .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */\n",
              ".output_html .kt { color: #B00040 } /* Keyword.Type */\n",
              ".output_html .m { color: #666666 } /* Literal.Number */\n",
              ".output_html .s { color: #BA2121 } /* Literal.String */\n",
              ".output_html .na { color: #687822 } /* Name.Attribute */\n",
              ".output_html .nb { color: #008000 } /* Name.Builtin */\n",
              ".output_html .nc { color: #0000FF; font-weight: bold } /* Name.Class */\n",
              ".output_html .no { color: #880000 } /* Name.Constant */\n",
              ".output_html .nd { color: #AA22FF } /* Name.Decorator */\n",
              ".output_html .ni { color: #717171; font-weight: bold } /* Name.Entity */\n",
              ".output_html .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */\n",
              ".output_html .nf { color: #0000FF } /* Name.Function */\n",
              ".output_html .nl { color: #767600 } /* Name.Label */\n",
              ".output_html .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */\n",
              ".output_html .nt { color: #008000; font-weight: bold } /* Name.Tag */\n",
              ".output_html .nv { color: #19177C } /* Name.Variable */\n",
              ".output_html .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */\n",
              ".output_html .w { color: #bbbbbb } /* Text.Whitespace */\n",
              ".output_html .mb { color: #666666 } /* Literal.Number.Bin */\n",
              ".output_html .mf { color: #666666 } /* Literal.Number.Float */\n",
              ".output_html .mh { color: #666666 } /* Literal.Number.Hex */\n",
              ".output_html .mi { color: #666666 } /* Literal.Number.Integer */\n",
              ".output_html .mo { color: #666666 } /* Literal.Number.Oct */\n",
              ".output_html .sa { color: #BA2121 } /* Literal.String.Affix */\n",
              ".output_html .sb { color: #BA2121 } /* Literal.String.Backtick */\n",
              ".output_html .sc { color: #BA2121 } /* Literal.String.Char */\n",
              ".output_html .dl { color: #BA2121 } /* Literal.String.Delimiter */\n",
              ".output_html .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */\n",
              ".output_html .s2 { color: #BA2121 } /* Literal.String.Double */\n",
              ".output_html .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */\n",
              ".output_html .sh { color: #BA2121 } /* Literal.String.Heredoc */\n",
              ".output_html .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */\n",
              ".output_html .sx { color: #008000 } /* Literal.String.Other */\n",
              ".output_html .sr { color: #A45A77 } /* Literal.String.Regex */\n",
              ".output_html .s1 { color: #BA2121 } /* Literal.String.Single */\n",
              ".output_html .ss { color: #19177C } /* Literal.String.Symbol */\n",
              ".output_html .bp { color: #008000 } /* Name.Builtin.Pseudo */\n",
              ".output_html .fm { color: #0000FF } /* Name.Function.Magic */\n",
              ".output_html .vc { color: #19177C } /* Name.Variable.Class */\n",
              ".output_html .vg { color: #19177C } /* Name.Variable.Global */\n",
              ".output_html .vi { color: #19177C } /* Name.Variable.Instance */\n",
              ".output_html .vm { color: #19177C } /* Name.Variable.Magic */\n",
              ".output_html .il { color: #666666 } /* Literal.Number.Integer.Long */</style><div class=\"highlight\"><pre><span></span><span class=\"nd\">@tvm</span><span class=\"o\">.</span><span class=\"n\">script</span><span class=\"o\">.</span><span class=\"n\">ir_module</span>\n",
              "<span class=\"k\">class</span> <span class=\"nc\">Module</span><span class=\"p\">:</span>\n",
              "    <span class=\"nd\">@tir</span><span class=\"o\">.</span><span class=\"n\">prim_func</span>\n",
              "    <span class=\"k\">def</span> <span class=\"nf\">main</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">],</span> <span class=\"n\">C</span><span class=\"p\">:</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">Buffer</span><span class=\"p\">[</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"s2\">&quot;float32&quot;</span><span class=\"p\">])</span> <span class=\"o\">-&gt;</span> <span class=\"kc\">None</span><span class=\"p\">:</span>\n",
              "        <span class=\"c1\"># function attr dict</span>\n",
              "        <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">func_attr</span><span class=\"p\">({</span><span class=\"s2\">&quot;global_symbol&quot;</span><span class=\"p\">:</span> <span class=\"s2\">&quot;main&quot;</span><span class=\"p\">,</span> <span class=\"s2\">&quot;tir.noalias&quot;</span><span class=\"p\">:</span> <span class=\"kc\">True</span><span class=\"p\">})</span>\n",
              "        <span class=\"c1\"># body</span>\n",
              "        <span class=\"c1\"># with tir.block(&quot;root&quot;)</span>\n",
              "        <span class=\"k\">for</span> <span class=\"n\">i0</span> <span class=\"ow\">in</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">serial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">):</span>\n",
              "            <span class=\"k\">with</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">block</span><span class=\"p\">(</span><span class=\"s2\">&quot;C&quot;</span><span class=\"p\">):</span>\n",
              "                <span class=\"n\">i</span> <span class=\"o\">=</span> <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">axis</span><span class=\"o\">.</span><span class=\"n\">spatial</span><span class=\"p\">(</span><span class=\"mi\">128</span><span class=\"p\">,</span> <span class=\"n\">i0</span><span class=\"p\">)</span>\n",
              "                <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">reads</span><span class=\"p\">(</span><span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">],</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">])</span>\n",
              "                <span class=\"n\">tir</span><span class=\"o\">.</span><span class=\"n\">writes</span><span class=\"p\">(</span><span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">])</span>\n",
              "                <span class=\"n\">C</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">]</span> <span class=\"o\">=</span> <span class=\"n\">A</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">]</span> <span class=\"o\">+</span> <span class=\"n\">B</span><span class=\"p\">[</span><span class=\"n\">i</span><span class=\"p\">]</span>\n",
              "    \n",
              "</pre></div>\n"
            ],
            "text/latex": [
              "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
              "\\PY{n+nd}{@tvm}\\PY{o}{.}\\PY{n}{script}\\PY{o}{.}\\PY{n}{ir\\PYZus{}module}\n",
              "\\PY{k}{class} \\PY{n+nc}{Module}\\PY{p}{:}\n",
              "    \\PY{n+nd}{@tir}\\PY{o}{.}\\PY{n}{prim\\PYZus{}func}\n",
              "    \\PY{k}{def} \\PY{n+nf}{main}\\PY{p}{(}\\PY{n}{A}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{,} \\PY{n}{C}\\PY{p}{:} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{Buffer}\\PY{p}{[}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{float32}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{]}\\PY{p}{)} \\PY{o}{\\PYZhy{}}\\PY{o}{\\PYZgt{}} \\PY{k+kc}{None}\\PY{p}{:}\n",
              "        \\PY{c+c1}{\\PYZsh{} function attr dict}\n",
              "        \\PY{n}{tir}\\PY{o}{.}\\PY{n}{func\\PYZus{}attr}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{global\\PYZus{}symbol}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{main}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{,} \\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{tir.noalias}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{:} \\PY{k+kc}{True}\\PY{p}{\\PYZcb{}}\\PY{p}{)}\n",
              "        \\PY{c+c1}{\\PYZsh{} body}\n",
              "        \\PY{c+c1}{\\PYZsh{} with tir.block(\\PYZdq{}root\\PYZdq{})}\n",
              "        \\PY{k}{for} \\PY{n}{i0} \\PY{o+ow}{in} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{serial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{)}\\PY{p}{:}\n",
              "            \\PY{k}{with} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{block}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{:}\n",
              "                \\PY{n}{i} \\PY{o}{=} \\PY{n}{tir}\\PY{o}{.}\\PY{n}{axis}\\PY{o}{.}\\PY{n}{spatial}\\PY{p}{(}\\PY{l+m+mi}{128}\\PY{p}{,} \\PY{n}{i0}\\PY{p}{)}\n",
              "                \\PY{n}{tir}\\PY{o}{.}\\PY{n}{reads}\\PY{p}{(}\\PY{n}{A}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{,} \\PY{n}{B}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{)}\n",
              "                \\PY{n}{tir}\\PY{o}{.}\\PY{n}{writes}\\PY{p}{(}\\PY{n}{C}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\\PY{p}{)}\n",
              "                \\PY{n}{C}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]} \\PY{o}{=} \\PY{n}{A}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]} \\PY{o}{+} \\PY{n}{B}\\PY{p}{[}\\PY{n}{i}\\PY{p}{]}\n",
              "    \n",
              "\\end{Verbatim}\n"
            ],
            "text/plain": [
              "@tvm.script.ir_module\n",
              "class Module:\n",
              "    @tir.prim_func\n",
              "    def main(A: tir.Buffer[128, \"float32\"], B: tir.Buffer[128, \"float32\"], C: tir.Buffer[128, \"float32\"]) -> None:\n",
              "        # function attr dict\n",
              "        tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
              "        # body\n",
              "        # with tir.block(\"root\")\n",
              "        for i0 in tir.serial(128):\n",
              "            with tir.block(\"C\"):\n",
              "                i = tir.axis.spatial(128, i0)\n",
              "                tir.reads(A[i], B[i])\n",
              "                tir.writes(C[i])\n",
              "                C[i] = A[i] + B[i]\n",
              "    "
            ]
          },
          "execution_count": 20,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# namespace for tensor expression utility\n",
        "from tvm import te\n",
        "\n",
        "# declare the computation using the expression API\n",
        "A = te.placeholder((128, ), name=\"A\")\n",
        "B = te.placeholder((128, ), name=\"B\")\n",
        "C = te.compute((128,), lambda i: A[i] + B[i], name=\"C\")\n",
        "\n",
        "# create a function with the specified list of arguments. \n",
        "func = te.create_prim_func([A, B, C])\n",
        "# mark that the function name is main\n",
        "func = func.with_attr(\"global_symbol\", \"main\")\n",
        "ir_mod_from_te = IRModule({\"main\": func})\n",
        "\n",
        "Code(ir_mod_from_te.script(), language=\"python\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GEqpO14Lf0Lq"
      },
      "source": [
        "## 变换矩阵乘法程序\n",
        "\n",
        "在上面的例子中,展示了如何变换向量加法。现在试着把它应用到稍微复杂一点的程序中(矩阵乘法)。首先尝试使用张量表达式 API 构建初始代码。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "R9ExHE3BfYYv",
        "outputId": "0d82d527-8051-4bc3-c3a7-0cbb12d9bcce"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "@tvm.script.ir_module\n",
            "class Module:\n",
            "    @tir.prim_func\n",
            "    def main(A: tir.Buffer[(1024, 1024), \"float32\"], B: tir.Buffer[(1024, 1024), \"float32\"], C: tir.Buffer[(1024, 1024), \"float32\"]) -> None:\n",
            "        # function attr dict\n",
            "        tir.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
            "        # body\n",
            "        # with tir.block(\"root\")\n",
            "        for i0, i1, i2 in tir.grid(1024, 1024, 1024):\n",
            "            with tir.block(\"C\"):\n",
            "                m, n, k = tir.axis.remap(\"SSR\", [i0, i1, i2])\n",
            "                tir.reads(A[m, k], B[k, n])\n",
            "                tir.writes(C[m, n])\n",
            "                with tir.init():\n",
            "                    C[m, n] = tir.float32(0)\n",
            "                C[m, n] = C[m, n] + A[m, k] * B[k, n]\n",
            "    \n",
            "Baseline: 2.238068\n"
          ]
        }
      ],
      "source": [
        "from tvm import te\n",
        "\n",
        "M = 1024\n",
        "K = 1024\n",
        "N = 1024\n",
        "\n",
        "# The default tensor type in tvm\n",
        "dtype = \"float32\"\n",
        "\n",
        "target = \"llvm\"\n",
        "dev = tvm.device(target, 0)\n",
        "\n",
        "# Algorithm\n",
        "k = te.reduce_axis((0, K), \"k\")\n",
        "A = te.placeholder((M, K), name=\"A\")\n",
        "B = te.placeholder((K, N), name=\"B\")\n",
        "C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name=\"C\")\n",
        "\n",
        "# Default schedule\n",
        "func = te.create_prim_func([A, B, C])\n",
        "func = func.with_attr(\"global_symbol\", \"main\")\n",
        "ir_module = IRModule({\"main\": func})\n",
        "print(Code(ir_module.script(), language=\"python\"))\n",
        "\n",
        "\n",
        "func = tvm.build(ir_module, target=\"llvm\")  # The module for CPU backends.\n",
        "\n",
        "a = tvm.nd.array(np.random.rand(M, K).astype(dtype), dev)\n",
        "b = tvm.nd.array(np.random.rand(K, N).astype(dtype), dev)\n",
        "c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)\n",
        "func(a, b, c)\n",
        "\n",
        "evaluator = func.time_evaluator(func.entry_name, dev, number=1)\n",
        "print(\"Baseline: %f\" % evaluator(a, b, c).mean)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "swj-gMz-1vBE"
      },
      "source": [
        "可以变换循环访问模式,使其对缓存更友好。让我们使用下面的调度。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W60q68KRgdNL",
        "outputId": "b49a101e-5148-4cf0-df88-0e112e741381"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.tir.schedule.schedule.Schedule'>\n",
            "after transformation: 0.239306\n"
          ]
        }
      ],
      "source": [
        "sch = tvm.tir.Schedule(ir_module)\n",
        "print(type(sch))\n",
        "block_c = sch.get_block(\"C\")\n",
        "# Get loops surronding the block\n",
        "(y, x, k) = sch.get_loops(block_c)\n",
        "block_size = 32\n",
        "yo, yi = sch.split(y, [None, block_size])\n",
        "xo, xi = sch.split(x, [None, block_size])\n",
        "\n",
        "sch.reorder(yo, xo, k, yi, xi)\n",
        "Code(sch.mod.script(), language=\"python\")\n",
        "\n",
        "func = tvm.build(sch.mod, target=\"llvm\")  # The module for CPU backends.\n",
        "\n",
        "c = tvm.nd.array(np.zeros((M, N), dtype=dtype), dev)\n",
        "func(a, b, c)\n",
        "\n",
        "evaluator = func.time_evaluator(func.entry_name, dev, number=1)\n",
        "print(\"after transformation: %f\" % evaluator(a, b, c).mean)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h1RQGOBjn4w_"
      },
      "source": [
        "试着改变 bn 的值,看看你能得到什么性能。在实践中,将利用自动系统来搜索一组可能的变换,以找到最优的变换。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "include_colab_link": true,
      "name": "2-tensor-program-abstraction.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3.10.4 ('mlc': conda)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    },
    "vscode": {
      "interpreter": {
        "hash": "d8a760899c905ec5a15e0d212432af25d7f0b614c7ae634224dffa77837bb03c"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}