构建图卷积网络#

原作者: Yulun YaoChien-Yu Lin

本文是介绍性教程,介绍如何使用 Relay 构建图卷积网络(Graph Convolutional Network,简称 GCN)。在本教程中,将在 Cora 数据集上演示 GCN。Cora 数据集是图神经网络(Graph Neural Networks,简称 GNN)和支持 GNN 训练和推理的框架的通用基准。直接从 DGL 库加载数据集,以便与 DGL 进行苹果对苹果的比较。

DGL 安装请参阅

用 PyTorch 后端在 DGL 中定义 GCN#

DGL 示例 部分重用了上面示例中的代码。

import torch
from torch import nn
from torch.nn import functional as F
import dgl
import networkx as nx
from dgl.nn.pytorch import GraphConv


class GCN(nn.Module):
    def __init__(self, g, n_infeat, n_hidden, n_classes, n_layers, activation):
        super().__init__()
        self.g = g
        self.layers = nn.ModuleList()
        self.layers.append(GraphConv(n_infeat, n_hidden, activation=activation))
        for i in range(n_layers - 1):
            self.layers.append(GraphConv(n_hidden, n_hidden, activation=activation))
        self.layers.append(GraphConv(n_hidden, n_classes))

    def forward(self, features):
        h = features
        for i, layer in enumerate(self.layers):
            # 处理不同 DGL 版本的 api 变更
            if dgl.__version__ > "0.3":
                h = layer(self.g, h)
            else:
                h = layer(h, self.g)
        return h

定义加载数据集和评估准确性的函数#

你可以用你自己的数据集代替这一部分,这里我们从 DGL 加载数据:

from dgl.data import load_data
from collections import namedtuple


def load_dataset(dataset="cora"):
    args = namedtuple("args", ["dataset"])
    data = load_data(args(dataset))

    # 删除自循环,以避免重复传递节点的特性给自身
    g = data.graph
    g.remove_edges_from(nx.selfloop_edges(g))
    g.add_edges_from(zip(g.nodes, g.nodes))

    return g, data


def evaluate(data, logits):
    # 训练阶段中不包含的测试集
    test_mask = data.test_mask

    pred = logits.argmax(axis=1)
    acc = ((pred == data.labels) * test_mask).sum() / test_mask.sum()

    return acc

加载数据并设置模型参数#

"""
Parameters
----------
dataset: str
    Name of dataset. You can choose from ['cora', 'citeseer', 'pubmed'].

num_layer: int
    number of hidden layers

num_hidden: int
    number of the hidden units in the hidden layer

infeat_dim: int
    dimension of the input features

num_classes: int
    dimension of model output (Number of classes)
"""
dataset = "cora"
g, data = load_dataset(dataset)

num_layers = 1
num_hidden = 16
infeat_dim = data.features.shape[1]
num_classes = data.num_labels
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/dgl/data/utils.py:288: UserWarning: Property dataset.graph will be deprecated, please use dataset[0] instead.
  warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/dgl/data/utils.py:288: UserWarning: Property dataset.feat will be deprecated, please use g.ndata['feat'] instead.
  warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/dgl/data/utils.py:288: UserWarning: Property dataset.num_labels will be deprecated, please use dataset.num_classes instead.
  warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))

设定 DGL-PyTorch 模型并获得黄金结果#

被训练的 weights

import env

from tvm.contrib.download import download_testdata
from dgl import DGLGraph

features = torch.FloatTensor(data.features)
dgl_g = DGLGraph(g)

torch_model = GCN(dgl_g, infeat_dim, num_hidden, num_classes, num_layers, F.relu)

# Download the pretrained weights
model_url = "https://homes.cs.washington.edu/~cyulin/media/gnn_model/gcn_%s.torch" % (dataset)
model_path = download_testdata(model_url, "gcn_%s.pickle" % (dataset), module="gcn_model")

