在 Relay 中使用外部库#

原作者: Masahiro Masuda, Truman Tian

这是简短的教程,介绍关于如何使用在 Relay 中使用外部库,如 cuDNN,或 cuBLAS。

Relay 在内部使用 TVM 来生成目标特定的代码。例如,使用 cuda 后端,TVM 为用户提供的网络中的所有层生成 cuda kernel。但有时,将不同供应商开发的外部库合并到 Relay 中也是有帮助的。幸运的是,TVM 有一种透明地调用这些库的机制。对于 Relay 用户,需要做的只是适当地设置目标字符串。

使用来自 Relay 的外部库之前, TVM 需要构建您想要使用的库。例如,要使用 cuDNN,在 cmake/config.cmake 中启用 USE_CUDNN 选项,必要时需要指定 cuDNN include 和库目录。

首先,导入 Relay 和 TVM。

import tvm
from tvm import te
import numpy as np
from tvm.contrib import graph_executor as runtime
from tvm import relay
from tvm.relay import testing
import tvm.testing

创建简单网络#

创建非常简单的网络进行演示。它由卷积、batch normalization 和 ReLU 激活组成。

out_channels = 16
batch_size = 1

data = relay.var("data", relay.TensorType((batch_size, 3, 224, 224), "float32"))
weight = relay.var("weight")
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")

simple_net = relay.nn.conv2d(
    data=data, weight=weight, kernel_size=(3, 3), channels=out_channels, padding=(1, 1)
)
simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0]
simple_net = relay.nn.relu(simple_net)
simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net)

data_shape = (batch_size, 3, 224, 224)
net, params = testing.create_workload(simple_net)

使用 cuda 后端构建和运行#

像往常一样,用 cuda 后端构建和运行这个网络。通过将日志级别设置为 DEBUG,将 Relay graph 编译的结果转储为伪代码。

import logging

logging.basicConfig(level=logging.DEBUG)  # to dump TVM IR after fusion

target = "cuda"
lib = relay.build_module.build(net, target, params=params)

