使用 Relay Visualizer 可视化 Relay#

原作者: Chi-Wei Wang

Relay IR 模块可以包含很多运算。尽管单个运算通常很容易理解,但将它们放在一起可能会导致复杂的、难以阅读的 graph。随着优化传递(passes)的出现,情况可能会变得更糟。

这个实用程序将 IR 模块可视化为节点和边。它定义了一组接口,包括 parser、plotter(renderer)、graph、node 和 edges。 提供了默认 parser。用户可以实现自己的 renderer 来渲染 graph。

在这里,使用 renderer 在文本形式中渲染 graph。它是轻量级的、类似 AST 的可视化工具,灵感来自 clang ast-dump。下面将介绍如何通过接口类实现定制的 parser 和 renderer。

更多细节见:tvm.contrib.relay_viz

import tvm
from tvm import relay
from tvm.contrib.relay_viz import RelayVisualizer, DotPlotter, DotVizParser
from tvm.contrib.relay_viz.interface import (
    VizEdge,
    VizNode,
    VizParser,
)
from tvm.contrib.relay_viz.terminal import (
    TermGraph,
    TermPlotter,
    TermVizParser,
)

定义具有多个 GlobalVar 的 Relay IR 模块#

构建包含多个 GlobalVar 的示例 IR 模块。定义 add 函数,并在 main 函数中调用它。

创建 add 算子及其函数

data = relay.var("data")
bias = relay.var("bias")
add_op = relay.add(data, bias)
add_func = relay.Function([data, bias], add_op)

查看算子和函数:

print(f"算子:\n{add_op}")
print("="*20)
print(f"函数:\n{add_func}")
算子:
free_var %data;
free_var %bias;
add(%data, %bias)
====================
函数:
fn (%data, %bias) {
  add(%data, %bias)
}
add_gvar = relay.GlobalVar("AddFunc")
input0 = relay.var("input0")
input1 = relay.var("input1")
input2 = relay.var("input2")
add_01 = relay.Call(add_gvar, [input0, input1])
add_012 = relay.Call(add_gvar, [input2, add_01])
main_func = relay.Function([input0, input1, input2], add_012)
main_gvar = relay.GlobalVar("main")

mod = tvm.IRModule({main_gvar: main_func,
                    add_gvar: add_func})

在终端上使用 Relay Visualizer 渲染 graph#

终端是类似 clang AST-dump 的文本形式显示 Relay IR 模块。

看到 mainAddFunc 函数。AddFuncmain 函数中调用两次。

viz = RelayVisualizer(mod)
viz.render()
@main([Var(input0), Var(input1), Var(input2)])
`--Call 
   |--GlobalVar AddFunc
   |--Var(Input) name_hint: input2
   `--Call 
      |--GlobalVar AddFunc
      |--Var(Input) name_hint: input0
      `--Var(Input) name_hint: input1
@AddFunc([Var(data), Var(bias)])
`--Call 
   |--add 
   |--Var(Input) name_hint: data
   `--Var(Input) name_hint: bias

为感兴趣的 Relay 类型定制解析器#

有时想要强调感兴趣的信息,或者针对特定的用法以不同的方式分析事物。只要遵循接口,就可以提供定制的解析器。

这里演示如何自定义 relay.var 的解析器。

需要实现抽象接口 tvm.contrib.relay_viz.interface.VizParser

class YourAwesomeParser(VizParser):
    def __init__(self):
        self._delegate = TermVizParser()

    def get_node_edges(
        self,
        node: relay.Expr,
        relay_param: dict[str, tvm.runtime.NDArray],
        node_to_id: dict[relay.Expr, str],
    ) -> tuple[VizNode | None, list[VizEdge]]:

        if isinstance(node, relay.Var):
            node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}")
            # no edge is introduced. So return an empty list.
            return node, []

        # delegate other types to the other parser.
        return self._delegate.get_node_edges(node, relay_param, node_to_id)

将解析器和感兴趣的渲染程序传递给可视化工具。

这里只是终端(terminal)渲染器。

viz = RelayVisualizer(mod, {}, TermPlotter(), YourAwesomeParser())
viz.render()
@main([Var(input0), Var(input1), Var(input2)])
`--Call 
   |--GlobalVar AddFunc
   |--AwesomeVar name_hint input2
   `--Call 
      |--GlobalVar AddFunc
      |--AwesomeVar name_hint input0
      `--AwesomeVar name_hint input1
