{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 如何使用 TVM Pass Infra\n",
        "\n",
        "**原作者**: [Zhi Chen](https://github.com/zhiics)\n",
        "\n",
        "随着在 Relay/tir 中优化 pass 数量的增加,手动执行它们并维护它们的依赖关系变得非常棘手。因此,TVM 引入了 infrastructure 来管理优化 pass,并使其适用于 TVM 堆栈中 IR 的不同层。\n",
        "\n",
        "Relay/tir 程序的优化可以应用在不同的粒度上,即函数级 {py:class}`tvm.relay.transform.FunctionPass`/{py:class}`tvm.tir.transform.PrimFuncPass` 和模块级 {py:class}`tvm.transform.ModulePass`。或者用户可以依赖于 {py:class}`tvm.transform.Sequential` 在 Relay/tir 程序上应用 pass 序列,其中 pass 之间的依赖性可以由 `pass infra` 解析。有关每种 pass 的详细信息,请参阅 {ref}`pass-infra`。\n",
        "\n",
        "本教程主要演示开发人员如何使用 pass infra 执行某种优化,并为 Relay 程序创建优化管道。同样的方法也可以用于 tir。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tvm\n",
        "import tvm.relay as relay"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 创建 Relay 程序示例\n",
        "\n",
        "首先,为创建简单的 Relay 程序。本教程中的示例将使用这个程序进行各种优化。类似地,用户可以编写 tir 原语函数并应用 tir passes。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def example():\n",
        "    shape = (1, 64, 54, 54)\n",
        "    c_data = np.empty(shape).astype(\"float32\")\n",
        "    c = relay.const(c_data)\n",
        "    weight = relay.var(\"weight\", shape=(64, 64, 3, 3))\n",
        "    x = relay.var(\"x\", relay.TensorType((1, 64, 56, 56), \"float32\"))\n",
        "    conv = relay.nn.conv2d(x, weight, kernel_size=(3, 3))\n",
        "    y = relay.add(c, c)\n",
        "    y = relay.multiply(y, relay.const(2, \"float32\"))\n",
        "    y = relay.add(conv, y)\n",
        "    z = relay.add(y, c)\n",
        "    z1 = relay.add(y, c)\n",
        "    z2 = relay.add(z, z1)\n",
        "    return relay.Function([x, weight], z2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 优化程序\n",
        "\n",
        "现在要优化程序。Relay 具有许多优化功能。选择其中一些应用到这个示例程序中。\n",
        "\n",
        "有多种方法来优化 Relay 程序。下面将分别提供它们的示例。\n",
        "\n",
        "### 手动应用优化 passes\n",
        "\n",
        "创建 Relay 模块,它包含一个或多个用于优化的 Relay 函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "def @main(%x: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32]) {\n",
            "  %0 = add(meta[relay.Constant][0], meta[relay.Constant][0]);\n",
            "  %1 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0], kernel_size=[3, 3]);\n",
            "  %2 = multiply(%0, 2f);\n",
            "  %3 = add(%1, %2);\n",
            "  %4 = add(%3, meta[relay.Constant][0]);\n",
            "  %5 = add(%3, meta[relay.Constant][0]);\n",
            "  add(%4, %5)\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "f = example()\n",
        "mod = tvm.IRModule.from_expr(f)\n",
        "print(mod)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "可以对模块应用常量折叠。\n",
        "\n",
        "`fold_const` 是不带任何参数的回调函数。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [],
      "source": [
        "fold_const = relay.transform.FoldConstant()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "对给定的模块调用 pass。注意,常量折叠传递在函数级工作。也就是说,模块中的每个函数都将被优化应用。用户不需要手动遍历各个函数来应用此传递。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %0 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %1 = add(%0, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  add(%2, %3) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "mod = fold_const(mod)\n",
        "# 可以从更新的程序中看到,常数是折叠的。\n",
        "print(mod)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "以类似的方式应用更多的优化。例如,可以消除  `z` 和 `z1` 使用的常见表达式。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %0 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %1 = add(%0, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %2 = add(%1, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "mod = relay.transform.EliminateCommonSubexpr()(mod)\n",
        "print(mod)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "一些优化,如融合,也是参数化的。例如,opt 级别 0 将不允许算子融合在一起。用户可以通过 `fuse_opt_level` 来启用它。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %0 = fn (%p03: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p12: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    nn.conv2d(%p03, %p12, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %1 = %0(%x, %weight) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %2 = fn (%p02: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p11: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    add(%p02, %p11) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3 = %2(%1, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %4 = fn (%p01: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p1: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    add(%p01, %p1) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %5 = %4(%3, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %6 = fn (%p0: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    add(%p0, %p0) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %6(%5) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "mod = relay.transform.FuseOps(fuse_opt_level=0)(mod)\n",
        "\n",
        "# 可以观察到,优化后的模块包含的函数只有 primitive op\n",
        "print(mod)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 使用 Sequential 来应用 Passes 序列\n",
        "\n",
        "像上面那样应用 pass 实际上是很乏味的,它可能需要用户更好地理解它们之间的依赖关系。例如,fusion 目前在 let 绑定上不能很好地工作。因此,如果在融合之前应用  {py:func}`relay.transform.ToANormalForm`,将无法融合可融合的算子,因为此 pass 为每个表达式生成 let 绑定,以规范化 Relay 程序。\n",
        "\n",
        "因此,Relay 提供了 {py:class}`tvm.transform.Sequential`,通过指定每个 pass 所需的 passes 并将它们打包为一个整体来执行,从而使开发人员不必明确地处理这些问题。例如,现在可以使用 sequential 样式应用相同的 pass,如下所示。{py:class}`tvm.transform.Sequential` 与 [torch.nn.sequential](https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential) 和 [mxnet.gluon.block](https://mxnet.apache.org/api/python/docs/_modules/mxnet/gluon/block.html) 类似。\n",
        "\n",
        "例如,`torch.nn.sequential` 用于包含 PyTorch 模块序列,这些模块将被添加以构建网络。它主要关注网络层。相反,pass infra 中的 {py:class}`tvm.transform.Sequential` 作用于优化 pass。\n",
        "\n",
        "下面通过 {py:class}`tvm.transform.Sequential` 执行一些传递:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %4 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %3 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %3) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %4(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "f = example()\n",
        "mod = tvm.IRModule.from_expr(f)\n",
        "# Glob 感兴趣的 passes.\n",
        "seq = tvm.transform.Sequential(\n",
        "    [\n",
        "        relay.transform.FoldConstant(),\n",
        "        relay.transform.EliminateCommonSubexpr(),\n",
        "        relay.transform.FuseOps(fuse_opt_level=2),\n",
        "    ]\n",
        ")\n",
        "mod1 = seq(mod)\n",
        "print(mod1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "从变换后的 Relay 程序中,可以看到仍然有两个相同的加法运算。这是因为实际上并没有执行 ``EliminateCommonSubexpr``。原因是在 {py:class}`tvm.transform.Sequential` 下,默认只执行优化级别小于或等于 2 的传递。然而,pass infra 为用户提供了配置接口,以定制他们想要执行的优化级别。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "with tvm.transform.PassContext(opt_level=3):\n",
        "    mod2 = seq(mod)\n",
        "print(mod2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "现在可以看到,两个相同的加法项中只有一个被保留了。\n",
        "\n",
        "此外,用户可以使用 `disabled_pass` 配置有选择地禁用一些传递,这类似于通用编译器(如 Clang 和 GCC)使用的 `-fno-xxx` 选项。例如,可以如下所示禁用 `EliminateCommonSubexpr`。打印的模块将再次显示两个相同的加法运算。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %4 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %3 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %3) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %4(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "with tvm.transform.PassContext(opt_level=3, disabled_pass=[\"EliminateCommonSubexpr\"]):\n",
        "    mod3 = seq(mod)\n",
        "print(mod3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 使用 Python 装饰器实现 Pass\n",
        "\n",
        "下一个示例演示了如何使用 Python 装饰器借助 pass infra 编排定制的优化管道。这个功能大大简化了 pass 的实现。例如,用户可以简单地定义装饰类来进行函数级优化,如下面的示例所示。`transform_function` 包装类,用 `c` 的倍数替换所有常量。稍后,将访问给定模块中的每个函数,并在调用自定义传递时替换函数中的每个常量。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %0 = multiply(3f /* ty=float32 */, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %1 = add(%0, %0) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %2 = multiply(3f /* ty=float32 */, 2f /* ty=float32 */) /* ty=float32 */;\n",
            "  %3 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %4 = multiply(%1, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %5 = add(%3, %4) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %6 = add(%5, %0) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %7 = add(%5, %0) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "  add(%6, %7) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "@relay.transform.function_pass(opt_level=1)\n",
        "class CustomPipeline:\n",
        "    \"\"\"Simple test function to replace one argument to another.\"\"\"\n",
        "\n",
        "    def __init__(self, multiplier):\n",
        "        self.multiplier = multiplier\n",
        "\n",
        "    # This function can define a pass.\n",
        "    def transform_function(self, func, mod, ctx):\n",
        "        obj = self\n",
        "\n",
        "        class ReplaceConstant(tvm.relay.ExprMutator):\n",
        "            def visit_constant(self, c):\n",
        "                return relay.multiply(obj.multiplier, c)\n",
        "\n",
        "        return ReplaceConstant().visit(func)\n",
        "\n",
        "\n",
        "f = example()\n",
        "mod = tvm.IRModule.from_expr(f)\n",
        "custom_pass = CustomPipeline(multiplier=relay.const(3, \"float32\"))\n",
        "assert custom_pass.info.name == \"CustomPipeline\"\n",
        "mod3 = custom_pass(mod)\n",
        "print(mod3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 调试 Pass\n",
        "\n",
        "TVM 为用户提供了即插即用风格(plug-and-play)的调试传递,它通过特殊的传递( ``PrintIR`` ) 转储(dump)整个模块的 IR,在完成某个传递后打印 IR。对 sequential 传递示例稍加修改的版本如下所示,以便为 ``FoldConstant`` 优化启用 IR 转储。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "f = example()\n",
        "mod = tvm.IRModule.from_expr(f)\n",
        "seq = tvm.transform.Sequential(\n",
        "    [\n",
        "        relay.transform.FoldConstant(),\n",
        "        tvm.transform.PrintIR(),\n",
        "        relay.transform.EliminateCommonSubexpr(),\n",
        "        relay.transform.FuseOps(),\n",
        "    ]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "通过在 ``FoldConstant`` 之后插入 ``PrintIR``,当 ``FoldConstant`` 完成时,pass infra 将转储模块 IR。用户可以在想要调试的任何 pass 之后插入它,以查看优化效果。\n",
        "\n",
        "\n",
        "有一个更灵活的调试机制。可以实现 ``PassInstrument`` 类来执行任意代码,不仅在每次传递之前和/或之后,还可以在进入/退出 ``PassContext`` 时执行。查看 {ref}`pass_instrument_cpp_backend` 了解更多信息。\n",
        "\n",
        "这里使用 {py:func}`tvm.instrument.pass_instrument` 装饰器,实现 PassInsturment 类在每次 Pass 执行前打印 IR:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Running pass: {} The meta data of the pass - pass name: sequential, opt_level: 0, required passes: []\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "Running pass: {} The meta data of the pass - pass name: FoldConstant, opt_level: 2, required passes: []\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "Running pass: {} The meta data of the pass - pass name: InferType, opt_level: 0, required passes: []\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "Running pass: {} The meta data of the pass - pass name: PrintIR, opt_level: 0, required passes: []\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "Running pass: {} The meta data of the pass - pass name: InferType, opt_level: 0, required passes: []\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "Running pass: {} The meta data of the pass - pass name: EliminateCommonSubexpr, opt_level: 3, required passes: [\n",
            "InferType, ]\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "Running pass: {} The meta data of the pass - pass name: InferType, opt_level: 0, required passes: []\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "Running pass: {} The meta data of the pass - pass name: InferType, opt_level: 0, required passes: []\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "Running pass: {} The meta data of the pass - pass name: FuseOps, opt_level: 0, required passes: [\n",
            "InferType, ]\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "Running pass: {} The meta data of the pass - pass name: InferType, opt_level: 0, required passes: []\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "[09:34:18] /workspace/tvm/src/ir/transform.cc:655: PrintIR():\n",
            "#[version = \"0.0.5\"]\n",
            "def @main(%x: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "  %3 = fn (%p0: Tensor[(1, 64, 56, 56), float32] /* ty=Tensor[(1, 64, 56, 56), float32] */, %p1: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] */, %p2: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, %p3: Tensor[(1, 64, 54, 54), float32] /* ty=Tensor[(1, 64, 54, 54), float32] */, Primitive=1) -> Tensor[(1, 64, 54, 54), float32] {\n",
            "    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0], kernel_size=[3, 3]) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %1 = add(%0, %p2) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    %2 = add(%1, %p3) /* ty=Tensor[(1, 64, 54, 54), float32] */;\n",
            "    add(%2, %2) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "  } /* ty=fn (Tensor[(1, 64, 56, 56), float32], Tensor[(64, 64, 3, 3), float32], Tensor[(1, 64, 54, 54), float32], Tensor[(1, 64, 54, 54), float32]) -> Tensor[(1, 64, 54, 54), float32] */;\n",
            "  %3(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 64, 54, 54), float32] */, meta[relay.Constant][1] /* ty=Tensor[(1, 64, 54, 54), float32] */) /* ty=Tensor[(1, 64, 54, 54), float32] */\n",
            "}\n",
            "\n",
            "/* For debugging purposes the metadata section has been omitted.\n",
            " * If you would like to see the full metadata section you can set the \n",
            " * option to `True` when invoking `astext`. \n",
            " */\n"
          ]
        }
      ],
      "source": [
        "@tvm.instrument.pass_instrument\n",
        "class PrintIR:\n",
        "    \"\"\"Print the name of the pass, the IR, only before passes execute.\"\"\"\n",
        "\n",
        "    def run_before_pass(self, mod, info):\n",
        "        print(\"Running pass: {}\", info)\n",
        "        print(mod)\n",
        "\n",
        "\n",
        "with tvm.transform.PassContext(opt_level=3, instruments=[PrintIR()]):\n",
        "    with tvm.target.Target(\"llvm\"):\n",
        "        # Perform the optimizations.\n",
        "        mod = seq(mod)\n",
        "print(mod)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 小结\n",
        "\n",
        "本教程介绍了如何使用 pass infra 更方便地在 TVM 中编写和调用 pass。本文还讨论了调用 pass 的不同方法。使用 {py:class}`tvm.transform.Sequential` 可以很大程度上帮助用户简化处理多个优化传递及其依赖关系的工作。此外,还提供了示例来说明如何使用 ``PrintIR`` 和跟踪调试 pass。"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.8.13 ('py38': conda)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13"
    },
    "vscode": {
      "interpreter": {
        "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}