{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.5. TensorIR 练习"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython\n",
    "import numpy as np\n",
    "import tvm\n",
    "from tvm.ir.module import IRModule\n",
    "from tvm.script import tir as T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 如何编写 TensorIR\n",
    "\n",
    "### 然后编写向量加法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# init data\n",
    "a = np.arange(16).reshape(4, 4)\n",
    "b = np.arange(16, 0, -1).reshape(4, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[16, 16, 16, 16],\n",
       "       [16, 16, 16, 16],\n",
       "       [16, 16, 16, 16],\n",
       "       [16, 16, 16, 16]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# numpy version\n",
    "c_np = a + b\n",
    "c_np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[16, 16, 16, 16],\n",
       "       [16, 16, 16, 16],\n",
       "       [16, 16, 16, 16],\n",
       "       [16, 16, 16, 16]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# low-level numpy version\n",
    "def lnumpy_add(a: np.ndarray, b: np.ndarray, c: np.ndarray):\n",
    "  for i in range(4):\n",
    "    for j in range(4):\n",
    "      c[i, j] = a[i, j] + b[i, j]\n",
    "c_lnumpy = np.empty((4, 4), dtype=np.int64)\n",
    "lnumpy_add(a, b, c_lnumpy)\n",
    "c_lnumpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class MyAdd:\n",
    "    @T.prim_func\n",
    "    def add(A: T.Buffer[(4, 4), \"int64\"],\n",
    "            B: T.Buffer[(4, 4), \"int64\"],\n",
    "            C: T.Buffer[(4, 4), \"int64\"]):\n",
    "        T.func_attr({\"global_symbol\": \"add\"})\n",
    "        for i, j in T.grid(4, 4):\n",
    "            with T.block(\"C\"):\n",
    "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
    "                C[vi, vj] = A[vi, vj] + B[vi, vj]\n",
    "\n",
    "\n",
    "rt_lib = tvm.build(MyAdd, target=\"llvm\")\n",
    "a_tvm = tvm.nd.array(a)\n",
    "b_tvm = tvm.nd.array(b)\n",
    "c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))\n",
    "rt_lib[\"add\"](a_tvm, b_tvm, c_tvm)\n",
    "np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 广播加法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# init data\n",
    "a = np.arange(16).reshape(4, 4)\n",
    "b = np.arange(4, 0, -1).reshape(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 4,  4,  4,  4],\n",
       "       [ 8,  8,  8,  8],\n",
       "       [12, 12, 12, 12],\n",
       "       [16, 16, 16, 16]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# numpy version\n",
    "c_np = a + b\n",
    "c_np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class MyAdd:\n",
    "    @T.prim_func\n",
    "    def add(A: T.Buffer[(4, 4), \"int64\"],\n",
    "            B: T.Buffer[(4), \"int64\"],\n",
    "            C: T.Buffer[(4, 4), \"int64\"]):\n",
    "        T.func_attr({\"global_symbol\": \"add\"})\n",
    "        for i, j in T.grid(4, 4):\n",
    "            with T.block(\"C\"):\n",
    "                vi, vj = T.axis.remap(\"SS\", [i, j])\n",
    "                C[vi, vj] = A[vi, vj] + B[vj]\n",
    "\n",
    "rt_lib = tvm.build(MyAdd, target=\"llvm\")\n",
    "a_tvm = tvm.nd.array(a)\n",
    "b_tvm = tvm.nd.array(b)\n",
    "c_tvm = tvm.nd.array(np.empty((4, 4), dtype=np.int64))\n",
    "rt_lib[\"add\"](a_tvm, b_tvm, c_tvm)\n",
    "np.testing.assert_allclose(c_tvm.numpy(), c_np, rtol=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2D 卷积\n",
    "\n",
    "使用 NCHW 布局的卷积的数学定义:\n",
    "\n",
    "$$\n",
    "\\text{Conv}[b, k, i, j] =\n",
    "    \\sum_{d_i, d_j, q} A[b, q, \\text{strides} * i + d_i, \\text{strides} * j + d_j] * W[k, q, d_i, d_j],\n",
    "$$\n",
    "\n",
    "其中,$A$ 是输入张量,$W$ 是权重张量,$b$ 是批次索引,$k$ 是输出通道,$i$ 和 $j$ 是图像高度和宽度的索引,$d_i$ 和 $d_j$ 是权重的索引,$q$ 是输入通道,`strides` 是过滤器窗口的步幅。\n",
    "\n",
    "下面考虑简单的情况:`stride=1, padding=0`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3\n",
    "OUT_H, OUT_W = H - K + 1, W - K + 1\n",
    "data = np.arange(N*CI*H*W).reshape(N, CI, H, W)\n",
    "weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[[ 474,  510,  546,  582,  618,  654],\n",
       "         [ 762,  798,  834,  870,  906,  942],\n",
       "         [1050, 1086, 1122, 1158, 1194, 1230],\n",
       "         [1338, 1374, 1410, 1446, 1482, 1518],\n",
       "         [1626, 1662, 1698, 1734, 1770, 1806],\n",
       "         [1914, 1950, 1986, 2022, 2058, 2094]],\n",
       "\n",
       "        [[1203, 1320, 1437, 1554, 1671, 1788],\n",
       "         [2139, 2256, 2373, 2490, 2607, 2724],\n",
       "         [3075, 3192, 3309, 3426, 3543, 3660],\n",
       "         [4011, 4128, 4245, 4362, 4479, 4596],\n",
       "         [4947, 5064, 5181, 5298, 5415, 5532],\n",
       "         [5883, 6000, 6117, 6234, 6351, 6468]]]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# torch version\n",
    "import torch\n",
    "\n",
    "data_torch = torch.Tensor(data)\n",
    "weight_torch = torch.Tensor(weight)\n",
    "conv_torch = torch.nn.functional.conv2d(data_torch, weight_torch)\n",
    "conv_torch = conv_torch.numpy().astype(np.int64)\n",
    "conv_torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.script.ir_module\n",
    "class MyConv:\n",
    "  @T.prim_func\n",
    "  def conv(A: T.Buffer[(1, 1, 8, 8), \"int64\"], # 1,1,8,8\n",
    "          B: T.Buffer[(2, 1, 3, 3), \"int64\"], # 2,1,3,3\n",
    "          C: T.Buffer[(1, 2, 6, 6), \"int64\"]): # 1,2,6,6\n",
    "    T.func_attr({\"global_symbol\": \"conv\", \"tir.noalias\": True})\n",
    "    for n, c, h, w, i, k1, k2 in T.grid(N, CO, OUT_H, OUT_W, CI, K, K):\n",
    "      with T.block(\"C\"):\n",
    "        vn = T.axis.spatial(1, n)\n",
    "        vc = T.axis.spatial(2, c)\n",
    "        vh = T.axis.spatial(6, h)\n",
    "        vw = T.axis.spatial(6, w)\n",
    "        vi = T.axis.spatial(1, i)\n",
    "        vk1 = T.axis.reduce(3, k1)\n",
    "        vk2 = T.axis.reduce(3, k2)\n",
    "        with T.init():\n",
    "          C[vn, vc, vh, vw] = T.int64(0)\n",
    "        C[vn, vc, vh, vw] = C[vn, vc, vh, vw] + A[vn, vi, vh + vk1, vw + vk2] * B[vc, vi, vk1, vk2]\n",
    "\n",
    "\n",
    "rt_lib = tvm.build(MyConv, target=\"llvm\")\n",
    "data_tvm = tvm.nd.array(data)\n",
    "weight_tvm = tvm.nd.array(weight)\n",
    "conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))\n",
    "rt_lib[\"conv\"](data_tvm, weight_tvm, conv_tvm)\n",
    "np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 如何变换 TensorIR"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 ('mlc': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "aaffa6035209f65dc356111783931130a4c4995d4af64fce84e8309d82e28b87"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}