dev = tvm.device(target, 0)
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
module = runtime.GraphModule(lib["default"](dev))
module.set_input("data", data)
module.run()
out_shape = (batch_size, out_channels, 224, 224)
out = module.get_output(0, tvm.nd.empty(out_shape))
out_cuda = out.numpy()
DEBUG:autotvm:Finish loading 825 records
INFO:te_compiler:Using injective.cpu for add based on highest priority (10)
/media/workspace/anaconda3/envs/mxnetx/lib/python3.10/site-packages/tvm/driver/build_module.py:263: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  warnings.warn(
INFO:te_compiler:Using injective.cpu for sqrt based on highest priority (10)
INFO:te_compiler:Using injective.cpu for divide based on highest priority (10)
INFO:te_compiler:Using injective.cpu for multiply based on highest priority (10)
INFO:te_compiler:Using injective.cpu for expand_dims based on highest priority (10)
INFO:te_compiler:Using injective.cpu for negative based on highest priority (10)
INFO:te_compiler:Using injective.cpu for multiply based on highest priority (10)
INFO:te_compiler:Using injective.cpu for add based on highest priority (10)
INFO:te_compiler:Using injective.cpu for expand_dims based on highest priority (10)
WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
DEBUG:autotvm:Cannot find tuning records for:
    target=cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32
    key=('conv2d_nchw.cuda', ('TENSOR', (1, 3, 224, 224), 'float32'), ('TENSOR', (16, 3, 3, 3), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 'float32')
TVM will apply a default schedule which may negatively impact performance.
INFO:te_compiler:Using conv2d_nchw.cuda for nn.conv2d based on highest priority (10)
INFO:te_compiler:Using injective.cuda for multiply based on highest priority (10)
INFO:te_compiler:Using injective.cuda for add based on highest priority (10)
INFO:te_compiler:Using injective.cuda for nn.relu based on highest priority (10)

生成的伪代码应该如下所示。

小技巧

注意 bias add、batch normalization 和 ReLU 激活是如何融合到卷积核中的。

TVM 从这个表示生成单一的融合 kernel。

print(lib.ir_mod["main"])
fn (%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %weight: Tensor[(16, 3, 3, 3), float32] /* ty=Tensor[(16, 3, 3, 3), float32] */, %bn_gamma: Tensor[(16), float32] /* ty=Tensor[(16), float32] */, %bn_beta: Tensor[(16), float32] /* ty=Tensor[(16), float32] */, %bn_mean: Tensor[(16), float32] /* ty=Tensor[(16), float32] */, %bn_var: Tensor[(16), float32] /* ty=Tensor[(16), float32] */) -> Tensor[(1, 16, 224, 224), float32] {
  %0 = nn.conv2d(%data, %weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
  %1 = nn.batch_norm(%0, %bn_gamma, %bn_beta, %bn_mean, %bn_var) /* ty=(Tensor[(1, 16, 224, 224), float32], Tensor[(16), float32], Tensor[(16), float32]) */;
  %2 = %1.0 /* ty=Tensor[(1, 16, 224, 224), float32] */;
  nn.relu(%2) /* ty=Tensor[(1, 16, 224, 224), float32] */
} /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(16, 3, 3, 3), float32], Tensor[(16), float32], Tensor[(16), float32], Tensor[(16), float32], Tensor[(16), float32]) -> Tensor[(1, 16, 224, 224), float32] */
lib.function_metadata
{"tvmgen_default_fused_nn_conv2d_multiply_add_nn_relu": FunctionInfoNode(
workspace_sizes={cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32: 768},
  io_sizes={cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32: 3211264},
  constant_sizes={cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32: 0},
  tir_primfuncs={cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32: PrimFunc([placeholder, placeholder, placeholder, placeholder, T_relu]) attrs={"from_legacy_te_schedule": (bool)1, "global_symbol": "tvmgen_default_fused_nn_conv2d_multiply_add_nn_relu", "tir.noalias": (bool)1, "hash": "97c4f8c60220fadf"} {
  // attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = 1
  allocate conv2d_nchw[float32 * 28], storage_scope = local
  allocate pad_temp.shared[float32 * 114], storage_scope = shared
  allocate placeholder.shared[float32 * 48], storage_scope = shared
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 224
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 2
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 4
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
  conv2d_nchw[0] = 0f
  conv2d_nchw[14] = 0f
  conv2d_nchw[2] = 0f
  conv2d_nchw[16] = 0f
  conv2d_nchw[4] = 0f
  conv2d_nchw[18] = 0f
  conv2d_nchw[6] = 0f
  conv2d_nchw[20] = 0f
  conv2d_nchw[8] = 0f
  conv2d_nchw[22] = 0f
  conv2d_nchw[10] = 0f
  conv2d_nchw[24] = 0f
  conv2d_nchw[12] = 0f
  conv2d_nchw[26] = 0f
  conv2d_nchw[1] = 0f
  conv2d_nchw[15] = 0f
  conv2d_nchw[3] = 0f
  conv2d_nchw[17] = 0f
  conv2d_nchw[5] = 0f
  conv2d_nchw[19] = 0f
  conv2d_nchw[7] = 0f
  conv2d_nchw[21] = 0f
  conv2d_nchw[9] = 0f
  conv2d_nchw[23] = 0f
  conv2d_nchw[11] = 0f
  conv2d_nchw[25] = 0f
  conv2d_nchw[13] = 0f
  conv2d_nchw[27] = 0f
  for (rc.outer, 0, 3) {
    // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 4
    // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
    // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
    if ((((threadIdx.z*29) + (threadIdx.x*2)) < 114)) {
      if ((threadIdx.x < 15)) {
        pad_temp.shared[((threadIdx.z*29) + (threadIdx.x*2))] = tir.if_then_else((((1 <= blockIdx.y) && (1 <= (((blockIdx.x*112) + (threadIdx.z*29)) + (threadIdx.x*2)))) && ((((blockIdx.x*112) + (threadIdx.z*29)) + (threadIdx.x*2)) < 225)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (blockIdx.x*112)) + (threadIdx.z*29)) + (threadIdx.x*2)) - 225)], 0f)
      }
    }
    if ((((threadIdx.z*29) + (threadIdx.x*2)) < 113)) {
      if ((threadIdx.x < 14)) {
        pad_temp.shared[(((threadIdx.z*29) + (threadIdx.x*2)) + 1)] = tir.if_then_else(((1 <= blockIdx.y) && ((((blockIdx.x*112) + (threadIdx.z*29)) + (threadIdx.x*2)) < 224)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (blockIdx.x*112)) + (threadIdx.z*29)) + (threadIdx.x*2)) - 224)], 0f)
      }
    }
    // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 4
    // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
    // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
    if (((floordiv(threadIdx.x, 12) + threadIdx.z) < 4)) {
      if ((threadIdx.x < 12)) {
        placeholder.shared[((threadIdx.z*12) + threadIdx.x)] = placeholder[((((threadIdx.z*108) + (floordiv(threadIdx.x, 3)*27)) + (rc.outer*9)) + floormod(threadIdx.x, 3))]
      }
    }
    conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp.shared[threadIdx.x]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[14] = (conv2d_nchw[14] + (pad_temp.shared[threadIdx.x]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[16] = (conv2d_nchw[16] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[18] = (conv2d_nchw[18] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[20] = (conv2d_nchw[20] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[8] = (conv2d_nchw[8] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[22] = (conv2d_nchw[22] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[10] = (conv2d_nchw[10] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[24] = (conv2d_nchw[24] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[12] = (conv2d_nchw[12] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[26] = (conv2d_nchw[26] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp.shared[threadIdx.x]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[15] = (conv2d_nchw[15] + (pad_temp.shared[threadIdx.x]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[17] = (conv2d_nchw[17] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[19] = (conv2d_nchw[19] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[7] = (conv2d_nchw[7] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[21] = (conv2d_nchw[21] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[9] = (conv2d_nchw[9] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[23] = (conv2d_nchw[23] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[11] = (conv2d_nchw[11] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[25] = (conv2d_nchw[25] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[13] = (conv2d_nchw[13] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[27] = (conv2d_nchw[27] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[14] = (conv2d_nchw[14] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[16] = (conv2d_nchw[16] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[18] = (conv2d_nchw[18] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[20] = (conv2d_nchw[20] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[8] = (conv2d_nchw[8] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[22] = (conv2d_nchw[22] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[10] = (conv2d_nchw[10] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[24] = (conv2d_nchw[24] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[12] = (conv2d_nchw[12] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[26] = (conv2d_nchw[26] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[15] = (conv2d_nchw[15] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[17] = (conv2d_nchw[17] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[19] = (conv2d_nchw[19] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[7] = (conv2d_nchw[7] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[21] = (conv2d_nchw[21] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[9] = (conv2d_nchw[9] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[23] = (conv2d_nchw[23] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[11] = (conv2d_nchw[11] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[25] = (conv2d_nchw[25] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[13] = (conv2d_nchw[13] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[27] = (conv2d_nchw[27] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[14] = (conv2d_nchw[14] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[16] = (conv2d_nchw[16] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[18] = (conv2d_nchw[18] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[20] = (conv2d_nchw[20] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[8] = (conv2d_nchw[8] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[22] = (conv2d_nchw[22] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[10] = (conv2d_nchw[10] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[24] = (conv2d_nchw[24] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[12] = (conv2d_nchw[12] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[26] = (conv2d_nchw[26] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[15] = (conv2d_nchw[15] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[17] = (conv2d_nchw[17] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[19] = (conv2d_nchw[19] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[7] = (conv2d_nchw[7] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[21] = (conv2d_nchw[21] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[9] = (conv2d_nchw[9] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[23] = (conv2d_nchw[23] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[11] = (conv2d_nchw[11] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[25] = (conv2d_nchw[25] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[13] = (conv2d_nchw[13] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[27] = (conv2d_nchw[27] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 4
    // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
    // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
    if ((((threadIdx.z*29) + (threadIdx.x*2)) < 114)) {
      if ((threadIdx.x < 15)) {
        pad_temp.shared[((threadIdx.z*29) + (threadIdx.x*2))] = tir.if_then_else(((1 <= (((blockIdx.x*112) + (threadIdx.z*29)) + (threadIdx.x*2))) && ((((blockIdx.x*112) + (threadIdx.z*29)) + (threadIdx.x*2)) < 225)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (blockIdx.x*112)) + (threadIdx.z*29)) + (threadIdx.x*2)) - 1)], 0f)
      }
    }
    if ((((threadIdx.z*29) + (threadIdx.x*2)) < 113)) {
      if ((threadIdx.x < 14)) {
        pad_temp.shared[(((threadIdx.z*29) + (threadIdx.x*2)) + 1)] = tir.if_then_else(((((blockIdx.x*112) + (threadIdx.z*29)) + (threadIdx.x*2)) < 224), placeholder[(((((rc.outer*50176) + (blockIdx.y*224)) + (blockIdx.x*112)) + (threadIdx.z*29)) + (threadIdx.x*2))], 0f)
      }
    }
    // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 4
    // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
    // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
    if (((floordiv(threadIdx.x, 12) + threadIdx.z) < 4)) {
      if ((threadIdx.x < 12)) {
        placeholder.shared[((threadIdx.z*12) + threadIdx.x)] = placeholder[(((((threadIdx.z*108) + (floordiv(threadIdx.x, 3)*27)) + (rc.outer*9)) + floormod(threadIdx.x, 3)) + 3)]
      }
    }
    conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp.shared[threadIdx.x]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[14] = (conv2d_nchw[14] + (pad_temp.shared[threadIdx.x]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[16] = (conv2d_nchw[16] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[18] = (conv2d_nchw[18] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[20] = (conv2d_nchw[20] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[8] = (conv2d_nchw[8] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[22] = (conv2d_nchw[22] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[10] = (conv2d_nchw[10] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[24] = (conv2d_nchw[24] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[12] = (conv2d_nchw[12] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[26] = (conv2d_nchw[26] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp.shared[threadIdx.x]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[15] = (conv2d_nchw[15] + (pad_temp.shared[threadIdx.x]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[17] = (conv2d_nchw[17] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[19] = (conv2d_nchw[19] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[7] = (conv2d_nchw[7] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[21] = (conv2d_nchw[21] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[9] = (conv2d_nchw[9] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[23] = (conv2d_nchw[23] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[11] = (conv2d_nchw[11] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[25] = (conv2d_nchw[25] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[13] = (conv2d_nchw[13] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[27] = (conv2d_nchw[27] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[14] = (conv2d_nchw[14] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[16] = (conv2d_nchw[16] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[18] = (conv2d_nchw[18] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[20] = (conv2d_nchw[20] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[8] = (conv2d_nchw[8] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[22] = (conv2d_nchw[22] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[10] = (conv2d_nchw[10] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[24] = (conv2d_nchw[24] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[12] = (conv2d_nchw[12] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[26] = (conv2d_nchw[26] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[15] = (conv2d_nchw[15] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[17] = (conv2d_nchw[17] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[19] = (conv2d_nchw[19] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[7] = (conv2d_nchw[7] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[21] = (conv2d_nchw[21] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[9] = (conv2d_nchw[9] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[23] = (conv2d_nchw[23] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[11] = (conv2d_nchw[11] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[25] = (conv2d_nchw[25] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[13] = (conv2d_nchw[13] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[27] = (conv2d_nchw[27] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[14] = (conv2d_nchw[14] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[16] = (conv2d_nchw[16] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[18] = (conv2d_nchw[18] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[20] = (conv2d_nchw[20] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[8] = (conv2d_nchw[8] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[22] = (conv2d_nchw[22] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[10] = (conv2d_nchw[10] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[24] = (conv2d_nchw[24] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[12] = (conv2d_nchw[12] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[26] = (conv2d_nchw[26] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[15] = (conv2d_nchw[15] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[17] = (conv2d_nchw[17] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[19] = (conv2d_nchw[19] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[7] = (conv2d_nchw[7] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[21] = (conv2d_nchw[21] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[9] = (conv2d_nchw[9] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[23] = (conv2d_nchw[23] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[11] = (conv2d_nchw[11] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[25] = (conv2d_nchw[25] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[13] = (conv2d_nchw[13] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[27] = (conv2d_nchw[27] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 4
    // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
    // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
    if ((((threadIdx.z*29) + (threadIdx.x*2)) < 114)) {
      if ((threadIdx.x < 15)) {
        pad_temp.shared[((threadIdx.z*29) + (threadIdx.x*2))] = tir.if_then_else((((blockIdx.y < 223) && (1 <= (((blockIdx.x*112) + (threadIdx.z*29)) + (threadIdx.x*2)))) && ((((blockIdx.x*112) + (threadIdx.z*29)) + (threadIdx.x*2)) < 225)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (blockIdx.x*112)) + (threadIdx.z*29)) + (threadIdx.x*2)) + 223)], 0f)
      }
    }
    if ((((threadIdx.z*29) + (threadIdx.x*2)) < 113)) {
      if ((threadIdx.x < 14)) {
        pad_temp.shared[(((threadIdx.z*29) + (threadIdx.x*2)) + 1)] = tir.if_then_else(((blockIdx.y < 223) && ((((blockIdx.x*112) + (threadIdx.z*29)) + (threadIdx.x*2)) < 224)), placeholder[((((((rc.outer*50176) + (blockIdx.y*224)) + (blockIdx.x*112)) + (threadIdx.z*29)) + (threadIdx.x*2)) + 224)], 0f)
      }
    }
    // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 4
    // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
    // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 16
    if (((floordiv(threadIdx.x, 12) + threadIdx.z) < 4)) {
      if ((threadIdx.x < 12)) {
        placeholder.shared[((threadIdx.z*12) + threadIdx.x)] = placeholder[(((((threadIdx.z*108) + (floordiv(threadIdx.x, 3)*27)) + (rc.outer*9)) + floormod(threadIdx.x, 3)) + 6)]
      }
    }
    conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp.shared[threadIdx.x]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[14] = (conv2d_nchw[14] + (pad_temp.shared[threadIdx.x]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[16] = (conv2d_nchw[16] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[18] = (conv2d_nchw[18] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[20] = (conv2d_nchw[20] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[8] = (conv2d_nchw[8] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[22] = (conv2d_nchw[22] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[10] = (conv2d_nchw[10] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[24] = (conv2d_nchw[24] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[12] = (conv2d_nchw[12] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[(threadIdx.z*6)]))
    conv2d_nchw[26] = (conv2d_nchw[26] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[((threadIdx.z*6) + 24)]))
    conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp.shared[threadIdx.x]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[15] = (conv2d_nchw[15] + (pad_temp.shared[threadIdx.x]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[17] = (conv2d_nchw[17] + (pad_temp.shared[(threadIdx.x + 16)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[19] = (conv2d_nchw[19] + (pad_temp.shared[(threadIdx.x + 32)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[7] = (conv2d_nchw[7] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[21] = (conv2d_nchw[21] + (pad_temp.shared[(threadIdx.x + 48)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[9] = (conv2d_nchw[9] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[23] = (conv2d_nchw[23] + (pad_temp.shared[(threadIdx.x + 64)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[11] = (conv2d_nchw[11] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[25] = (conv2d_nchw[25] + (pad_temp.shared[(threadIdx.x + 80)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[13] = (conv2d_nchw[13] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[((threadIdx.z*6) + 3)]))
    conv2d_nchw[27] = (conv2d_nchw[27] + (pad_temp.shared[(threadIdx.x + 96)]*placeholder.shared[((threadIdx.z*6) + 27)]))
    conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[14] = (conv2d_nchw[14] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[16] = (conv2d_nchw[16] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[18] = (conv2d_nchw[18] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[20] = (conv2d_nchw[20] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[8] = (conv2d_nchw[8] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[22] = (conv2d_nchw[22] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[10] = (conv2d_nchw[10] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[24] = (conv2d_nchw[24] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[12] = (conv2d_nchw[12] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 1)]))
    conv2d_nchw[26] = (conv2d_nchw[26] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 25)]))
    conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[15] = (conv2d_nchw[15] + (pad_temp.shared[(threadIdx.x + 1)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[17] = (conv2d_nchw[17] + (pad_temp.shared[(threadIdx.x + 17)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[19] = (conv2d_nchw[19] + (pad_temp.shared[(threadIdx.x + 33)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[7] = (conv2d_nchw[7] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[21] = (conv2d_nchw[21] + (pad_temp.shared[(threadIdx.x + 49)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[9] = (conv2d_nchw[9] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[23] = (conv2d_nchw[23] + (pad_temp.shared[(threadIdx.x + 65)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[11] = (conv2d_nchw[11] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[25] = (conv2d_nchw[25] + (pad_temp.shared[(threadIdx.x + 81)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[13] = (conv2d_nchw[13] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 4)]))
    conv2d_nchw[27] = (conv2d_nchw[27] + (pad_temp.shared[(threadIdx.x + 97)]*placeholder.shared[((threadIdx.z*6) + 28)]))
    conv2d_nchw[0] = (conv2d_nchw[0] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[14] = (conv2d_nchw[14] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[2] = (conv2d_nchw[2] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[16] = (conv2d_nchw[16] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[4] = (conv2d_nchw[4] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[18] = (conv2d_nchw[18] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[6] = (conv2d_nchw[6] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[20] = (conv2d_nchw[20] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[8] = (conv2d_nchw[8] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[22] = (conv2d_nchw[22] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[10] = (conv2d_nchw[10] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[24] = (conv2d_nchw[24] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[12] = (conv2d_nchw[12] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 2)]))
    conv2d_nchw[26] = (conv2d_nchw[26] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 26)]))
    conv2d_nchw[1] = (conv2d_nchw[1] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[15] = (conv2d_nchw[15] + (pad_temp.shared[(threadIdx.x + 2)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[3] = (conv2d_nchw[3] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[17] = (conv2d_nchw[17] + (pad_temp.shared[(threadIdx.x + 18)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[5] = (conv2d_nchw[5] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[19] = (conv2d_nchw[19] + (pad_temp.shared[(threadIdx.x + 34)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[7] = (conv2d_nchw[7] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[21] = (conv2d_nchw[21] + (pad_temp.shared[(threadIdx.x + 50)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[9] = (conv2d_nchw[9] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[23] = (conv2d_nchw[23] + (pad_temp.shared[(threadIdx.x + 66)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[11] = (conv2d_nchw[11] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[25] = (conv2d_nchw[25] + (pad_temp.shared[(threadIdx.x + 82)]*placeholder.shared[((threadIdx.z*6) + 29)]))
    conv2d_nchw[13] = (conv2d_nchw[13] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 5)]))
    conv2d_nchw[27] = (conv2d_nchw[27] + (pad_temp.shared[(threadIdx.x + 98)]*placeholder.shared[((threadIdx.z*6) + 29)]))
  }
  T_relu[((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x)] = max(((conv2d_nchw[0]*placeholder[(threadIdx.z*2)]) + placeholder[(threadIdx.z*2)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 401408)] = max(((conv2d_nchw[14]*placeholder[((threadIdx.z*2) + 8)]) + placeholder[((threadIdx.z*2) + 8)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 16)] = max(((conv2d_nchw[2]*placeholder[(threadIdx.z*2)]) + placeholder[(threadIdx.z*2)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 401424)] = max(((conv2d_nchw[16]*placeholder[((threadIdx.z*2) + 8)]) + placeholder[((threadIdx.z*2) + 8)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 32)] = max(((conv2d_nchw[4]*placeholder[(threadIdx.z*2)]) + placeholder[(threadIdx.z*2)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 401440)] = max(((conv2d_nchw[18]*placeholder[((threadIdx.z*2) + 8)]) + placeholder[((threadIdx.z*2) + 8)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 48)] = max(((conv2d_nchw[6]*placeholder[(threadIdx.z*2)]) + placeholder[(threadIdx.z*2)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 401456)] = max(((conv2d_nchw[20]*placeholder[((threadIdx.z*2) + 8)]) + placeholder[((threadIdx.z*2) + 8)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 64)] = max(((conv2d_nchw[8]*placeholder[(threadIdx.z*2)]) + placeholder[(threadIdx.z*2)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 401472)] = max(((conv2d_nchw[22]*placeholder[((threadIdx.z*2) + 8)]) + placeholder[((threadIdx.z*2) + 8)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 80)] = max(((conv2d_nchw[10]*placeholder[(threadIdx.z*2)]) + placeholder[(threadIdx.z*2)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 401488)] = max(((conv2d_nchw[24]*placeholder[((threadIdx.z*2) + 8)]) + placeholder[((threadIdx.z*2) + 8)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 96)] = max(((conv2d_nchw[12]*placeholder[(threadIdx.z*2)]) + placeholder[(threadIdx.z*2)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 401504)] = max(((conv2d_nchw[26]*placeholder[((threadIdx.z*2) + 8)]) + placeholder[((threadIdx.z*2) + 8)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50176)] = max(((conv2d_nchw[1]*placeholder[((threadIdx.z*2) + 1)]) + placeholder[((threadIdx.z*2) + 1)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 451584)] = max(((conv2d_nchw[15]*placeholder[((threadIdx.z*2) + 9)]) + placeholder[((threadIdx.z*2) + 9)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50192)] = max(((conv2d_nchw[3]*placeholder[((threadIdx.z*2) + 1)]) + placeholder[((threadIdx.z*2) + 1)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 451600)] = max(((conv2d_nchw[17]*placeholder[((threadIdx.z*2) + 9)]) + placeholder[((threadIdx.z*2) + 9)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50208)] = max(((conv2d_nchw[5]*placeholder[((threadIdx.z*2) + 1)]) + placeholder[((threadIdx.z*2) + 1)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 451616)] = max(((conv2d_nchw[19]*placeholder[((threadIdx.z*2) + 9)]) + placeholder[((threadIdx.z*2) + 9)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50224)] = max(((conv2d_nchw[7]*placeholder[((threadIdx.z*2) + 1)]) + placeholder[((threadIdx.z*2) + 1)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 451632)] = max(((conv2d_nchw[21]*placeholder[((threadIdx.z*2) + 9)]) + placeholder[((threadIdx.z*2) + 9)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50240)] = max(((conv2d_nchw[9]*placeholder[((threadIdx.z*2) + 1)]) + placeholder[((threadIdx.z*2) + 1)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 451648)] = max(((conv2d_nchw[23]*placeholder[((threadIdx.z*2) + 9)]) + placeholder[((threadIdx.z*2) + 9)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50256)] = max(((conv2d_nchw[11]*placeholder[((threadIdx.z*2) + 1)]) + placeholder[((threadIdx.z*2) + 1)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 451664)] = max(((conv2d_nchw[25]*placeholder[((threadIdx.z*2) + 9)]) + placeholder[((threadIdx.z*2) + 9)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 50272)] = max(((conv2d_nchw[13]*placeholder[((threadIdx.z*2) + 1)]) + placeholder[((threadIdx.z*2) + 1)]), 0f)
  T_relu[(((((threadIdx.z*100352) + (blockIdx.y*224)) + (blockIdx.x*112)) + threadIdx.x) + 451680)] = max(((conv2d_nchw[27]*placeholder[((threadIdx.z*2) + 9)]) + placeholder[((threadIdx.z*2) + 9)]), 0f)
}
},
  relay_primfuncs={cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32: fn (%p0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %p1: Tensor[(16, 3, 3, 3), float32] /* ty=Tensor[(16, 3, 3, 3), float32] */, %p2: Tensor[(16, 1, 1), float32] /* ty=Tensor[(16, 1, 1), float32] */, %p3: Tensor[(16, 1, 1), float32] /* ty=Tensor[(16, 1, 1), float32] */, target=meta[Target][0], prim_funcs={'tvmgen_default_fused_nn_conv2d_multiply_add_nn_relu'=meta[tir.PrimFunc][0]}, out_layout="", data_layout="NCHW", hash="97c4f8c60220fadf", kernel_layout="OIHW", prim_fn_var='tvmgen_default_fused_nn_conv2d_multiply_add_nn_relu', Primitive=1) -> Tensor[(1, 16, 224, 224), float32] {
  %0 = nn.conv2d(%p0, %p1, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
  %1 = multiply(%0, %p2) /* ty=Tensor[(1, 16, 224, 224), float32] */;
  %2 = add(%1, %p3) /* ty=Tensor[(1, 16, 224, 224), float32] */;
  nn.relu(%2) /* ty=Tensor[(1, 16, 224, 224), float32] */
} /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(16, 3, 3, 3), float32], Tensor[(16, 1, 1), float32], Tensor[(16, 1, 1), float32]) -> Tensor[(1, 16, 224, 224), float32] */
}), "__tvm_main__": FunctionInfoNode(
workspace_sizes={cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32: 0},
  io_sizes={cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32: 3813376},
  constant_sizes={cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32: 1856},
  tir_primfuncs={},
  relay_primfuncs={cuda -keys=cuda,gpu -arch=sm_75 -max_num_threads=1024 -thread_warp_size=32: fn (%data {virtual_device=VirtualDevice(device_type=2, virtual_device_id=0, target=Target(kind='cuda', keys={'cuda', 'gpu'}, attrs={'thread_warp_size': 32, 'max_num_threads': 1024, 'arch': "sm_75"}, host=Target(kind='llvm', keys={'cpu'}, attrs={'link-params': (bool)0})))}: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, hash="9f4aa0477d7fec53", executor=meta[Executor][0], kernel_layout="OIHW", data_layout="NCHW", out_layout="", runtime=meta[Runtime][0], virtual_device=VirtualDevice(device_type=2, virtual_device_id=0, target=Target(kind='cuda', keys={'cuda', 'gpu'}, attrs={'thread_warp_size': 32, 'max_num_threads': 1024, 'arch': "sm_75"}, host=Target(kind='llvm', keys={'cpu'}, attrs={'link-params': (bool)0})))) -> Tensor[(1, 16, 224, 224), float32] {
  %3 = fn (%p0: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %p1: Tensor[(16, 3, 3, 3), float32] /* ty=Tensor[(16, 3, 3, 3), float32] */, %p2: Tensor[(16, 1, 1), float32] /* ty=Tensor[(16, 1, 1), float32] */, %p3: Tensor[(16, 1, 1), float32] /* ty=Tensor[(16, 1, 1), float32] */, hash="97c4f8c60220fadf", data_layout="NCHW", kernel_layout="OIHW", Primitive=1, out_layout="") -> Tensor[(1, 16, 224, 224), float32] {
    %0 = nn.conv2d(%p0, %p1, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    %1 = multiply(%0, %p2) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    %2 = add(%1, %p3) /* ty=Tensor[(1, 16, 224, 224), float32] */;
    nn.relu(%2) /* ty=Tensor[(1, 16, 224, 224), float32] */
  } /* ty=fn (Tensor[(1, 3, 224, 224), float32], Tensor[(16, 3, 3, 3), float32], Tensor[(16, 1, 1), float32], Tensor[(16, 1, 1), float32]) -> Tensor[(1, 16, 224, 224), float32] */;
  %3(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, meta[relay.Constant][1] /* ty=Tensor[(16, 1, 1), float32] */, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 224, 224), float32] */
} /* ty=fn (Tensor[(1, 3, 224, 224), float32]) -> Tensor[(1, 16, 224, 224), float32] */
})}

为卷积层使用 cuDNN#

可以用 cuDNN 来代替 cuDNN 的卷积核。为此,需要做的就是将选项 " -libs=cudnn" 附加到目标字符串中。

net, params = testing.create_workload(simple_net)
target = "cuda -libs=cudnn"  # use cudnn for convolution
lib = relay.build_module.build(net, target, params=params)

dev = tvm.device(target, 0)
data = np.random.uniform(-1, 1, size=data_shape).astype("float32")
module = runtime.GraphModule(lib["default"](dev))
module.set_input("data", data)
module.run()
out_shape = (batch_size, out_channels, 224, 224)
out = module.get_output(0, tvm.nd.empty(out_shape))
out_cudnn = out.numpy()
DEBUG:autotvm:Finish loading 825 records
INFO:te_compiler:Using injective.cpu for add based on highest priority (10)
/media/workspace/anaconda3/envs/mxnetx/lib/python3.10/site-packages/tvm/driver/build_module.py:263: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  warnings.warn(
INFO:te_compiler:Using injective.cpu for sqrt based on highest priority (10)
INFO:te_compiler:Using injective.cpu for divide based on highest priority (10)
INFO:te_compiler:Using injective.cpu for multiply based on highest priority (10)
INFO:te_compiler:Using injective.cpu for expand_dims based on highest priority (10)
INFO:te_compiler:Using injective.cpu for negative based on highest priority (10)
INFO:te_compiler:Using injective.cpu for multiply based on highest priority (10)
INFO:te_compiler:Using injective.cpu for add based on highest priority (10)
INFO:te_compiler:Using injective.cpu for expand_dims based on highest priority (10)
[09:19:32] /media/pc/data/4tb/lxw/books/tvm/src/runtime/contrib/cudnn/conv_forward.cc:135: 	CUDNN Found 8 fwd algorithms, choosing CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
[09:19:32] /media/pc/data/4tb/lxw/books/tvm/src/runtime/contrib/cudnn/conv_forward.cc:138: 		0) CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM - time: 0.046912 ms, Memory: 0
[09:19:32] /media/pc/data/4tb/lxw/books/tvm/src/runtime/contrib/cudnn/conv_forward.cc:138: 		1) CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM - time: 0.073984 ms, Memory: 304000
[09:19:32] /media/pc/data/4tb/lxw/books/tvm/src/runtime/contrib/cudnn/conv_forward.cc:138: 		2) CUDNN_CONVOLUTION_FWD_ALGO_GEMM - time: 0.08064 ms, Memory: 5419008
[09:19:32] /media/pc/data/4tb/lxw/books/tvm/src/runtime/contrib/cudnn/conv_forward.cc:138: 		3) CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD - time: 0.089344 ms, Memory: 19200
[09:19:32] /media/pc/data/4tb/lxw/books/tvm/src/runtime/contrib/cudnn/conv_forward.cc:138: 		4) CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM - time: 0.129024 ms, Memory: 304000
[09:19:32] /media/pc/data/4tb/lxw/books/tvm/src/runtime/contrib/cudnn/conv_forward.cc:138: 		5) CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING - time: 0.900576 ms, Memory: 374272
[09:19:32] /media/pc/data/4tb/lxw/books/tvm/src/runtime/contrib/cudnn/conv_forward.cc:138: 		6) CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED - time: 1.11293 ms, Memory: 137288448
[09:19:32] /media/pc/data/4tb/lxw/books/tvm/src/runtime/contrib/cudnn/conv_forward.cc:138: 		7) CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED - time: 1.11664 ms, Memory: 137288448
DEBUG:autotvm:Cannot find tuning records for:
    target=cuda -keys=cuda,gpu -arch=sm_75 -libs=cudnn -max_num_threads=1024 -thread_warp_size=32
    key=('conv2d_cudnn.cuda', ('TENSOR', (1, 3, 224, 224), 'float32'), ('TENSOR', (16, 3, 3, 3), 'float32'), (1, 1), (1, 1, 1, 1), (1, 1), 1, 'NCHW', 'float32')
TVM will apply a default schedule which may negatively impact performance.
INFO:te_compiler:Using conv2d_cudnn.cuda for nn.conv2d based on highest priority (25)
INFO:te_compiler:Using injective.cuda for multiply based on highest priority (10)
INFO:te_compiler:Using injective.cuda for add based on highest priority (10)
INFO:te_compiler:Using injective.cuda for nn.relu based on highest priority (10)

备注

如果你使用 cuDNN, Relay 不能融合后面的层的卷积。这是因为层融合发生在 TVM 内部表示 (IR) 级别。Relay 将外部库视为黑盒,因此没有办法将它们与 TVM IR 融合。

下面的伪代码显示,cuDNN 卷积 + bias add + batch norm + ReLU 分为两个计算阶段,一个用于 cuDNN 调用,另一个用于其余的运算。

lib.ir_mod
#[version = "0.0.5"]
def @main(%data: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] */, %weight: Tensor[(16, 3, 3, 3), float32] /* ty=Tensor[(16, 3, 3, 3), float32] */, %bn_gamma: Tensor[(16), float32] /* ty=Tensor[(16), float32] */, %bn_beta: Tensor[(16), float32] /* ty=Tensor[(16), float32] */, %bn_mean: Tensor[(16), float32] /* ty=Tensor[(16), float32] */, %bn_var: Tensor[(16), float32] /* ty=Tensor[(16), float32] */) -> Tensor[(1, 16, 224, 224), float32] {
  %0 = nn.conv2d(%data, %weight, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 224, 224), float32] */;
  %1 = nn.batch_norm(%0, %bn_gamma, %bn_beta, %bn_mean, %bn_var) /* ty=(Tensor[(1, 16, 224, 224), float32], Tensor[(16), float32], Tensor[(16), float32]) */;
  %2 = %1.0 /* ty=Tensor[(1, 16, 224, 224), float32] */;
  nn.relu(%2) /* ty=Tensor[(1, 16, 224, 224), float32] */
}

验证结果#

可以检查两次运行的结果是否匹配。

tvm.testing.assert_allclose(out_cuda, out_cudnn, rtol=1e-5)

结论#

本教程涵盖了 cuDNN 与 Relay 的使用。TVM 也支持 cuBLAS。如果 cuBLAS 被启用,它将在全连接的层(relay.dense)内使用。要使用 cuBLAS,设置目标字符串为 "cuda -libs=cublas"

也可以同时使用 cuDNN 和 cuBLAS:"cuda -libs=cudnn,cublas"

对于 ROCm 后端,支持 MIOpen 和 rocBLAS。它们可以通过 target "rocm -libs=miopen,rocblas" 来启用。

能够使用外部库是很好的,但是需要记住一些注意事项。

  • 首先,使用外部库可能会限制 TVM 和 Relay 的使用。

    例如,MIOpen 目前只支持 NCHW 布局和 fp32 数据类型,所以在 TVM 中不能使用其他布局或数据类型。

  • 其次,更重要的是,外部库限制了 graph 编译过程中算子融合的可能性,如上所示。

    TVM 和 Relay 的目标是实现在各种硬件上的最佳性能,通过联合算子级和图优化。 为了实现这一目标,应该继续为 TVM 和 Relay 开发更好的优化,同时在必要时使用外部库作为返回现有实现的好方法。