@AddFunc([Var(data), Var(bias)])
`--Call 
   |--add 
   |--AwesomeVar name_hint data
   `--AwesomeVar name_hint bias

定制 Graph 和 Plotter#

除了解析器,还可以通过实现抽象类 tvm.contrib.relay_viz.interface.VizGraphtvm.contrib.relay_viz.interface.Plotter 来定制 graph 和渲染器。

这里,重写了 terminal.py 中定义的 TermGraph,以方便演示。在 AwesomeVar 上面添加了钩子,并让 TermPlotter 使用新类。

class AwesomeGraph(TermGraph):
    def node(self, viz_node):
        # add the node first
        super().node(viz_node)
        # if it's AwesomeVar, duplicate it.
        if viz_node.type_name == "AwesomeVar":
            duplicated_id = f"duplicated_{viz_node.identity}"
            duplicated_type = "double AwesomeVar"
            super().node(VizNode(duplicated_id, duplicated_type, ""))
            # connect the duplicated var to the original one
            super().edge(VizEdge(duplicated_id, viz_node.identity))


# override TermPlotter to use `AwesomeGraph` instead
class AwesomePlotter(TermPlotter):
    def create_graph(self, name):
        self._name_to_graph[name] = AwesomeGraph(name)
        return self._name_to_graph[name]


viz = RelayVisualizer(mod, {}, AwesomePlotter(), YourAwesomeParser())
viz.render()
@main([Var(input0), Var(input1), Var(input2)])
`--Call 
   |--GlobalVar AddFunc
   |--AwesomeVar name_hint input2
   |  `--double AwesomeVar 
   `--Call 
      |--GlobalVar AddFunc
      |--AwesomeVar name_hint input0
      |  `--double AwesomeVar 
      `--AwesomeVar name_hint input1
         `--double AwesomeVar 
@AddFunc([Var(data), Var(bias)])
`--Call 
   |--add 
   |--AwesomeVar name_hint data
   |  `--double AwesomeVar 
   `--AwesomeVar name_hint bias
      `--double AwesomeVar 

也可以渲染为:

from tvm.contrib import relay_viz
from tvm.relay.testing import resnet

mod, param = resnet.get_workload(num_layers=18)
# graphviz attributes
graph_attr = {"color": "red"}
node_attr = {"color": "blue"}
edge_attr = {"color": "black"}

# VizNode is passed to the callback.
# We want to color NCHW conv2d nodes. Also give Var a different shape.
def get_node_attr(node):
    if "nn.conv2d" in node.type_name and "NCHW" in node.detail:
        return {
            "fillcolor": "green",
            "style": "filled",
            "shape": "box",
        }
    if "Var" in node.type_name:
        return {"shape": "ellipse"}
    return {"shape": "box"}


# Create plotter and pass it to viz. Then render the graph.
dot_plotter = DotPlotter(
    graph_attr=graph_attr,
    node_attr=node_attr,
    edge_attr=edge_attr,
    get_node_attr=get_node_attr)

viz = RelayVisualizer(
    mod,
    relay_param=param,
    plotter=dot_plotter,
    parser=DotVizParser())
viz.render("hello")  # 保存到 PDF

小结#

本教程演示了 Relay Visualizer 及其定制的用法。

tvm.contrib.relay_viz.RelayVisualizer 由定义在 interface.py 中的接口组成。

它的目标是快速 look-then-fix 迭代。构造函数参数的目的是简单,而定制仍然可以通过一组接口类进行。