# Load the weights into the model
torch_model.load_state_dict(torch.load(model_path))
/media/pc/data/4tb/lxw/books/tvm/xinetzone/src
/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/dgl/data/utils.py:288: UserWarning: Property dataset.feat will be deprecated, please use g.ndata['feat'] instead.
  warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/dgl/heterograph.py:72: DGLWarning: Recommend creating graphs by `dgl.graph(data)` instead of `dgl.DGLGraph(data)`.
  dgl_warning('Recommend creating graphs by `dgl.graph(data)`'
<All keys matched successfully>

运行 DGL 模型并测试其准确性#

torch_model.eval()
with torch.no_grad():
    logits_torch = torch_model(features)
print("Print the first five outputs from DGL-PyTorch execution\n", logits_torch[:5])

acc = evaluate(data, logits_torch.numpy())
print("Test accuracy of DGL results: {:.2%}".format(acc))
Print the first five outputs from DGL-PyTorch execution
 tensor([[ 0.2640, -1.0674,  0.0736,  0.7828, -0.7666, -0.0291, -0.1403],
        [ 0.2670, -0.9722,  0.0714,  0.6953, -0.6088, -0.0735, -0.1660],
        [ 0.2985, -0.9762,  0.1139,  0.5794, -0.5615, -0.0353, -0.1830],
        [ 0.2773, -1.2461,  0.0398,  0.9599, -1.0011,  0.0598, -0.1064],
        [ 0.3692, -1.0940,  0.0363,  0.6424, -0.6491,  0.0804, -0.1536]])
Test accuracy of DGL results: 5.30%
/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/dgl/data/utils.py:288: UserWarning: Property dataset.test_mask will be deprecated, please use g.ndata['test_mask'] instead.
  warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))
/media/workspace/anaconda3/envs/torchx/lib/python3.10/site-packages/dgl/data/utils.py:288: UserWarning: Property dataset.label will be deprecated, please use g.ndata['label'] instead.
  warnings.warn('Property {} will be deprecated, please use {} instead.'.format(old, new))

在 Relay 中定义图卷积层#

要在 TVM 上运行 GCN,首先需要实现 Graph Convolution Layer。可以参考 在 DGL 中使用 MXNet 后端实现的 GraphConv 层

该层定义如下运算,注意应用了两次转置来保持 sparse_dense 算子右手边的邻接矩阵,这个方法是临时的,在接下来的几周当有稀疏矩阵转置并且支持左稀疏算子的时候会更新。

\[\mbox{GraphConv}(A, H, W) = A * H * W = ((H * W)^t * A^t)^t = ((W^t * H^t) * A^t)^t\]
from tvm import relay
from tvm.contrib import graph_executor
import tvm
from tvm import te


def GraphConv(layer_name, input_dim, output_dim, adj, input, norm=None, bias=True, activation=None):
    """
    Parameters
    ----------
    layer_name: str
    Name of layer

    input_dim: int
    Input dimension per node feature

    output_dim: int,
    Output dimension per node feature

    adj: namedtuple,
    Graph representation (Adjacency Matrix) in Sparse Format (`data`, `indices`, `indptr`),
    where `data` has shape [num_nonzeros], indices` has shape [num_nonzeros], `indptr` has shape [num_nodes + 1]

    input: relay.Expr,
    Input feature to current layer with shape [num_nodes, input_dim]

    norm: relay.Expr,
    Norm passed to this layer to normalize features before and after Convolution.

    bias: bool
    Set bias to True to add bias when doing GCN layer

    activation: <function relay.op.nn>,
    Activation function applies to the output. e.g. relay.nn.{relu, sigmoid, log_softmax, softmax, leaky_relu}

    Returns
    ----------
    output: tvm.relay.Expr
    The Output Tensor for this layer [num_nodes, output_dim]
    """
    if norm is not None:
        input = relay.multiply(input, norm)

    weight = relay.var(layer_name + ".weight", shape=(input_dim, output_dim))
    weight_t = relay.transpose(weight)
    dense = relay.nn.dense(weight_t, input)
    output = relay.nn.sparse_dense(dense, adj)
    output_t = relay.transpose(output)
    if norm is not None:
        output_t = relay.multiply(output_t, norm)
    if bias is True:
        _bias = relay.var(layer_name + ".bias", shape=(output_dim, 1))
        output_t = relay.nn.bias_add(output_t, _bias, axis=-1)
    if activation is not None:
        output_t = activation(output_t)
    return output_t

准备 GraphConv 层中所需的参数#

import numpy as np
import networkx as nx


