PAPI 快速上手#

PAPI(Performance Application Programming Interface)是在各种平台上提供性能计数器的库。性能计数器提供关于给定执行运行期间处理器行为的准确底层信息。该信息可以包含简单的指标,如总周期计数(total cycle count)、缓存未命中(cache misses)和执行指令(instructions executed),以及更高级的信息,如总 FLOPS 和 warp 占用率(occupancy)。PAPI 使这些指标在分析时可用。

安装 PAPI#

PAPI 可以使用你的包管理器来安装(apt-get install libpapi-dev)或者从这里获取源码

用 PAPI 构建 TVM#

要在 TVM 构建中包含 PAPI,需要在 config.cmake 中设置如下:

set(USE_PAPI ON)

如果 PAPI 被安装在非标准的地方,你可以像这样指定它的位置:

set(USE_PAPI path/to/papi.pc)

在剖析时使用 PAPI#

如果 TVM 是用 PAPI 构建的(见上文),那么你可以将 tvm.runtime.profiling.PAPIMetricCollector 传递给 tvm.runtime.GraphModule.profile() 来收集性能指标。下面是例子:

import set_env
import numpy as np
import pytest
from tvm.runtime import profiler_vm
from tvm import relay
import tvm
from tvm.relay.testing import mlp


target = "llvm"
dev = tvm.cpu()
mod, params = mlp.get_workload(1)

exe = relay.vm.compile(mod, target, params=params)
vm = profiler_vm.VirtualMachineProfiler(exe, dev)

data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32"), device=dev)
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
report = vm.profile(
    [data],
    func_name="main",
    collectors=[tvm.runtime.profiling.PAPIMetricCollector({tvm.cpu(): ["PAPI_FP_OPS"]})],
)
print(report)

为了保证正常运行,需要设定 /proc/sys/kernel/perf_event_paranoid 为 2 或者更小或者作为 root:

sudo sh -c "echo 2 > /proc/sys/kernel/perf_event_paranoid" 

VM#

dtype = "float32"
target = "llvm"
x = relay.var("x", shape=(relay.Any(), relay.Any()), dtype=dtype)
y = relay.var("y", shape=(relay.Any(), relay.Any()), dtype=dtype)
mod = tvm.IRModule()
mod["main"] = relay.Function([x, y], relay.add(x, y))
exe = relay.vm.compile(mod, target)
vm = profiler_vm.VirtualMachineProfiler(exe, dev)

data = np.random.rand(28, 28).astype("float32")
report = vm.profile(data, data, func_name="main")
assert "fused_add" in str(report)
assert "Total" in str(report)
assert "AllocTensorReg" in str(report)
assert "AllocStorage" in str(report)
assert report.configuration["Executor"] == "VM"
from io import StringIO
import csv

def read_csv(report):
    f = StringIO(report.csv())
    headers = []
    rows = []
    reader = csv.reader(f, delimiter=",")
    # force parsing
    in_header = True
    for row in reader:
        if in_header:
            headers = row
            in_header = False
            rows = [[] for x in headers]
        else:
            for i in range(len(row)):
                rows[i].append(row[i])
    return dict(zip(headers, rows))
_csv = read_csv(report)

assert "Hash" in _csv.keys()
# Ops should have a duration greater than zero.
assert all(
    [
        float(dur) > 0
        for dur, name in zip(_csv["Duration (us)"], _csv["Name"])
        if name[:5] == "fused"
    ]
)
# AllocTensor or AllocStorage may be cached, so their duration could be 0.
assert all(
    [
        float(dur) >= 0
        for dur, name in zip(_csv["Duration (us)"], _csv["Name"])
        if name[:5] != "fused"
    ]
)

Graph Executor#

from tvm.contrib.debugger import debug_executor

mod, params = mlp.get_workload(1)

exe = relay.build(mod, target, params=params)
gr = debug_executor.create(exe.get_graph_json(), exe.lib, dev)

data = np.random.rand(1, 1, 28, 28).astype("float32")
report = gr.profile(data=data)
assert "fused_nn_softmax" in str(report)
assert "Total" in str(report)
assert "Hash" in str(report)
assert "Graph" in str(report)
report
Name                                                 Duration (us)  Percent  Device  Count                                                    Argument Shapes              Hash  
tvmgen_default_fused_nn_dense_nn_bias_add_nn_relu            49.04    42.70    cpu0      1  float32[1, 784], float32[128, 784], float32[128], float32[1, 128]  35ac6d50e6e03a62  
tvmgen_default_fused_nn_dense_nn_bias_add_nn_relu_1           9.33     8.12    cpu0      1     float32[1, 128], float32[64, 128], float32[64], float32[1, 64]  7c89e1efbba1ce3b  
tvmgen_default_fused_nn_dense_nn_bias_add                     5.50     4.79    cpu0      1       float32[1, 64], float32[10, 64], float32[10], float32[1, 10]  8a679957c4723fed  
__nop                                                         1.08     0.94    cpu0      1                             float32[1, 1, 28, 28], float32[1, 784]  9efde5b782d81fa1  
tvmgen_default_fused_nn_softmax                               0.77     0.67    cpu0      1                                     float32[1, 10], float32[1, 10]  0cc19816e7a3c070  
----------                                                                                                                                                                       
Sum                                                          65.72    57.23              5                                                                                       
Total                                                       114.84             cpu0      1                                                                                       

Configuration
-------------
Number of threads: 24
Executor: Graph

算子#

from tvm.runtime.profiling import Report
from tvm.script import tir as T

@T.prim_func
def axpy_cpu(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [10], "float64")
    B = T.match_buffer(b, [10], "float64")
    C = T.match_buffer(c, [10], "float64")
    for i in range(10):
        C[i] = A[i] + B[i]


@T.prim_func
def axpy_gpu(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [10], "float64")
    B = T.match_buffer(b, [10], "float64")
    C = T.match_buffer(c, [10], "float64")
    for i in T.thread_binding(0, 10, "threadIdx.x"):
        C[i] = A[i] + B[i]
def test_profile_function(target, dev):
    target = tvm.target.Target(target)
    if str(target.kind) == "llvm":
        metric = "PAPI_FP_OPS"
        func = axpy_cpu
    elif str(target.kind) == "cuda":
        metric = (
            "cuda:::gpu__compute_memory_access_throughput.max.pct_of_peak_sustained_region:device=0"
        )
        func = axpy_gpu
    else:
        pytest.skip(f"Target {target.kind} not supported by this test")
    f = tvm.build(func, target=target)
    a = tvm.nd.array(np.ones(10), device=dev)
    b = tvm.nd.array(np.ones(10), device=dev)
    c = tvm.nd.array(np.zeros(10), device=dev)
    report = tvm.runtime.profiling.profile_function(
        f, dev, [tvm.runtime.profiling.PAPIMetricCollector({dev: [metric]})]
    )(a, b, c)
    assert metric in report.keys()
    assert report[metric].value > 0