使用自动调度优化运算
导航
使用自动调度优化运算#
作者: Lianmin Zheng,Chengfan Jia
在本教程中,我们将展示 TVM 的自动调度功能如何在不需要编写自定义模板的情况下找到最佳调度。
与基于模板的 AutoTVM 不同,后者依赖于手动模板来定义搜索空间,而自动调度器不需要任何模板。
用户只需要编写计算声明,而不需要任何调度命令或模板。自动调度器可以自动生成一个大的搜索空间,并在空间中找到一个好的调度。
本教程中我们以矩阵乘法为例。
提示
请注意,本教程不能在 Windows 或最近版本的 MacOS 上运行。为了让它运行,你需要将本教程的主体包裹在一个 if __name__ == "__main__":
块中。
import numpy as np
import tvm
from tvm import te, auto_scheduler
定义矩阵乘法#
首先,我们定义一个带有偏置加法的矩阵乘法。注意,这使用了 TVM 张量表达式语言中的标准操作。主要的区别是在函数定义的顶部使用了 register_workload
装饰器。该函数应该返回一个输入/输出张量的列表。从这些张量中,自动调度器可以得到整个计算图。
@auto_scheduler.register_workload # 注意 auto_scheduler 装饰器
def matmul_add(N, L, M, dtype):
A = te.placeholder((N, L), name="A", dtype=dtype)
B = te.placeholder((L, M), name="B", dtype=dtype)
C = te.placeholder((N, M), name="C", dtype=dtype)
k = te.reduce_axis((0, L), name="k")
matmul = te.compute(
(N, M),
lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
name="matmul",
attrs={"layout_free_placeholders": [B]}, # 启用张量 B 的自动布局转换
)
out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")
return [A, B, C, out]
创建搜索任务#
在定义了函数之后,我们现在可以为 auto_scheduler
创建一个任务来进行搜索。我们指定这个矩阵乘法的特殊参数,在这个例子中,是对 \(1024 \times 1024\) 大小的正方形矩阵的乘法。然后我们使用 N=L=M=1024 and dtype="float32"
创建一个搜索任务。
用自定义目标提高性能
为了使 TVM 能够充分利用特定的硬件平台,你需要手动指定你的 CPU 能力。例如:
用
llvm -mcpu=core-avx2
替换下面的llvm
,以启用 AVX2用
llvm -mcpu=skylake-avx512
替换下面的llvm
,以启用 AVX-512
target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)
# 检查计算图
print("Computational DAG:")
print(task.compute_dag)
Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])
为自动调度设置参数#
下一步,我们为自动调度设置参数。
num_measure_trials
是我们在搜索过程中可以使用的测量试验的数量。为了快速演示,我们在本教程中只做了 10 次试验。在实践中,1000 是一个很好的搜索收敛值。你可以根据你的时间预算做更多的试验。此外,我们使用
RecordToFile
来 log 测量记录到matmul.json
文件中。这些测量记录可以用来查询历史最好的,恢复搜索,并在以后做更多的分析。查阅
TuningOptions
了解参数的更多信息。
log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
运行搜索#
现在我们把所有的输入准备好。很简单,不是吗?我们可以启动搜索,让自动调度发挥它的魔力。经过一些测量试验后,我们可以从日志文件中加载最佳调度并加以应用。
# 运行 auto-tuning (search)
task.tune(tune_option)
# 应用最优 schedule
sch, args = task.apply_best(log_file)
----------------------------------------------------------------------
------------------------------ [ Search ]
----------------------------------------------------------------------
Generate Sketches #s: 3
Sample Initial Population #s: 2009 fail_ct: 1 Time elapsed: 0.74
GA Iter: 0 Max score: 0.9999 Min score: 0.9383 #Pop: 128 #M+: 0 #M-: 0
GA Iter: 4 Max score: 0.9999 Min score: 0.9878 #Pop: 128 #M+: 1383 #M-: 75
EvolutionarySearch #s: 128 Time elapsed: 2.38
----------------------------------------------------------------------
------------------------------ [ Measure ]
----------------------------------------------------------------------
Get 10 programs to measure:
..........**********
==================================================
No: 1 GFLOPS: 74.00 / 74.00 results: MeasureResult(cost:[0.0290], error_no:0, all_cost:0.80, Tstamp:1661835378.96)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,65536)
matmul auto_unroll: 512
for k.0 (0,1024)
for i.2 (0,4)
for i.3 (0,2)
vectorize j.3 (0,2)
matmul = ...
for i.2 (0,8)
vectorize j.2 (0,2)
out = ...
==================================================
No: 2 GFLOPS: 125.24 / 125.24 results: MeasureResult(cost:[0.0172], error_no:0, all_cost:0.94, Tstamp:1661835379.42)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@ (0,1024)
for j.1 (0,4)
for k.0 (0,32)
for i.2 (0,4)
for k.1 (0,32)
vectorize j.3 (0,64)
matmul = ...
for i.2 (0,4)
for j.2 (0,64)
out = ...
==================================================
No: 3 GFLOPS: 72.50 / 125.24 results: MeasureResult(cost:[0.0296], error_no:0, all_cost:1.31, Tstamp:1661835379.98)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,256)
matmul auto_unroll: 512
for i.1 (0,4)
for j.1 (0,16)
for k.0 (0,512)
for j.2 (0,16)
for k.1 (0,2)
for i.3 (0,2)
vectorize j.3 (0,2)
matmul = ...
for i.1 (0,8)
for j.1 (0,512)
out = ...
==================================================
No: 4 GFLOPS: 69.32 / 125.24 results: MeasureResult(cost:[0.0310], error_no:0, all_cost:2.62, Tstamp:1661835380.53)
==================================================
Placeholder: A, B, C
matmul auto_unroll: 512
parallel i.0@j.0@i.1@j.1@ (0,256)
for k.0 (0,16)
for i.2 (0,4)
for j.2 (0,8)
for k.1 (0,64)
for i.3 (0,4)
for j.3 (0,32)
matmul = ...
parallel i (0,1024)
for j (0,1024)
out = ...
==================================================
No: 5 GFLOPS: 39.57 / 125.24 results: MeasureResult(cost:[0.0543], error_no:0, all_cost:1.19, Tstamp:1661835381.00)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,128)
matmul auto_unroll: 16
for i.1 (0,2)
for j.1 (0,2)
for k.0 (0,32)
for i.2 (0,128)
for j.2 (0,4)
for k.1 (0,32)
for i.3 (0,4)
matmul = ...
for i.1 (0,1024)
for j.1 (0,8)
out = ...
==================================================
No: 6 GFLOPS: 31.00 / 125.24 results: MeasureResult(cost:[0.0693], error_no:0, all_cost:1.17, Tstamp:1661835381.47)
==================================================
Placeholder: A, B, C
parallel i.0 (0,512)
for j.0 (0,4)
for j.1 (0,64)
for k.0 (0,64)
for i.2 (0,2)
for j.2 (0,4)
for k.1 (0,16)
matmul = ...
parallel i (0,1024)
for j (0,1024)
out = ...
==================================================
No: 7 GFLOPS: 58.25 / 125.24 results: MeasureResult(cost:[0.0369], error_no:0, all_cost:0.78, Tstamp:1661835381.94)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,4096)
for k.0 (0,32)
for j.2 (0,128)
for k.1 (0,32)
vectorize j.3 (0,2)
matmul = ...
parallel i (0,1024)
for j (0,1024)
out = ...
==================================================
No: 8 GFLOPS: 22.04 / 125.24 results: MeasureResult(cost:[0.0975], error_no:0, all_cost:0.87, Tstamp:1661835382.58)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@i.1@j.1@ (0,2048)
for k.0 (0,256)
for i.2 (0,256)
for k.1 (0,4)
vectorize j.3 (0,2)
matmul = ...
parallel i (0,1024)
for j (0,1024)
out = ...
==================================================
No: 9 GFLOPS: 43.68 / 125.24 results: MeasureResult(cost:[0.0492], error_no:0, all_cost:7.61, Tstamp:1661835382.98)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,32)
matmul auto_unroll: 512
for i.1 (0,8)
for j.1 (0,8)
for k.0 (0,256)
for i.2 (0,2)
for j.2 (0,2)
for k.1 (0,4)
for i.3 (0,32)
vectorize j.3 (0,4)
matmul = ...
for i.1 (0,512)
for j.1 (0,64)
out = ...
==================================================
No: 10 GFLOPS: 24.96 / 125.24 results: MeasureResult(cost:[0.0861], error_no:0, all_cost:0.89, Tstamp:1661835383.52)
==================================================
Placeholder: A, B, C
parallel i.0@j.0@ (0,8)
matmul auto_unroll: 16
for j.1 (0,4)
for k.0 (0,16)
for i.2 (0,128)
for j.2 (0,64)
for k.1 (0,64)
for i.3 (0,2)
vectorize j.3 (0,2)
matmul = ...
for i.1 (0,256)
for j.1 (0,512)
out = ...
Time elapsed for measurement: 15.21 s
----------------------------------------------------------------------
------------------------------ [ Done ]
----------------------------------------------------------------------
/media/pc/data/4tb/lxw/libs/anaconda3/envs/py38/lib/python3.8/site-packages/xgboost/compat.py:36: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.
from pandas import MultiIndex, Int64Index
检查优化后的调度#
我们可以 lower 调度,看看自动调度后的 IR。自动调度器正确地进行了优化,包括多级平铺(tiling)、布局转换(layout transformation)、并行化(parallelization)、矢量化(vectorization)、解卷(unrolling)和运算符融合(operator fusion)。
mod = tvm.lower(sch, args, simple_mode=True)
mod.show()
@tvm.script.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer[1048576, "float32"], B: T.Buffer[1048576, "float32"], C: T.Buffer[1048576, "float32"], out: T.Buffer[1048576, "float32"]) -> None:
# function attr dict
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
T.preflattened_buffer(A, [1024, 1024], dtype="float32", data=A.data)
T.preflattened_buffer(B, [1024, 1024], dtype="float32", data=B.data)
T.preflattened_buffer(C, [1024, 1024], dtype="float32", data=C.data)
T.preflattened_buffer(out, [1024, 1024], dtype="float32", data=out.data)
# body
auto_scheduler_layout_transform = T.allocate([1048576], "float32", "global")
for ax0_ax1_fused_ax2_fused in T.parallel(16):
for ax4, ax5, ax6, ax7 in T.grid(256, 8, 4, 8):
auto_scheduler_layout_transform[ax0_ax1_fused_ax2_fused * 65536 + ax4 * 256 + ax5 * 32 + ax6 * 8 + ax7] = B[ax4 * 4096 + ax6 * 1024 + ax0_ax1_fused_ax2_fused * 64 + ax5 * 8 + ax7]
for i_outer_outer_j_outer_outer_fused in T.parallel(512):
matmul = T.allocate([256], "float32", "global")
for i_outer_inner, j_outer_inner in T.grid(4, 2):
matmul[0:8] = T.broadcast(T.float32(0), 8)
matmul[64:72] = T.broadcast(T.float32(0), 8)
matmul[8:16] = T.broadcast(T.float32(0), 8)
matmul[72:80] = T.broadcast(T.float32(0), 8)
matmul[16:24] = T.broadcast(T.float32(0), 8)
matmul[80:88] = T.broadcast(T.float32(0), 8)
matmul[24:32] = T.broadcast(T.float32(0), 8)
matmul[88:96] = T.broadcast(T.float32(0), 8)
matmul[32:40] = T.broadcast(T.float32(0), 8)
matmul[96:104] = T.broadcast(T.float32(0), 8)
matmul[40:48] = T.broadcast(T.float32(0), 8)
matmul[104:112] = T.broadcast(T.float32(0), 8)
matmul[48:56] = T.broadcast(T.float32(0), 8)
matmul[112:120] = T.broadcast(T.float32(0), 8)
matmul[56:64] = T.broadcast(T.float32(0), 8)
matmul[120:128] = T.broadcast(T.float32(0), 8)
matmul[128:136] = T.broadcast(T.float32(0), 8)
matmul[192:200] = T.broadcast(T.float32(0), 8)
matmul[136:144] = T.broadcast(T.float32(0), 8)
matmul[200:208] = T.broadcast(T.float32(0), 8)
matmul[144:152] = T.broadcast(T.float32(0), 8)
matmul[208:216] = T.broadcast(T.float32(0), 8)
matmul[152:160] = T.broadcast(T.float32(0), 8)
matmul[216:224] = T.broadcast(T.float32(0), 8)
matmul[160:168] = T.broadcast(T.float32(0), 8)
matmul[224:232] = T.broadcast(T.float32(0), 8)
matmul[168:176] = T.broadcast(T.float32(0), 8)
matmul[232:240] = T.broadcast(T.float32(0), 8)
matmul[176:184] = T.broadcast(T.float32(0), 8)
matmul[240:248] = T.broadcast(T.float32(0), 8)
matmul[184:192] = T.broadcast(T.float32(0), 8)
matmul[248:256] = T.broadcast(T.float32(0), 8)
for k_outer in T.serial(256):
cse_var_48: T.int32 = i_outer_outer_j_outer_outer_fused % 8 * 131072 + j_outer_inner * 65536 + k_outer * 256
cse_var_47: T.int32 = i_outer_outer_j_outer_outer_fused // 8 * 16384 + i_outer_inner * 4096 + k_outer * 4
cse_var_46: T.int32 = cse_var_48 + 96
cse_var_45: T.int32 = cse_var_48 + 88
cse_var_44: T.int32 = cse_var_48 + 80
cse_var_43: T.int32 = cse_var_48 + 8
cse_var_42: T.int32 = cse_var_48 + 72
cse_var_41: T.int32 = cse_var_48 + 64
cse_var_40: T.int32 = cse_var_48 + 56
cse_var_39: T.int32 = cse_var_48 + 48
cse_var_38: T.int32 = cse_var_48 + 40
cse_var_37: T.int32 = cse_var_48 + 32
cse_var_36: T.int32 = cse_var_48 + 248
cse_var_35: T.int32 = cse_var_48 + 240
cse_var_34: T.int32 = cse_var_48 + 24
cse_var_33: T.int32 = cse_var_48 + 232
cse_var_32: T.int32 = cse_var_48 + 224
cse_var_31: T.int32 = cse_var_48 + 216
cse_var_30: T.int32 = cse_var_48 + 208
cse_var_29: T.int32 = cse_var_48 + 200
cse_var_28: T.int32 = cse_var_48 + 192
cse_var_27: T.int32 = cse_var_48 + 184
cse_var_26: T.int32 = cse_var_48 + 176
cse_var_25: T.int32 = cse_var_48 + 168
cse_var_24: T.int32 = cse_var_48 + 160
cse_var_23: T.int32 = cse_var_48 + 16
cse_var_22: T.int32 = cse_var_48 + 152
cse_var_21: T.int32 = cse_var_48 + 144
cse_var_20: T.int32 = cse_var_48 + 136
cse_var_19: T.int32 = cse_var_48 + 128
cse_var_18: T.int32 = cse_var_48 + 120
cse_var_17: T.int32 = cse_var_48 + 112
cse_var_16: T.int32 = cse_var_48 + 104
cse_var_15: T.int32 = cse_var_47 + 3075
cse_var_14: T.int32 = cse_var_47 + 3074
cse_var_13: T.int32 = cse_var_47 + 3073
cse_var_12: T.int32 = cse_var_47 + 3072
cse_var_11: T.int32 = cse_var_47 + 3
cse_var_10: T.int32 = cse_var_47 + 2051
cse_var_9: T.int32 = cse_var_47 + 2050
cse_var_8: T.int32 = cse_var_47 + 2049
cse_var_7: T.int32 = cse_var_47 + 2048
cse_var_6: T.int32 = cse_var_47 + 2
cse_var_5: T.int32 = cse_var_47 + 1027
cse_var_4: T.int32 = cse_var_47 + 1026
cse_var_3: T.int32 = cse_var_47 + 1025
cse_var_2: T.int32 = cse_var_47 + 1024
cse_var_1: T.int32 = cse_var_47 + 1
matmul[0:8] = matmul[0:8] + T.broadcast(A[cse_var_47], 8) * auto_scheduler_layout_transform[cse_var_48:cse_var_48 + 8]
matmul[64:72] = matmul[64:72] + T.broadcast(A[cse_var_2], 8) * auto_scheduler_layout_transform[cse_var_48:cse_var_48 + 8]
matmul[0:8] = matmul[0:8] + T.broadcast(A[cse_var_1], 8) * auto_scheduler_layout_transform[cse_var_43:cse_var_43 + 8]
matmul[64:72] = matmul[64:72] + T.broadcast(A[cse_var_3], 8) * auto_scheduler_layout_transform[cse_var_43:cse_var_43 + 8]
matmul[0:8] = matmul[0:8] + T.broadcast(A[cse_var_6], 8) * auto_scheduler_layout_transform[cse_var_23:cse_var_23 + 8]
matmul[64:72] = matmul[64:72] + T.broadcast(A[cse_var_4], 8) * auto_scheduler_layout_transform[cse_var_23:cse_var_23 + 8]
matmul[0:8] = matmul[0:8] + T.broadcast(A[cse_var_11], 8) * auto_scheduler_layout_transform[cse_var_34:cse_var_34 + 8]
matmul[64:72] = matmul[64:72] + T.broadcast(A[cse_var_5], 8) * auto_scheduler_layout_transform[cse_var_34:cse_var_34 + 8]
matmul[8:16] = matmul[8:16] + T.broadcast(A[cse_var_47], 8) * auto_scheduler_layout_transform[cse_var_37:cse_var_37 + 8]
matmul[72:80] = matmul[72:80] + T.broadcast(A[cse_var_2], 8) * auto_scheduler_layout_transform[cse_var_37:cse_var_37 + 8]
matmul[8:16] = matmul[8:16] + T.broadcast(A[cse_var_1], 8) * auto_scheduler_layout_transform[cse_var_38:cse_var_38 + 8]
matmul[72:80] = matmul[72:80] + T.broadcast(A[cse_var_3], 8) * auto_scheduler_layout_transform[cse_var_38:cse_var_38 + 8]
matmul[8:16] = matmul[8:16] + T.broadcast(A[cse_var_6], 8) * auto_scheduler_layout_transform[cse_var_39:cse_var_39 + 8]
matmul[72:80] = matmul[72:80] + T.broadcast(A[cse_var_4], 8) * auto_scheduler_layout_transform[cse_var_39:cse_var_39 + 8]
matmul[8:16] = matmul[8:16] + T.broadcast(A[cse_var_11], 8) * auto_scheduler_layout_transform[cse_var_40:cse_var_40 + 8]
matmul[72:80] = matmul[72:80] + T.broadcast(A[cse_var_5], 8) * auto_scheduler_layout_transform[cse_var_40:cse_var_40 + 8]
matmul[16:24] = matmul[16:24] + T.broadcast(A[cse_var_47], 8) * auto_scheduler_layout_transform[cse_var_41:cse_var_41 + 8]
matmul[80:88] = matmul[80:88] + T.broadcast(A[cse_var_2], 8) * auto_scheduler_layout_transform[cse_var_41:cse_var_41 + 8]
matmul[16:24] = matmul[16:24] + T.broadcast(A[cse_var_1], 8) * auto_scheduler_layout_transform[cse_var_42:cse_var_42 + 8]
matmul[80:88] = matmul[80:88] + T.broadcast(A[cse_var_3], 8) * auto_scheduler_layout_transform[cse_var_42:cse_var_42 + 8]
matmul[16:24] = matmul[16:24] + T.broadcast(A[cse_var_6], 8) * auto_scheduler_layout_transform[cse_var_44:cse_var_44 + 8]
matmul[80:88] = matmul[80:88] + T.broadcast(A[cse_var_4], 8) * auto_scheduler_layout_transform[cse_var_44:cse_var_44 + 8]
matmul[16:24] = matmul[16:24] + T.broadcast(A[cse_var_11], 8) * auto_scheduler_layout_transform[cse_var_45:cse_var_45 + 8]
matmul[80:88] = matmul[80:88] + T.broadcast(A[cse_var_5], 8) * auto_scheduler_layout_transform[cse_var_45:cse_var_45 + 8]
matmul[24:32] = matmul[24:32] + T.broadcast(A[cse_var_47], 8) * auto_scheduler_layout_transform[cse_var_46:cse_var_46 + 8]
matmul[88:96] = matmul[88:96] + T.broadcast(A[cse_var_2], 8) * auto_scheduler_layout_transform[cse_var_46:cse_var_46 + 8]
matmul[24:32] = matmul[24:32] + T.broadcast(A[cse_var_1], 8) * auto_scheduler_layout_transform[cse_var_16:cse_var_16 + 8]
matmul[88:96] = matmul[88:96] + T.broadcast(A[cse_var_3], 8) * auto_scheduler_layout_transform[cse_var_16:cse_var_16 + 8]
matmul[24:32] = matmul[24:32] + T.broadcast(A[cse_var_6], 8) * auto_scheduler_layout_transform[cse_var_17:cse_var_17 + 8]
matmul[88:96] = matmul[88:96] + T.broadcast(A[cse_var_4], 8) * auto_scheduler_layout_transform[cse_var_17:cse_var_17 + 8]
matmul[24:32] = matmul[24:32] + T.broadcast(A[cse_var_11], 8) * auto_scheduler_layout_transform[cse_var_18:cse_var_18 + 8]
matmul[88:96] = matmul[88:96] + T.broadcast(A[cse_var_5], 8) * auto_scheduler_layout_transform[cse_var_18:cse_var_18 + 8]
matmul[32:40] = matmul[32:40] + T.broadcast(A[cse_var_47], 8) * auto_scheduler_layout_transform[cse_var_19:cse_var_19 + 8]
matmul[96:104] = matmul[96:104] + T.broadcast(A[cse_var_2], 8) * auto_scheduler_layout_transform[cse_var_19:cse_var_19 + 8]
matmul[32:40] = matmul[32:40] + T.broadcast(A[cse_var_1], 8) * auto_scheduler_layout_transform[cse_var_20:cse_var_20 + 8]
matmul[96:104] = matmul[96:104] + T.broadcast(A[cse_var_3], 8) * auto_scheduler_layout_transform[cse_var_20:cse_var_20 + 8]
matmul[32:40] = matmul[32:40] + T.broadcast(A[cse_var_6], 8) * auto_scheduler_layout_transform[cse_var_21:cse_var_21 + 8]
matmul[96:104] = matmul[96:104] + T.broadcast(A[cse_var_4], 8) * auto_scheduler_layout_transform[cse_var_21:cse_var_21 + 8]
matmul[32:40] = matmul[32:40] + T.broadcast(A[cse_var_11], 8) * auto_scheduler_layout_transform[cse_var_22:cse_var_22 + 8]
matmul[96:104] = matmul[96:104] + T.broadcast(A[cse_var_5], 8) * auto_scheduler_layout_transform[cse_var_22:cse_var_22 + 8]
matmul[40:48] = matmul[40:48] + T.broadcast(A[cse_var_47], 8) * auto_scheduler_layout_transform[cse_var_24:cse_var_24 + 8]
matmul[104:112] = matmul[104:112] + T.broadcast(A[cse_var_2], 8) * auto_scheduler_layout_transform[cse_var_24:cse_var_24 + 8]
matmul[40:48] = matmul[40:48] + T.broadcast(A[cse_var_1], 8) * auto_scheduler_layout_transform[cse_var_25:cse_var_25 + 8]
matmul[104:112] = matmul[104:112] + T.broadcast(A[cse_var_3], 8) * auto_scheduler_layout_transform[cse_var_25:cse_var_25 + 8]
matmul[40:48] = matmul[40:48] + T.broadcast(A[cse_var_6], 8) * auto_scheduler_layout_transform[cse_var_26:cse_var_26 + 8]
matmul[104:112] = matmul[104:112] + T.broadcast(A[cse_var_4], 8) * auto_scheduler_layout_transform[cse_var_26:cse_var_26 + 8]
matmul[40:48] = matmul[40:48] + T.broadcast(A[cse_var_11], 8) * auto_scheduler_layout_transform[cse_var_27:cse_var_27 + 8]
matmul[104:112] = matmul[104:112] + T.broadcast(A[cse_var_5], 8) * auto_scheduler_layout_transform[cse_var_27:cse_var_27 + 8]
matmul[48:56] = matmul[48:56] + T.broadcast(A[cse_var_47], 8) * auto_scheduler_layout_transform[cse_var_28:cse_var_28 + 8]
matmul[112:120] = matmul[112:120] + T.broadcast(A[cse_var_2], 8) * auto_scheduler_layout_transform[cse_var_28:cse_var_28 + 8]
matmul[48:56] = matmul[48:56] + T.broadcast(A[cse_var_1], 8) * auto_scheduler_layout_transform[cse_var_29:cse_var_29 + 8]
matmul[112:120] = matmul[112:120] + T.broadcast(A[cse_var_3], 8) * auto_scheduler_layout_transform[cse_var_29:cse_var_29 + 8]
matmul[48:56] = matmul[48:56] + T.broadcast(A[cse_var_6], 8) * auto_scheduler_layout_transform[cse_var_30:cse_var_30 + 8]
matmul[112:120] = matmul[112:120] + T.broadcast(A[cse_var_4], 8) * auto_scheduler_layout_transform[cse_var_30:cse_var_30 + 8]
matmul[48:56] = matmul[48:56] + T.broadcast(A[cse_var_11], 8) * auto_scheduler_layout_transform[cse_var_31:cse_var_31 + 8]
matmul[112:120] = matmul[112:120] + T.broadcast(A[cse_var_5], 8) * auto_scheduler_layout_transform[cse_var_31:cse_var_31 + 8]
matmul[56:64] = matmul[56:64] + T.broadcast(A[cse_var_47], 8) * auto_scheduler_layout_transform[cse_var_32:cse_var_32 + 8]
matmul[120:128] = matmul[120:128] + T.broadcast(A[cse_var_2], 8) * auto_scheduler_layout_transform[cse_var_32:cse_var_32 + 8]
matmul[56:64] = matmul[56:64] + T.broadcast(A[cse_var_1], 8) * auto_scheduler_layout_transform[cse_var_33:cse_var_33 + 8]
matmul[120:128] = matmul[120:128] + T.broadcast(A[cse_var_3], 8) * auto_scheduler_layout_transform[cse_var_33:cse_var_33 + 8]
matmul[56:64] = matmul[56:64] + T.broadcast(A[cse_var_6], 8) * auto_scheduler_layout_transform[cse_var_35:cse_var_35 + 8]
matmul[120:128] = matmul[120:128] + T.broadcast(A[cse_var_4], 8) * auto_scheduler_layout_transform[cse_var_35:cse_var_35 + 8]
matmul[56:64] = matmul[56:64] + T.broadcast(A[cse_var_11], 8) * auto_scheduler_layout_transform[cse_var_36:cse_var_36 + 8]
matmul[120:128] = matmul[120:128] + T.broadcast(A[cse_var_5], 8) * auto_scheduler_layout_transform[cse_var_36:cse_var_36 + 8]
matmul[128:136] = matmul[128:136] + T.broadcast(A[cse_var_7], 8) * auto_scheduler_layout_transform[cse_var_48:cse_var_48 + 8]
matmul[192:200] = matmul[192:200] + T.broadcast(A[cse_var_12], 8) * auto_scheduler_layout_transform[cse_var_48:cse_var_48 + 8]
matmul[128:136] = matmul[128:136] + T.broadcast(A[cse_var_8], 8) * auto_scheduler_layout_transform[cse_var_43:cse_var_43 + 8]
matmul[192:200] = matmul[192:200] + T.broadcast(A[cse_var_13], 8) * auto_scheduler_layout_transform[cse_var_43:cse_var_43 + 8]
matmul[128:136] = matmul[128:136] + T.broadcast(A[cse_var_9], 8) * auto_scheduler_layout_transform[cse_var_23:cse_var_23 + 8]
matmul[192:200] = matmul[192:200] + T.broadcast(A[cse_var_14], 8) * auto_scheduler_layout_transform[cse_var_23:cse_var_23 + 8]
matmul[128:136] = matmul[128:136] + T.broadcast(A[cse_var_10], 8) * auto_scheduler_layout_transform[cse_var_34:cse_var_34 + 8]
matmul[192:200] = matmul[192:200] + T.broadcast(A[cse_var_15], 8) * auto_scheduler_layout_transform[cse_var_34:cse_var_34 + 8]
matmul[136:144] = matmul[136:144] + T.broadcast(A[cse_var_7], 8) * auto_scheduler_layout_transform[cse_var_37:cse_var_37 + 8]
matmul[200:208] = matmul[200:208] + T.broadcast(A[cse_var_12], 8) * auto_scheduler_layout_transform[cse_var_37:cse_var_37 + 8]
matmul[136:144] = matmul[136:144] + T.broadcast(A[cse_var_8], 8) * auto_scheduler_layout_transform[cse_var_38:cse_var_38 + 8]
matmul[200:208] = matmul[200:208] + T.broadcast(A[cse_var_13], 8) * auto_scheduler_layout_transform[cse_var_38:cse_var_38 + 8]
matmul[136:144] = matmul[136:144] + T.broadcast(A[cse_var_9], 8) * auto_scheduler_layout_transform[cse_var_39:cse_var_39 + 8]
matmul[200:208] = matmul[200:208] + T.broadcast(A[cse_var_14], 8) * auto_scheduler_layout_transform[cse_var_39:cse_var_39 + 8]
matmul[136:144] = matmul[136:144] + T.broadcast(A[cse_var_10], 8) * auto_scheduler_layout_transform[cse_var_40:cse_var_40 + 8]
matmul[200:208] = matmul[200:208] + T.broadcast(A[cse_var_15], 8) * auto_scheduler_layout_transform[cse_var_40:cse_var_40 + 8]
matmul[144:152] = matmul[144:152] + T.broadcast(A[cse_var_7], 8) * auto_scheduler_layout_transform[cse_var_41:cse_var_41 + 8]
matmul[208:216] = matmul[208:216] + T.broadcast(A[cse_var_12], 8) * auto_scheduler_layout_transform[cse_var_41:cse_var_41 + 8]
matmul[144:152] = matmul[144:152] + T.broadcast(A[cse_var_8], 8) * auto_scheduler_layout_transform[cse_var_42:cse_var_42 + 8]
matmul[208:216] = matmul[208:216] + T.broadcast(A[cse_var_13], 8) * auto_scheduler_layout_transform[cse_var_42:cse_var_42 + 8]
matmul[144:152] = matmul[144:152] + T.broadcast(A[cse_var_9], 8) * auto_scheduler_layout_transform[cse_var_44:cse_var_44 + 8]
matmul[208:216] = matmul[208:216] + T.broadcast(A[cse_var_14], 8) * auto_scheduler_layout_transform[cse_var_44:cse_var_44 + 8]
matmul[144:152] = matmul[144:152] + T.broadcast(A[cse_var_10], 8) * auto_scheduler_layout_transform[cse_var_45:cse_var_45 + 8]
matmul[208:216] = matmul[208:216] + T.broadcast(A[cse_var_15], 8) * auto_scheduler_layout_transform[cse_var_45:cse_var_45 + 8]
matmul[152:160] = matmul[152:160] + T.broadcast(A[cse_var_7], 8) * auto_scheduler_layout_transform[cse_var_46:cse_var_46 + 8]
matmul[216:224] = matmul[216:224] + T.broadcast(A[cse_var_12], 8) * auto_scheduler_layout_transform[cse_var_46:cse_var_46 + 8]
matmul[152:160] = matmul[152:160] + T.broadcast(A[cse_var_8], 8) * auto_scheduler_layout_transform[cse_var_16:cse_var_16 + 8]
matmul[216:224] = matmul[216:224] + T.broadcast(A[cse_var_13], 8) * auto_scheduler_layout_transform[cse_var_16:cse_var_16 + 8]
matmul[152:160] = matmul[152:160] + T.broadcast(A[cse_var_9], 8) * auto_scheduler_layout_transform[cse_var_17:cse_var_17 + 8]
matmul[216:224] = matmul[216:224] + T.broadcast(A[cse_var_14], 8) * auto_scheduler_layout_transform[cse_var_17:cse_var_17 + 8]
matmul[152:160] = matmul[152:160] + T.broadcast(A[cse_var_10], 8) * auto_scheduler_layout_transform[cse_var_18:cse_var_18 + 8]
matmul[216:224] = matmul[216:224] + T.broadcast(A[cse_var_15], 8) * auto_scheduler_layout_transform[cse_var_18:cse_var_18 + 8]
matmul[160:168] = matmul[160:168] + T.broadcast(A[cse_var_7], 8) * auto_scheduler_layout_transform[cse_var_19:cse_var_19 + 8]
matmul[224:232] = matmul[224:232] + T.broadcast(A[cse_var_12], 8) * auto_scheduler_layout_transform[cse_var_19:cse_var_19 + 8]
matmul[160:168] = matmul[160:168] + T.broadcast(A[cse_var_8], 8) * auto_scheduler_layout_transform[cse_var_20:cse_var_20 + 8]
matmul[224:232] = matmul[224:232] + T.broadcast(A[cse_var_13], 8) * auto_scheduler_layout_transform[cse_var_20:cse_var_20 + 8]
matmul[160:168] = matmul[160:168] + T.broadcast(A[cse_var_9], 8) * auto_scheduler_layout_transform[cse_var_21:cse_var_21 + 8]
matmul[224:232] = matmul[224:232] + T.broadcast(A[cse_var_14], 8) * auto_scheduler_layout_transform[cse_var_21:cse_var_21 + 8]
matmul[160:168] = matmul[160:168] + T.broadcast(A[cse_var_10], 8) * auto_scheduler_layout_transform[cse_var_22:cse_var_22 + 8]
matmul[224:232] = matmul[224:232] + T.broadcast(A[cse_var_15], 8) * auto_scheduler_layout_transform[cse_var_22:cse_var_22 + 8]
matmul[168:176] = matmul[168:176] + T.broadcast(A[cse_var_7], 8) * auto_scheduler_layout_transform[cse_var_24:cse_var_24 + 8]
matmul[232:240] = matmul[232:240] + T.broadcast(A[cse_var_12], 8) * auto_scheduler_layout_transform[cse_var_24:cse_var_24 + 8]
matmul[168:176] = matmul[168:176] + T.broadcast(A[cse_var_8], 8) * auto_scheduler_layout_transform[cse_var_25:cse_var_25 + 8]
matmul[232:240] = matmul[232:240] + T.broadcast(A[cse_var_13], 8) * auto_scheduler_layout_transform[cse_var_25:cse_var_25 + 8]
matmul[168:176] = matmul[168:176] + T.broadcast(A[cse_var_9], 8) * auto_scheduler_layout_transform[cse_var_26:cse_var_26 + 8]
matmul[232:240] = matmul[232:240] + T.broadcast(A[cse_var_14], 8) * auto_scheduler_layout_transform[cse_var_26:cse_var_26 + 8]
matmul[168:176] = matmul[168:176] + T.broadcast(A[cse_var_10], 8) * auto_scheduler_layout_transform[cse_var_27:cse_var_27 + 8]
matmul[232:240] = matmul[232:240] + T.broadcast(A[cse_var_15], 8) * auto_scheduler_layout_transform[cse_var_27:cse_var_27 + 8]
matmul[176:184] = matmul[176:184] + T.broadcast(A[cse_var_7], 8) * auto_scheduler_layout_transform[cse_var_28:cse_var_28 + 8]
matmul[240:248] = matmul[240:248] + T.broadcast(A[cse_var_12], 8) * auto_scheduler_layout_transform[cse_var_28:cse_var_28 + 8]
matmul[176:184] = matmul[176:184] + T.broadcast(A[cse_var_8], 8) * auto_scheduler_layout_transform[cse_var_29:cse_var_29 + 8]
matmul[240:248] = matmul[240:248] + T.broadcast(A[cse_var_13], 8) * auto_scheduler_layout_transform[cse_var_29:cse_var_29 + 8]
matmul[176:184] = matmul[176:184] + T.broadcast(A[cse_var_9], 8) * auto_scheduler_layout_transform[cse_var_30:cse_var_30 + 8]
matmul[240:248] = matmul[240:248] + T.broadcast(A[cse_var_14], 8) * auto_scheduler_layout_transform[cse_var_30:cse_var_30 + 8]
matmul[176:184] = matmul[176:184] + T.broadcast(A[cse_var_10], 8) * auto_scheduler_layout_transform[cse_var_31:cse_var_31 + 8]
matmul[240:248] = matmul[240:248] + T.broadcast(A[cse_var_15], 8) * auto_scheduler_layout_transform[cse_var_31:cse_var_31 + 8]
matmul[184:192] = matmul[184:192] + T.broadcast(A[cse_var_7], 8) * auto_scheduler_layout_transform[cse_var_32:cse_var_32 + 8]
matmul[248:256] = matmul[248:256] + T.broadcast(A[cse_var_12], 8) * auto_scheduler_layout_transform[cse_var_32:cse_var_32 + 8]
matmul[184:192] = matmul[184:192] + T.broadcast(A[cse_var_8], 8) * auto_scheduler_layout_transform[cse_var_33:cse_var_33 + 8]
matmul[248:256] = matmul[248:256] + T.broadcast(A[cse_var_13], 8) * auto_scheduler_layout_transform[cse_var_33:cse_var_33 + 8]
matmul[184:192] = matmul[184:192] + T.broadcast(A[cse_var_9], 8) * auto_scheduler_layout_transform[cse_var_35:cse_var_35 + 8]
matmul[248:256] = matmul[248:256] + T.broadcast(A[cse_var_14], 8) * auto_scheduler_layout_transform[cse_var_35:cse_var_35 + 8]
matmul[184:192] = matmul[184:192] + T.broadcast(A[cse_var_10], 8) * auto_scheduler_layout_transform[cse_var_36:cse_var_36 + 8]
matmul[248:256] = matmul[248:256] + T.broadcast(A[cse_var_15], 8) * auto_scheduler_layout_transform[cse_var_36:cse_var_36 + 8]
for i_inner, j_inner in T.grid(4, 64):
cse_var_49: T.int32 = i_outer_outer_j_outer_outer_fused // 8 * 16384 + i_outer_inner * 4096 + i_inner * 1024 + i_outer_outer_j_outer_outer_fused % 8 * 128 + j_outer_inner * 64 + j_inner
out[cse_var_49] = matmul[i_inner * 64 + j_inner] + C[cse_var_49]
检查正确性并评估性能#
我们建立二进制文件,并检查其正确性(correctness)和性能(performance)。
func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_np
dev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)
# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)
# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print(
"Execution time of this operator: %.3f ms"
% (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)
Execution time of this operator: 11.307 ms
使用纪录文件#
在搜索过程中,所有的测量记录都被 log 到记录文件 matmul.json
。这些测量记录可以用来重新应用搜索结果,恢复搜索,并进行其他分析。
这里有一个例子,我们从一个文件中加载最佳调度,并打印出等效的 python 调度 API。这可以用于调试和学习自动调度的行为。
print("Equivalent python schedule:")
print(task.print_best(log_file))
Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=2)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=2)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=4)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=8)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=8)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=2)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=4)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
out_i_o_i, out_i_i = s[out].split(out_i, factor=4)
out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=4)
out_j_o_i, out_j_i = s[out].split(out_j, factor=64)
out_j_o_o, out_j_o_i = s[out].split(out_j_o_i, factor=2)
s[out].reorder(out_i_o_o, out_j_o_o, out_i_o_i, out_j_o_i, out_i_i, out_j_i)
s[matmul].compute_at(s[out], out_j_o_i)
out_i_o_o_j_o_o_fused = s[out].fuse(out_i_o_o, out_j_o_o)
s[out].parallel(out_i_o_o_j_o_o_fused)
s[matmul].pragma(matmul_i_o_o_o, "auto_unroll_max_step", 512)
s[matmul].pragma(matmul_i_o_o_o, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)
一个更复杂的例子是恢复搜索。在这种情况下,我们需要自己创建搜索策略和成本模型,并通过日志文件恢复搜索策略和成本模型(cost model)的状态。在下面的例子中,我们恢复了状态并做了更多的 5 次试验。
def resume_search(task, log_file):
print("Resume search:")
cost_model = auto_scheduler.XGBModel()
cost_model.update_from_file(log_file)
search_policy = auto_scheduler.SketchPolicy(
task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)]
)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)]
)
task.tune(tune_option, search_policy=search_policy)
resume_search(task, log_file)
Resume search:
----------------------------------------------------------------------
------------------------------ [ Call init-search callbacks ]
----------------------------------------------------------------------
SearchPolicy: Loaded 25 measurement records from matmul.json for ["matmul_add", 1024, 1024, 1024, "float32"]
----------------------------------------------------------------------
------------------------------ [ Search ]
----------------------------------------------------------------------
Generate Sketches #s: 3
Sample Initial Population #s: 2023 fail_ct: 0 Time elapsed: 0.60
GA Iter: 0 Max score: 0.9999 Min score: 0.9305 #Pop: 128 #M+: 0 #M-: 0
GA Iter: 4 Max score: 1.0000 Min score: 0.9863 #Pop: 128 #M+: 1372 #M-: 74
EvolutionarySearch #s: 128 Time elapsed: 2.43
----------------------------------------------------------------------
------------------------------ [ Measure ]
----------------------------------------------------------------------
Get 5 programs to measure:
.....*****
Time elapsed for measurement: 8.20 s
----------------------------------------------------------------------
------------------------------ [ Done ]
----------------------------------------------------------------------
/media/pc/data/4tb/lxw/libs/anaconda3/envs/py38/lib/python3.8/site-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated. See: https://xgboost.readthedocs.io/en/latest/python/callbacks.html
warnings.warn(f'Old style callback is deprecated. See: {link}', UserWarning)
最后说明和总结#
在本教程中,我们已经展示了如何使用 TVM 自动调度器来自动优化矩阵乘法,而不需要指定搜索模板。它结束了一系列从张量表达式(Tensor Expression,简称 TE)语言开始的例子,展示了 TVM 如何优化计算操作。