def prepare_params(g, data):
    params = {}
    params["infeats"] = data.features.numpy().astype(
        "float32"
    )  # Only support float32 as feature for now

    # Generate adjacency matrix
    adjacency = nx.to_scipy_sparse_matrix(g)
    params["g_data"] = adjacency.data.astype("float32")
    params["indices"] = adjacency.indices.astype("int32")
    params["indptr"] = adjacency.indptr.astype("int32")

    # Normalization w.r.t. node degrees
    degs = [g.in_degree[i] for i in range(g.number_of_nodes())]
    params["norm"] = np.power(degs, -0.5).astype("float32")
    params["norm"] = params["norm"].reshape((params["norm"].shape[0], 1))

    return params


params = prepare_params(g, data)

# Check shape of features and the validity of adjacency matrix
assert len(params["infeats"].shape) == 2
assert (
    params["g_data"] is not None and params["indices"] is not None and params["indptr"] is not None
)
assert params["infeats"].shape[0] == params["indptr"].shape[0] - 1
/tmp/ipykernel_523780/3618165654.py:12: DeprecationWarning: 

The scipy.sparse array containers will be used instead of matrices
in Networkx 3.0. Use `to_scipy_sparse_array` instead.
  adjacency = nx.to_scipy_sparse_matrix(g)

把层放在一起#

# Define input features, norms, adjacency matrix in Relay
infeats = relay.var("infeats", shape=data.features.shape)
norm = relay.Constant(tvm.nd.array(params["norm"]))
g_data = relay.Constant(tvm.nd.array(params["g_data"]))
indices = relay.Constant(tvm.nd.array(params["indices"]))
indptr = relay.Constant(tvm.nd.array(params["indptr"]))

Adjacency = namedtuple("Adjacency", ["data", "indices", "indptr"])
adj = Adjacency(g_data, indices, indptr)

# Construct the 2-layer GCN
layers = []
layers.append(
    GraphConv(
        layer_name="layers.0",
        input_dim=infeat_dim,
        output_dim=num_hidden,
        adj=adj,
        input=infeats,
        norm=norm,
        activation=relay.nn.relu,
    )
)
layers.append(
    GraphConv(
        layer_name="layers.1",
        input_dim=num_hidden,
        output_dim=num_classes,
        adj=adj,
        input=layers[-1],
        norm=norm,
        activation=None,
    )
)

# Analyze free variables and generate Relay function
output = layers[-1]

使用 TVM 编译并运行#

从 PyTorch 模型导出权重到 Python Dict:

model_params = {}
for param_tensor in torch_model.state_dict():
    model_params[param_tensor] = torch_model.state_dict()[param_tensor].numpy()

for i in range(num_layers + 1):
    params["layers.%d.weight" % (i)] = model_params["layers.%d.weight" % (i)]
    params["layers.%d.bias" % (i)] = model_params["layers.%d.bias" % (i)]

# Set the TVM build target
target = "llvm"  # Currently only support `llvm` as target

func = relay.Function(relay.analysis.free_vars(output), output)
func = relay.build_module.bind_params_by_name(func, params)
mod = tvm.IRModule()
mod["main"] = func
# Build with Relay
with tvm.transform.PassContext(opt_level=0):  # Currently only support opt_level=0
    lib = relay.build(mod, target, params=params)

# Generate graph executor
dev = tvm.device(target, 0)
m = graph_executor.GraphModule(lib["default"](dev))
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.

运行 TVM 模型,测试准确性并通过 DGL 验证#

m.run()
logits_tvm = m.get_output(0).numpy()
print("Print the first five outputs from TVM execution\n", logits_tvm[:5])

labels = data.labels
test_mask = data.test_mask

acc = evaluate(data, logits_tvm)
print("Test accuracy of TVM results: {:.2%}".format(acc))

import tvm.testing

# Verify the results with the DGL model
tvm.testing.assert_allclose(logits_torch, logits_tvm, atol=1e-3)
Print the first five outputs from TVM execution
 [[ 0.26396316 -1.067397    0.07361096  0.78283393 -0.7665647  -0.02912378
  -0.14030665]
 [ 0.2670483  -0.97222644  0.07140031  0.6953188  -0.60881317 -0.07351625
  -0.16601387]
 [ 0.29854178 -0.97619903  0.11394241  0.57936156 -0.5615169  -0.03528827
  -0.18298927]
 [ 0.2773209  -1.2461467   0.0398193   0.95992005 -1.0011221   0.059847
  -0.10642916]
 [ 0.3691777  -1.0940018   0.03631139  0.6423676  -0.6491406   0.08039594
  -0.1535899 ]]
Test accuracy of TVM results: 5.30%