Relay 中的模式匹配#

在 TVM 中有很多地方,决定了 Relay 程序的纯数据流子图,并试图以某种方式变换它们,例如融合、量化、外部代码生成和设备特定的优化,如 bitpacking 和 VTA 使用的层切片。

今天,许多这样的方法都需要大量无聊的样板代码来实现,并要求用户从 visitor 和 AST 匹配的角度考虑问题。许多这样的变换可以很容易地用图重写来描述。为了构建 rewriter 或其他高级机制,首先需要模式语言来描述可以匹配的内容。

这样的语言不仅对构建 rewriter 有用,而且还为现有的 pass 提供了扩展点。例如,融合 pass 可以通过一组描述硬件能力的融合模式来参数化,量化通道可以采用一组模式来描述在给定平台上可以量化的算子。

在后端世界,可以使用相同的机制来构建更高级别的 API,使用自己的代码生成。这个 API 采用了一组描述你的硬件能力的模式和外部编译器,提供了相对平稳的开箱即用的异构体验。

模式示例#

有相当多的算子的属性值得匹配。下面将研究如何匹配树的属性,并扩展原型中未充分探索的一些用例。本节演示如何编写模式。建议查看 tests/python/relay/test_dataflow_pattern.py 了解更多用例。

备注

如果您无法找到与您想要的 Relay 节点匹配的对应模式节点,欢迎您提出 issue 或提交 PR 来添加它。

匹配两个 Ops 中的一个#

第一个例子是简单的例子,想要匹配带有单输入的算子或另一个单输入的算子:

def test_match_op_or():
    is_add_or_sub = is_op('add') | is_op('subtract')
    assert is_add_or_sub.match(relay.op.op.get("add"))
    assert is_add_or_sub.match(relay.op.op.get("subtract"))

使用属性匹配 Op#

下一个例子是 dense 运算,带有任何标记为 element-wise 的算子:

def test_no_match_attr():
    op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
    op_pat = op(wildcard(), wildcard())
    x = relay.var('x')
    y = relay.var('y')
    assert not op_pat.match(relay.op.nn.dense(x, y))

下面是另一个使用特定属性匹配 op 的例子:

def test_match_data_layout():
    is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"})
    x = relay.var('x')
    y = relay.var('y')
    assert not is_conv2d.match(relay.op.nn.conv2d(x, y))

或者具有特定 kernel 大小的卷积:

def test_match_kernel_size():
    is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
    x = relay.var('x')
    y = relay.var('y')
    assert is_conv2d.match(relay.op.nn.conv2d(x, y, kernel_size=[3, 3]))

匹配 Optional Op#

下一个例子是用可选算子匹配一个模式。在这个模式中,可以匹配 conv2d+bias_add+relu graph 或 conv2d+bias_add graph。

def test_match_optional():
    conv_node = is_op('nn.conv2d')(wildcard(), wildcard())
    bias_node = is_op('nn.bias_add')(conv_node, wildcard())
    pat = bias_node.optional(lambda x: is_op('nn.relu')(x))

    x = relay.var('x')
    y = relay.var('y')
    z = relay.var('z')
    conv2d = relay.op.nn.conv2d(x, y)
    bias = relay.op.nn.bias_add(conv2d, z)
    assert pat.match(bias)
    relu = relay.op.nn.relu(bias)
    assert pat.match(relu)

匹配类型#

除了用属性来匹配 ops,还可以根据形状和数据类型,制作模式来匹配它们的类型。这里有一些例子:

def test_match_type():
    # Match any op with float32
    pat1 = has_dtype('float32')
    x = relay.var('x', shape=(10, 10), dtype='float32')
    assert pat1.match(x)

    # Match any op with shape (10, 10)
    pat2 = has_shape((10, 10))
    x = relay.var('x', shape=(10, 10), dtype='float32')
    assert pat2.match(x)

    # Match conv2d+relu with a certain shape
    conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
    pat3 = is_op('nn.relu')(conv2d).has_shape((1, 32, 28, 28))

    x = relay.var('x', shape=(1, 3, 28, 28), dtype='float32')
    w = relay.var('w', shape=(32, 3, 3, 3), dtype='float32')
    conv2d = relay.nn.conv2d(x, w, strides=(1, 1), padding=(1, 1))
    relu = relay.nn.relu(conv2d)
    assert pat3.match(relu)

