{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 使用 Relay Visualizer 可视化 Relay\n", "\n", "**原作者**: [Chi-Wei Wang](https://github.com/chiwwang)\n", "\n", "Relay IR 模块可以包含很多运算。尽管单个运算通常很容易理解,但将它们放在一起可能会导致复杂的、难以阅读的 graph。随着优化传递(passes)的出现,情况可能会变得更糟。\n", "\n", "这个实用程序将 IR 模块可视化为节点和边。它定义了一组接口,包括 parser、plotter(renderer)、graph、node 和 edges。\n", "提供了默认 parser。用户可以实现自己的 renderer 来渲染 graph。\n", "\n", "在这里,使用 renderer 在文本形式中渲染 graph。它是轻量级的、类似 AST 的可视化工具,灵感来自 [clang ast-dump](https://clang.llvm.org/docs/IntroductionToTheClangAST.html)。下面将介绍如何通过接口类实现定制的 parser 和 renderer。\n", "\n", "更多细节见:{py:mod}`tvm.contrib.relay_viz`。" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import tvm\n", "from tvm import relay\n", "from tvm.contrib.relay_viz import RelayVisualizer, DotPlotter, DotVizParser\n", "from tvm.contrib.relay_viz.interface import (\n", " VizEdge,\n", " VizNode,\n", " VizParser,\n", ")\n", "from tvm.contrib.relay_viz.terminal import (\n", " TermGraph,\n", " TermPlotter,\n", " TermVizParser,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义具有多个 `GlobalVar` 的 Relay IR 模块\n", "\n", "构建包含多个 `GlobalVar` 的示例 IR 模块。定义 `add` 函数,并在 `main` 函数中调用它。\n", "\n", "```{rubric} 创建 add 算子及其函数\n", "```" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "data = relay.var(\"data\")\n", "bias = relay.var(\"bias\")\n", "add_op = relay.add(data, bias)\n", "add_func = relay.Function([data, bias], add_op)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "查看算子和函数:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "算子:\n", "free_var %data;\n", "free_var %bias;\n", "add(%data, %bias)\n", "====================\n", "函数:\n", "fn (%data, %bias) {\n", " add(%data, %bias)\n", "}\n" ] } ], "source": [ "print(f\"算子:\\n{add_op}\")\n", "print(\"=\"*20)\n", "print(f\"函数:\\n{add_func}\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "add_gvar = relay.GlobalVar(\"AddFunc\")\n", "input0 = relay.var(\"input0\")\n", "input1 = relay.var(\"input1\")\n", "input2 = relay.var(\"input2\")\n", "add_01 = relay.Call(add_gvar, [input0, input1])\n", "add_012 = relay.Call(add_gvar, [input2, add_01])\n", "main_func = relay.Function([input0, input1, input2], add_012)\n", "main_gvar = relay.GlobalVar(\"main\")\n", "\n", "mod = tvm.IRModule({main_gvar: main_func,\n", " add_gvar: add_func})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 在终端上使用 Relay Visualizer 渲染 graph\n", "\n", "终端是类似 clang AST-dump 的文本形式显示 Relay IR 模块。\n", "\n", "看到 ``main`` 和 ``AddFunc`` 函数。``AddFunc`` 在 ``main`` 函数中调用两次。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@main([Var(input0), Var(input1), Var(input2)])\n", "`--Call \n", " |--GlobalVar AddFunc\n", " |--Var(Input) name_hint: input2\n", " `--Call \n", " |--GlobalVar AddFunc\n", " |--Var(Input) name_hint: input0\n", " `--Var(Input) name_hint: input1\n", "@AddFunc([Var(data), Var(bias)])\n", "`--Call \n", " |--add \n", " |--Var(Input) name_hint: data\n", " `--Var(Input) name_hint: bias\n" ] } ], "source": [ "viz = RelayVisualizer(mod)\n", "viz.render()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 为感兴趣的 Relay 类型定制解析器\n", "\n", "有时想要强调感兴趣的信息,或者针对特定的用法以不同的方式分析事物。只要遵循接口,就可以提供定制的解析器。\n", "\n", "这里演示如何自定义 ``relay.var`` 的解析器。\n", "\n", "需要实现抽象接口 {py:class}`tvm.contrib.relay_viz.interface.VizParser`。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class YourAwesomeParser(VizParser):\n", " def __init__(self):\n", " self._delegate = TermVizParser()\n", "\n", " def get_node_edges(\n", " self,\n", " node: relay.Expr,\n", " relay_param: dict[str, tvm.runtime.NDArray],\n", " node_to_id: dict[relay.Expr, str],\n", " ) -> tuple[VizNode | None, list[VizEdge]]:\n", "\n", " if isinstance(node, relay.Var):\n", " node = VizNode(node_to_id[node], \"AwesomeVar\", f\"name_hint {node.name_hint}\")\n", " # no edge is introduced. So return an empty list.\n", " return node, []\n", "\n", " # delegate other types to the other parser.\n", " return self._delegate.get_node_edges(node, relay_param, node_to_id)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "将解析器和感兴趣的渲染程序传递给可视化工具。\n", "\n", "这里只是终端(terminal)渲染器。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@main([Var(input0), Var(input1), Var(input2)])\n", "`--Call \n", " |--GlobalVar AddFunc\n", " |--AwesomeVar name_hint input2\n", " `--Call \n", " |--GlobalVar AddFunc\n", " |--AwesomeVar name_hint input0\n", " `--AwesomeVar name_hint input1\n", "@AddFunc([Var(data), Var(bias)])\n", "`--Call \n", " |--add \n", " |--AwesomeVar name_hint data\n", " `--AwesomeVar name_hint bias\n" ] } ], "source": [ "viz = RelayVisualizer(mod, {}, TermPlotter(), YourAwesomeParser())\n", "viz.render()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定制 Graph 和 Plotter\n", "\n", "除了解析器,还可以通过实现抽象类 {py:class}`tvm.contrib.relay_viz.interface.VizGraph` 和 {py:class}`tvm.contrib.relay_viz.interface.Plotter` 来定制 graph 和渲染器。\n", "\n", "这里,重写了 ``terminal.py`` 中定义的 ``TermGraph``,以方便演示。在 ``AwesomeVar`` 上面添加了钩子,并让 ``TermPlotter`` 使用新类。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "@main([Var(input0), Var(input1), Var(input2)])\n", "`--Call \n", " |--GlobalVar AddFunc\n", " |--AwesomeVar name_hint input2\n", " | `--double AwesomeVar \n", " `--Call \n", " |--GlobalVar AddFunc\n", " |--AwesomeVar name_hint input0\n", " | `--double AwesomeVar \n", " `--AwesomeVar name_hint input1\n", " `--double AwesomeVar \n", "@AddFunc([Var(data), Var(bias)])\n", "`--Call \n", " |--add \n", " |--AwesomeVar name_hint data\n", " | `--double AwesomeVar \n", " `--AwesomeVar name_hint bias\n", " `--double AwesomeVar \n" ] } ], "source": [ "class AwesomeGraph(TermGraph):\n", " def node(self, viz_node):\n", " # add the node first\n", " super().node(viz_node)\n", " # if it's AwesomeVar, duplicate it.\n", " if viz_node.type_name == \"AwesomeVar\":\n", " duplicated_id = f\"duplicated_{viz_node.identity}\"\n", " duplicated_type = \"double AwesomeVar\"\n", " super().node(VizNode(duplicated_id, duplicated_type, \"\"))\n", " # connect the duplicated var to the original one\n", " super().edge(VizEdge(duplicated_id, viz_node.identity))\n", "\n", "\n", "# override TermPlotter to use `AwesomeGraph` instead\n", "class AwesomePlotter(TermPlotter):\n", " def create_graph(self, name):\n", " self._name_to_graph[name] = AwesomeGraph(name)\n", " return self._name_to_graph[name]\n", "\n", "\n", "viz = RelayVisualizer(mod, {}, AwesomePlotter(), YourAwesomeParser())\n", "viz.render()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "也可以渲染为:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from tvm.contrib import relay_viz\n", "from tvm.relay.testing import resnet\n", "\n", "mod, param = resnet.get_workload(num_layers=18)\n", "# graphviz attributes\n", "graph_attr = {\"color\": \"red\"}\n", "node_attr = {\"color\": \"blue\"}\n", "edge_attr = {\"color\": \"black\"}\n", "\n", "# VizNode is passed to the callback.\n", "# We want to color NCHW conv2d nodes. Also give Var a different shape.\n", "def get_node_attr(node):\n", " if \"nn.conv2d\" in node.type_name and \"NCHW\" in node.detail:\n", " return {\n", " \"fillcolor\": \"green\",\n", " \"style\": \"filled\",\n", " \"shape\": \"box\",\n", " }\n", " if \"Var\" in node.type_name:\n", " return {\"shape\": \"ellipse\"}\n", " return {\"shape\": \"box\"}\n", "\n", "\n", "# Create plotter and pass it to viz. Then render the graph.\n", "dot_plotter = DotPlotter(\n", " graph_attr=graph_attr,\n", " node_attr=node_attr,\n", " edge_attr=edge_attr,\n", " get_node_attr=get_node_attr)\n", "\n", "viz = RelayVisualizer(\n", " mod,\n", " relay_param=param,\n", " plotter=dot_plotter,\n", " parser=DotVizParser())\n", "viz.render(\"hello\") # 保存到 PDF" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 小结\n", "\n", "本教程演示了 Relay Visualizer 及其定制的用法。\n", "\n", "{py:class}`tvm.contrib.relay_viz.RelayVisualizer` 由定义在 ``interface.py`` 中的接口组成。\n", "\n", "它的目标是快速 look-then-fix 迭代。构造函数参数的目的是简单,而定制仍然可以通过一组接口类进行。" ] } ], "metadata": { "interpreter": { "hash": "aa67ff675248b5ab29dcd2f00c1422844307085c8ca7c8ce7eddecd21b9c2975" }, "kernelspec": { "display_name": "Python 3.10.4 ('mxnetx')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 0 }