匹配 Non-Call 节点#

有时可能还想匹配包含 Tuple 或 TupleGetItem 节点的模式。由于不是 call 节点,需要使用特定的模式节点来匹配它们:

def test_match_tuple():
    x = relay.var('x')
    y = relay.var('y')
    z = relay.var('z')
    tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard()))
    assert tuple_pattern.match(relay.expr.Tuple((x,y,z)))

下一个例子是匹配 batch_norm -> get(0) -> relu。注意,您还可以使用 is_tuple_get_item(bn_node) 来匹配 TupleGetItem 节点和任何索引。

def test_match_tuple_get_item():
    bn_node = is_op('nn.batch_norm')(wildcard(), wildcard(), wildcard(), wildcard(), wildcard())
    tuple_get_item_node = is_tuple_get_item(bn_node, 0)
    pat = is_op('nn.relu')(tuple_get_item_node)

    x = relay.var('x', shape=(1, 8))
    gamma = relay.var("gamma", shape=(8,))
    beta = relay.var("beta", shape=(8,))
    moving_mean = relay.var("moving_mean", shape=(8,))
    moving_var = relay.var("moving_var", shape=(8,))
    bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var)
    tuple_get_item_node = bn_node[0]
    out = relay.nn.relu(tuple_get_item_node)
    pat.match(out)

如果有跨越函数边界的模式,可能希望匹配函数本身

def test_match_func():
    x = relay.var("x")
    y = relay.var("y")
    wc1 = wildcard()
    wc2 = wildcard()
    func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
    assert func_pattern.match(relay.Function([x, y], x + y))

下一个例子是匹配 constant 节点的值。这对于检查子图中的特定参数是否被绑定很有用。

def test_match_constant():
    conv2d = is_op('nn.conv2d')(wildcard(), is_constant())
    pattern = is_op('nn.bias_add')(conv2d, wildcard())

    x = relay.var('x', shape=(1, 3, 224, 224))
    w = relay.var('w', shape=(3, 3, 3, 3))
    b = relay.var('b', shape=(3, ))
    conv2d = relay.op.nn.conv2d(x, w)
    out = relay.op.nn.bias_add(conv2d, b)
    func = relay.Function([x, w, b], out)
    mod = tvm.IRModule.from_expr(func)

    # Two inputs of the conv2d in the graph are VarNode by default, so no match.
    assert not pattern.match(mod['main'].body)

    # The second input (weight) has been bind with constant values so it is now a constant node.
    mod["main"] = bind_params_by_name(mod["main"],
                                    {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
    assert pattern.match(mod['main'].body)

另一方面,如果需要将常数与特定值匹配,可以直接使用 is_expr。这对代数简化很有用。

def test_match_plus_zero():
    zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0)))
    pattern = wildcard() + zero

    x = relay.Var('x')
    y = x + relay.const(0)
    assert pattern.match(y)

下一个例子是将函数节点与特定属性匹配:

def test_match_function():
    pattern = wildcard().has_attr({"Composite": "add"})

    x = relay.var('x')
    y = relay.var('y')
    f = relay.Function([x, y], x + y).with_attr("Composite", "add")
    assert pattern.match(f)

Relay If 表达式,如果它的所有条件,真分支和假分支都匹配,就可以匹配:

def test_match_if():
    x = is_var("x")
    y = is_var("y")
    pat = is_if(is_op("less")(x, y), x, y)

    x = relay.var("x")
    y = relay.var("y")
    cond = x < y

    assert pat.match(relay.expr.If(cond, x, y))

如果 Relay Let 表达式的所有变量、值和 body 都匹配,那么它就可以被匹配:

def test_match_let():
    x = is_var("x")
    y = is_var("y")
    let_var = is_var("let")
    pat = is_let(let_var, is_op("less")(x, y), let_var)

    x = relay.var("x")
    y = relay.var("y")
    lv = relay.var("let")
    cond = x < y
    assert pat.match(relay.expr.Let(lv, cond, lv))

匹配 Diamond 和 Post-Dominator Graph#

下一个例子是在 diamond 的顶部匹配两个 inputs

def test_match_diamond():
    # Pattern
    is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
    path1 = is_op('nn.relu')(is_conv2d)
    path2 = is_op('nn.leaky_relu')(is_conv2d)
    diamond = is_op('add')(path1, path2)

    # Expr
    inp = relay.var('input')
    weight = relay.var('weight')
    conv2d = relay.op.nn.conv2d(inp, weight)
    relu = relay.op.nn.relu(conv2d)
    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
    out = relu + leaky_relu

    # Check
    assert diamond.match(out)

最后一个例子是将 diamond 与 post-dominator 的关系相匹配。在模式语言中嵌入支配分析作为匹配类型,以允许未知拓扑的模式匹配。这很重要,因为希望能够使用语言来描述融合模式,比如 elementwise 运算后面跟着 conv2d:

def test_match_dom_diamond():
    # Pattern
    is_conv2d = is_op('nn.conv2d')(is_var(), is_var())
    reduction = is_op('add')(wildcard(), wildcard())
    diamond = dominates(is_conv2d, is_elemwise, reduction)

    # Expr
    inp = relay.var('input')
    weight = relay.var('weight')
    conv2d = relay.op.nn.conv2d(inp, weight)
    relu = relay.op.nn.relu(conv2d)
    leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
    out = relu + leaky_relu

    # Check
    assert diamond.match(out)

模糊匹配模式#

上面的 Dominator 分析允许匹配 Relay AST 的子图,该子图不与一组模式节点精确地 1-to-1 对应。在其他一些地方,也支持这种模糊(”fuzzy”)匹配。

Tuple、Function 和具有任意数量输入的 Call 节点可以通过传递 None 作为参数值来匹配,即

tuple_pattern = is_tuple(None)
func_pattern = FunctionPattern(None, wildcard() + wildcard())
call_pattern = func_pattern(None)

这些模式通过限制参数的使用而不是参数的数量来匹配更通用的类模式。

此外,支持模糊体匹配(fuzzy bodies)函数,即受模式约束的函数体。模式 FunctionPattern([is_var(), is_var()], wildcard() + wildcard()]) 将匹配 relay.Function([x, y], x + y),但它也将匹配 relay.Function([x, y], x * x + y)。在第二种情况下,模式没有完美地约束函数体,因此产生的匹配是模糊的。

模式语言设计#

提出的模式语言被设计成 Relay IR 的镜像,并对常见场景提供额外的支持。模式语言的目标是提供类似正则表达式的功能来匹配数据流图并进行重写。

高层次的设计是引入模式语言,现在提出这种语言为

Pattern ::= expr
        | *
        | pattern(pattern1, ... patternN)
        | has_type(type)
        | has_dtype(type)
        | has_shape(shape)
        | has_attr(attrs)
        | is_var(name)
        | is_constant()
        | is_expr(expr)
        | is_op(op_name)
        | is_tuple()
        | is_tuple_get_item(pattern, index = None)
        | is_if(cond, tru, fls)
        | is_let(var, value, body)
        | pattern1 `|` pattern2
        | dominates(parent_pattern, path_pattern, child_pattern)
        | FunctionPattern(params, body)

然后,上述语言提供了匹配接口,可以选择子图,以及验证图是否匹配模式。

表达式模式#

匹配 literal 表达式。

通配符#

匹配任何表达式。

类型模式#

检查嵌套模式匹配的表达式是否具有特定的类型。

DType 模式#

检查嵌套模式匹配的表达式是否具有特定的数据类型。

Shape 模式#

检查与嵌套模式匹配的表达式是否具有特定的输出形状。

属性模式#

检查与模式匹配的算子是否具有具有特定值的属性。

变量模式#

检查表达式是否是 relay 变量,并可选地提供与变量名匹配的名称。

备用#

要么匹配第一种模式,要么匹配第二种模式。

Domination#

匹配子模式,找到父模式的匹配,确保子模式最终主导父模式(即,模式之外的节点没有使用父模式的输出),并且子模式和模式之间的任何节点都匹配路径模式。

函数模式#

用函数体和参数匹配函数

If 模式#

将 If 与条件、真分支和假分支匹配

Let 模式#

将 Let 与变量、值和 body 匹配

应用#

模式语言不仅提供模式匹配,还提供模式处理。这里将介绍两种模式处理方法并提供一些示例。

模式重写#

如果您想用另一个子图替换匹配的模式,您可以利用 rewrite 变换。下面是使用单个 batch_norm op 重写一系列算术算子的示例。构造函数参数 require_type 指示是否需要在回调之前运行 InferType。

class BatchnormCallback(DFPatternCallback):
    # A callback class to rewrite the matched pattern to a batch_norm op.
    def __init__(self, require_type=False):
        super().__init__(require_type)
        self.x = wildcard()
        self.var = wildcard()
        self.mean = wildcard()
        self.beta = wildcard()
        self.gamma = wildcard()
        self.eps = wildcard()

        self.pattern = self.gamma * (self.x - self.mean)/is_op("sqrt")(self.var + self.eps) + self.beta

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        var = node_map[self.var][0]
        mean = node_map[self.mean][0]
        beta = node_map[self.beta][0]
        gamma = node_map[self.gamma][0]
        eps = node_map[self.eps][0]
        return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.numpy().item())[0]

    # A graph of arithmetic operators that are functional equivalent to batch_norm.
    x = relay.var('x')
    var = relay.var('var')
    mean = relay.var('mean')
    beta = relay.var('beta')
    gamma = relay.var('gamma')
    BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta

    from tvm.relay.dataflow_pattern import rewrite
    out = rewrite(BatchnormCallback(), BN)
    assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0])

def callback(self, pre, post, node_map) 将在 rewriter 匹配 self.pattern 时被调用。node_map 是从模式节点映射到图中匹配节点的字典。

回调函数将在返回的模式上递归调用,直到模式停止变化。因此,如果 self.pattern 匹配回调返回的图的任何部分,rewriter 将循环运行。如果你想避免多次重写,你可以向构造函数传递 rewrite_once=True 参数。

模式分区#

如果您想对匹配的子图执行更复杂的处理,而您不满足于 rewrite,您可以考虑将匹配的子图划分到单独的 Relay 函数,并对该函数执行其他处理。这里使用 pattern.partition 为每个匹配的子图创建新的 Relay 函数。该功能类似于 TVM 中的 op 融合 pass:

# A pattern matching conv2d+relu.
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))

# A graph.
x = relay.var('input')
w = relay.var('weight')
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
print('relu')
# free_var %x: Tensor[(1, 3, 224, 224), float32]
# free_var %w: Tensor[(3, 3, 3, 3), float32]
# %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 222, 222), float32] */;
# free_var %b: Tensor[(3), float32]
# nn.bias_add(%0, %b) /* ty=Tensor[(1, 3, 222, 222), float32] */

# After partition.
print(pattern.partition(relu))
# free_var %x: Tensor[(1, 3, 224, 224), float32]
# free_var %w: Tensor[(3, 3, 3, 3), float32]
# free_var %b: Tensor[(3), float32]
# %1 = fn (%FunctionVar_0_0, %FunctionVar_0_1,
#          %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_") {
#   %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
#   nn.bias_add(%0, %FunctionVar_0_2)
# };
# %1(%x, %w, %b)

注意,你也可以为创建的函数指定属性:

print(pattern.partition(relu, {'Composite': 'one_layer'}))
# free_var %x: Tensor[(1, 3, 224, 224), float32]
# free_var %w: Tensor[(3, 3, 3, 3), float32]
# free_var %b: Tensor[(3), float32]
# %1 = fn (%FunctionVar_0_0, %FunctionVar_0_1,
#          %FunctionVar_0_2, Composite="one_layer",
#                            PartitionedFromPattern="nn.conv2d_nn.bias_add_") {
#   %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
#   nn.bias_add(%0, %FunctionVar_0_2)
# };
# %1(%x, %w, %b)

如果需要使用模式语言无法指定的自定义检查函数,可以在分区时指定 check 函数。下面的例子是演示检查子图输入数据布局的案例:

def check(pre):
    conv = pre.args[0]
    return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)

pattern.partition(relu, check=check)

在这个例子中,检查匹配的子图的第一个参数(即 pre.args[0] )是否有数据布局 “NCHW” 以及它的批大小是否为 1。如果模式匹配的条件不能通过分析模式本身来验证,那么这个特性就很有用。