{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tensor Core 编程教程\n", "\n", "在课程中,我们使用 `tmm` 内在函数来演示 `Tensorize` 的进展。 在本教程中,我们将把 TensorIR 运用到 NVIDIA GPU 上的 Tensor Cores。 请注意,Tensor Cores仅在具有 Volta 或更新架构的 NVIDIA GPU 上受支持(例如,`V100`、`T4`、`RTX-20X0`、`A100`、`RTX-30X0`)。 不幸的是,Colab 提供的大多数 GPU 都太旧,无法支持 Tensor Core。 您可能需要为本教程准备自己的设备。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 什么是 Tensor Core?\n", "\n", "张量核心是具有 Volta 和更新架构的 GPU 上的可编程矩阵乘法和累加单元。 每个 Tensor Core 都提供了一个矩阵处理数组,它执行 `D = A * B + C` 运算,如果我们使用 `nvcuda::wmma`,则`A`、`B`、`C` 和 `D` 是 `16x16` 的矩阵。 其中,矩阵乘法输入 `A` 和 `B` 是 `fp16` 矩阵,而累加矩阵 `C` 和 `D` 可以是 `fp16` 或 `fp32` 矩阵。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![WMMA16x16x16.png](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAABAAAAAD+CAYAAACkynf5AAAABmJLR0QA/wD/AP+gvaeTAAAACXBIWXMAAA7DAAAOwwHHb6hkAAAAB3RJTUUH4QoQDS8bm8RPawAAIABJREFUeNrt3Xt0VPW9///XFILaxCapSUaUQkISMEBuHK/JAUxUQkD8pR4P1f6E8PV4giyPJRyh2C6US1ltKUcIq8slybdVQm2tVG2OSLhVEHuIBVdNCCCGJIagXCahTWITTwW75vvHTIaZyeRG5rYnz8das9Zc9kzy2jP7PZ/9mc/+bFkB9Kuzs9P6q1/9yrply5Y+l9u1a5c1Pj7eKilkLiNHjiQPefx2iY+Pt7788sset6/du3dbf/nLX1rb2tqoE3wOyUOdMGSdGM5tJOofecjju/r31ltvWbds2WK1WCz9bocmq9VqFYBenT17Vs8++6za2toUHh6uV155pddlR48erQsXLrDSgCG49tpr9b//+7897n/22Wd1/PhxhYeH6+mnn1ZmZiZ1AqBOGKZODPc2EvUP8F39e+aZZ1RXVzeg2vc1ViHQu66uLq1du1ZtbW2O201NTb0uz5caMHR///vfe9x3+fJlnTp1yrEd/vSnP9Xx48epEwB1whB1gjYS9Q/wZf3r3va6urr0/PPP6+zZs70+fySrEOjdK6+8opaWlh4bWEJCQv9P3rra+Ctg4WrykCcwedzU1dXp0qVLjtuXLl3Sxo0btXHjRkVFRRm2ThRnbDX821ZSs5A85AlIHiPVCdpIPX1/j/HXwc/yyEOewORx3+6ca193J8DGjRs9Ls8IAEC9b0y7d+92uS83N1dTpkxh5QB+NmbMGM2fP9/lvra2Nr322mvUCQBBXSdoIwHwpejoaP3Lv/xLj2308OHDdAAAg3HgwIEeDYsnnnhCcXFxrBzAz6KiovTggw/2aNzv3r27xy9Q1AmAOhFMdYI2EgBf175HH31Ud9xxh8v9v//97+kAAAbq8uXL+p//+R+X+x599FGFhYWxcoAAmjt3rqKjo13ue/vtt6kTAIKyTtBGAuAvjz32mMvturo6j/Ny0AEAeHD48GHHpDaSbWiNe68aAP8LCwvTt7/9bZf79u/fr8uXL1MnAARdnaCNBMBf4uLiemyL7qN16AAAerFnj+tMITNnzmSlAEFi1qxZCg8Pd9zu6uoKyEzf1AmAOkEbidoHBJOcnByX2546P+kAANx0dXU5TiPk3JAAEBzCwsI0bdo0l/vch6NSJwDqRKDrBG0kAP52xx13uBwC1dXVpbq6OjoAgL4cP37c5VQaCQkJnD4ICDLZ2dk9tlvqBIBgqhO0kQAEQmZmpsvtI0eO0AEA9OXEiRMut2+//XZWChBkJk6cqFGjRjlut7S0+HWWb+oEQJ2gjUTtA4KR+3bp3vlJBwDgxv2cmZzTFgg+YWFhmjBhQp9fcNQJgDoRyDpBGwlAILhvl01NTerq6qIDAPDE/deBUaNGaeLEiawYIAi5D3Fz/2WKOgEgUHWCNhKAQAkPD1dCQoLLfc6dn3QAAE7cJ7aZMmUK57UFDNKwd/9lijoBIFB1gjYSgEByHwVQX19PBwDgyenTp11uu/eeAQgeCQkJPU7zdfbsWeoEgIDXCdpIAALJ/fCnpqYmOgAATz777DOX2/Hx8awUIIi5Dz91/oKjTgAIVJ2gjQQgkNw76OgAAHrh/qvAzTffzEoBDPQF5/4LFXUCQCDqBG0kAIF08803u5wFpa2tzTERIB0AgN3ly5dderdHjRqlMWPGsGKAIOb+C5T7L1TUCQD+rhO0kQAEA/fOz+55POgAAHppENx8881MbgMY7MvN10N7qRMAdYI2ErUPMGLt696OR7JqABv3oW2B6Nkuylml0sg/Krtiv6pCYJ2Shzy+1j3E7dKlS5Jsp6nq6upymfQrVOrE1JJCzZ/h4YHmWm0oqNY5g7135CFPqNYJ2kg+Yq7U+vR8l7tOHDVpm8WgK5U85PEx99FP587ZKjcjAAA753PbStJNN93k5/8gVQVjJUWmaF50KKxR8pDHP9x7uH05w3fg64QH49K0vDpHU0OlGJOHPAavE7SRvC1LeVnWHjtjkjQ53ar16UUGW5vkIY9/xMXFedyO6QAA7FpbW/vcaHxebjKny1ZqYjQn3mz49Uke8gSqYe/L4b2BrhOSpIPvaWlmue3y3Bn7nWM1c4lBe3LIQ54QqxO0kby9M3ZIuRH2HZjGbK3Ya7Jdju6y3Rk+SaMNtXNJHvIEpgOgra2NDgDAmXvvtn8b9mbNGxfjuJU0brKyDL02yUMe/4mNje2zoRo6dcKDHad1LJQKMXnIEwJ1gjaSN7+uVjp2xk4cNen5RqcD1iyzVdLYoBbLdp03zNcvecgTuA4ARgAAwdSwj56sOZGSOk5qV4ekyGlamWDglUke8gTwC667h3tYdADMjVeqJKlDtXvaZHjkIU8I1AnaSN6SpbxE+zDszs3a5+HY6/OWDXrHYpRZbMhDHv8KCwtzmeukq6tLXV1dTAIIBEPDvihjmpIkNTQf1DrFKj8tRvnxqVKTMX9rIQ95Atmwd9+WQ64DYMZ0baqefuW2QSeZIw95QrVO0Ebylikydw/F7u1X184y1RhmLZKHPIGpfc6HPLW3tzMCAOjeGLpnB/Z/o94+GZsuaudpi6pOn1SDJI1NUZEh1yZ5yBOaDfvA1ok+jEvT8upCPTY3RAoyechDBwBtJEmKmKTuAzdaO0Pg3DvkIU+Q1D46AAAF+Fe9hBTbZGwdJ7W9TVLbCe3skKQULc804ORs5CGPn0VFRWnUqFE+b9gHzfB/50nZMt9zHJOdutagM82ThzwhVCdoIwEIJtHRrhO2trW10QEASHLp2Zbkx3MDm1WSniJJamg+YT8Xu0Xbmy9KMuLkbOQhT+Aa9301WI1dJ/rSpJccM7NH6sZbjF6NyUMeY9cJ2ki+ERuRFVLrlDzk8ZfIyEiX2+3t7cwBAHhqBPjty617MjZJSWlPyJrmvtWmaF70flV1z1WUME/W6SmOhxtqtyi52uL5tQezrBHyOOdKv6jsiv2qMvL7Ewg+zpNjHqe9981X2CvrHPdZC1f1WM5Uvsbr0eLi4ly245aWFq//ShWwOtGf+nZZNFZmRerGZEkfS5qbo01rxzoWsWx9Sz/d3MukZ4NZ1gh5nHP9W3vgj1f3Vp5gMcQ8ibeatfjF+7Tstlcc922qLuyx3NLMckPWCdpIXtS5Xcc7lyg3QoqLmCIN5FvfXOly/vaWxmzXmdyvdlkj5HF+TuIplVQV+3a2en/l8Rcf5xkflaPHb92rH/4hzHHf+pnWHsut2GvyaiznkU+SrUOPEQCApMuXLwfkyy0rPkVJfS4RoyUZqfbrqaqcnqKG2i0yla9Rdu1FJaU9pBKPp2IezLJGyGNTlLPKZYfUuO9PYPgyz7qMHO2buUAjR4zwuMPvfPEF90Z8e3t7yNSJfiVHyXbwRocu1EtSgh5bO1aWrW9paWa5NmztkHnhDM32+GvtYJY1Qh6bqSWFLjukxn1/gtAQ8uQ/eacWb5mpESNHeNzhd74YtU7QRvKmKtVaGuw7WsuVF+GxZaC8xCzH9QXp+Y5zt5c0NigusbzX5w18WSPksclIt7rskBr3/QkE3+WZmbhOj9+6TyO+NtLjDr/zxdvcRz51dHTQAQBIPU8H9PWvf90PfzVVK9Ns52Lv3slyvmTX2oZlOyZnS0hRvn3iNkn2ydtiNCHKU5t+EMsaIY+krMwnVXB6jRad8denwrd5/M+3eRZPvFWPH3orYOncG6TuQ1aNWyf6Y9v5kiQ1N6vmY9lP1XblNG3n9jTL0v1rrbvBLGuEPJJuWlKgjHfK9auDMv77E3SGlif7X8fptbWHQrpO0EbyrvONG3RCkpSk3CyrFpjdduJnliq3e6fLXKDJatBx+2nazlt2qkVJMnvqrxjMskbII2l0Yr3SLpj0qh8HHvoyTyD4Ks+d31qsN048HhQZL126xCEAgCR98cUXfTYSfNOOsk/G5rST5azq9Ek1pE1TklJUkCAdj4qV1KqPur+H21pVLyk5yizJ9flZg1jWCHkkqar6Bc2WVBTvr3aub/P4v93u2zw3vLZBkvTytIIej9UVPKnR112vqtZPtej9t9Xc1eH1eO4NUl8c2xuQOuGJ+2nZJEkd2veMbaj7TeMjJXXowsf2hz5uV4ukuPHRklwb8oNZ1gh5JOnc5gq9JGnqPTL8+xMUvJhn5d2vSZK+u2Zajz/zg4oCfSPmOp0+2qrt695X2/kuQ9YJ2kjeVqZtVZP0dNYSxUmanG7V+l6WHB2RLKlelk77HZ0fqVX247ndzt0+mGWNkMe285qsbZIybpTh35/A8U2ete/eIEmal/pyj9dZllWn668ZreaOKr15YpHav2z2aiJPI58YAQBI6urq8vOX25XJ2ByzsbtzzM4u5afnKrh/HCIPeTw70vqZHj74usb8bqNOdlxU6V33+ySh+zFuoVEnBqi5VhsyK1T5cYgUZPIMuzzNx1u1bcVBrc77nSxNHZq38i7D1gnaSD7QWazn93r+Zds2/DpZezoNtELJQx67M+1H9Ovah/Xj98aopfOkHpxc6vN4jAAA7NyPA3Q/ZYb3WVRcsUbFg1gmKzMliNcgecjj2R2Vv3RcX13zrj59aKlPErof49ba2hoCdcLVh8Xl+jCE6i55yNNt+/xKx0SNu7fUaNXuhwxbJ2gj+U7NUZNqQmi9koc8Lxy5w3F9X+Nq/XD6p17/G+7bKyMAADv34wDDwsKC7n+sam+VFKtJ3dtxdKySJdW3W4a0rBHyGAF5+meS1PmVb4659ccve0aoE5J07pMOuZyi7ZYoxUlq+aRtSMsaIY8RkMfDa7jd/vKLrwxbJ2gjBc75znpJyTJ3T8IWMUmxklo7q4a0rBHyGAF5+msjmXTpH94fKuG+vXIWAMBI2i+qQTGaE2+bkcQ2o/tJVTRJUqoqC1fJWpBrO497n8saME+ovT8hl8fc68usSp+htRl3KzLsGi2ddJf2nmv0yb/L+b2d1LfLokil5dl6b27KGyezzqhmhyQl6LHqQm2qyNRN/S5rwDyh9v6EXB7PvxznLUrXrMUZujYiTHfPn6S6P52jTmDwuk6pRUmaYrZ9M402z1GcdqnWItkmcLNqfVaJRve7rAHzhNr7E3J5PLf+7h2/SvclrtU1IyI1LX6pTv1lr1/+dToAAKNo26/C2ou287cXrtKhtBjtem+7yiRJx1RxRlJkjKb0u6wB84Ta+6NUWQtXOZ7a2/VQeH8ufPE3LRifrk8e/J5uiYzRsg/2si372sfV2ra1Q+aFD2hTdaGWL4zUsecO2IdzN6nmoKRxUbqx32UNmCfU3h8laFN1oeOpvV03+vvz+cUvdNv947Xy7QcVFx+p/974AdsxBq+zWL9pbFBc4iGtn2lVcWKSThydbR/GXWbbMYuYYOuy7nNZA+YJtfdHRVo/0+p4am/Xjf7+/O3LC/qn0Qu0Ytonig2/RTvrlvnlX2cOAEA9fwVwnzEzWFRVvyBTdR8LnDnp2Mnvd1mD5elWdmBNcHVkXFWeYzKVH3PcbSpf4/G6cfJYev3/S+s/VGm973cn3bdZX/yyZ5Q6Idlmwl+6uY8FDp527OT3u6zB8nQL5mPyB56nSUszrwzdWppZ7vG6cfK0efzf33+jXu+/UR8SdYI2UmCdb0zWir4GmlkqHDv5/S5rsDzdgvlY/IHnKdOKvVdaeyv2mjxeN06eKo//++GzpTp81rcT/3mqe4wAAEJBwjyV6k2ZDhwjD3nIg97NzdF8vaelxU3kIQ95MHyYK/WIFmnF0TLykGfY52EEABAKmrbL1EQe8pAH/dhxQEt3kIc85MEwY5mtFRbykIc8EnMAAAAAAAAwLNABAAAwPI7tBUCdAAA6AIAB6erqcrkdHh7OSgFAnQBA7aP2AXQAAHy5AaBOUCcAUPsA0AEAAAAAAADoAAAAAAAAAHQAAAAAAAAAOgAAAAAAAAAdAAAAAAAAgA4AAAAAAADoAAAAAAAAACFqJKsA8JGFq8lDHvKgTyU1C8lDHvJgWPpZHnnIQ55AYAQA4EUmk4mVAIA6AQDUP4AOACDUWa1WVgIA6gQAUP+AoMQhAICvbF1t/AzOw8rJQx5/5hkmijO2Gj6D87By8pDHn3lgbN/fY/wMzsPKyUMef+YZCkYAAAAAAAAwDNABAAAAAAAAHQAAAAAAAIAOAAAAAAAAQAcAAAAAAACgAwAAAAAAANABAAAAAAAA6AAAAAAAAAB0AAAAAAAAADoAAAAAAACgAwAAAAAAANABAAAAAAAA6AAAAAAAAAB0AAAAAAAAADoAAAAAAAAAHQAAAAAAAOAqjGQVAMGjKGeVSiP/qOyK/aoiD3nIAydTSwo1f4aHB5prtaGgWufIQx7yIJSZK7U+Pd/lrhNHTdpmIQ95yDMYjAAAgkaqCsZKikzRvGjykIc8GKBxaVpenaOp5CEPeRCSspSXZe2xMyZJk9OtWp9eRB7ykIcOAMCAX2+Z02UrNTGaE28mD3nIA88OvqelmeW2y3Nn7HeO1cwl0eQhD3kQgjv/h5QbYbvV0pitFXtNtsvRXbY7wydpNHnIQx46AABjMWveuBjHraRxk5VFHvKQB/3ZcVrHyEMe8iBkm0crHTtjJ46a9Hyj0wFrltkqaWxQi2W7zpOHPOShAwAwlOjJmhMpqeOkdnVIipymlQnkIQ950I+58UqVJHWodk8bechDHoSQLOUl2odhd27WPg/HXp+3bNA7lirykIc8g8AkgEAQKMqYpiRJDc0HtU6xyk+LUX58qtR0jDzkIQ9czZiuTdXTr9w2+qRs5CEP4NEUmbuHYvf2q2tnmWrIQx7y0AEwWM8//7zH+xsaGrRly5Y+n/tf//VfkqTMzEyNGTNGZrNZkZGRrFQMgn0yNl3UztMWVemkGtKmKWlsiop0TGXkIQ950JdxaVpenaZjz5XrpR3kIQ95ECIiJinWfrW1s4o85CEPHQDes2zZMq8+d+XKlZozZ47uvPNOVi76l5Bim4yt46S2t0nSCe3smKYlkSlanmlWWbWFPOQhD644+J6WFjd1v6F6rHq6UiWlrs3R1B0H9CF5yEOeoHfkyBG9+uqrV/38vLw8zZo1ixUJYNCYA8AH1q1bp7vuuksPPPCA9u/fzwpBH8wqSU+RJDU0n7Cfi92i7c0XJRlxcjbykAf+1aSXHDOzR+rGW8hDHvIYwV//+lfV1NRc9eXChQvDan3FRmSRhzzkoQMg+O3YsUP33HOPHnjgAZ06dYoVgp66J2OTlJT2hKyFq2QtXKVDafYZ2ns7R3vCPFkLcvveWUuY53g9a+Eq1WeajZ3napY1Qp5Q+bxJyjGP0+VHV7rc5/wZ7L7Ay+rbZeneIUt2e2xujjZVZOqmvp4/N0ebqgsdl2cCfXq3oea5mmWNkCdEPm+Jt5r1Xx886nKf8+ev+wKoc7uOd9quxkVMGdxzzZVan1XS9+nazJVaP9PquDydmGXsPFezrBHyhMjnbXxUjn5872WX+5w/f90XOgBCqCNg4sSJ+u1vf8vKgIus+BQl9blEjJZkpLrcU5SzStbpKf28cqoqp6eooXaLTOVrlF17UUlpD6kk2qh5Br+sEfKEzudNWpeRo30zF2jkiBE9HjOVr3G5wMuSo2Tr3uvQhford08tKdSmtWP7eXKCHls7Vpatb2lpZrk2bO2QeeEMzb7FqHkGv6wR8oTK5y3/yQwt3jJTI0b2rBNLM8tdLoBUpVpLg30Ha7nyIjy2DJTntuOekW7V+vT8/loUWpCe7zjPe0ljg+ISy3v5G0bIM/hlg/39CaXP28zEdXr81n0a8bWeR9+v2GtyudABEGIeeeSRXiccxHCUqpX2X167d9SdL9m1tmHZGpviKINZmU+q4PQaLTrTX5s+Rfn2Sd4kqer0STUoRhOiDJpnkMsaIU9Ifd4kLZ54qx4/9Babtd/ZduAlSc3NqvnYdvWmJQXKeKdcvzrYz9PnxivV6ZRu5/Y0y+Lpl12j5BnkskbIE0qft+x/najX1h5is8WAnW/coBOSpCTlZlm1wOy2Ez+zVLlOO+6jE+uVdsGkV/ubzsZcoMlq0HH7Kd3OW3aqRUkyhxs0zyCXDf48ZSH1ebvzW4v1xonHg2a7YhJAP1u2bJna29v1ox/9iJUx7Nvt9snYnHbUnVWdts/OrhQVJEhlTVJV9QuaLakovp+d5ahYSa36qPs0zW2tqpeUHGWWZDFcHmlwyxohTyh93iTphtc2SJJenlbQ47G6gic1+rrrVdX6qRa9/7aauzrY/ofC/bRskqQO7XvmyqnZzm2u0EuSpt7Tz87y+EhJHbpg35HTx+1qkRQ3PlpSm+HyDHZZI+QJpc/byrtfkyR9d820Ho/9oKJA34i5TqePtmr7uvfVdr6LbR2SyrStapKezlqiOEmT061a3+cOXLK2Scq4sZ+d5YhkSfWy2Id8q/Mjtcp+7LdPz/PumzyDXdYIeULp87b23RskSfNSX+65b5hVp+uvGa3mjiq9eWKR2r9s9nlKRgAMQHNzs6xWa4+LxWJRc3Oz3nnnHcfpAAdi3bp1HA4w7F2ZjO3KbOxu2k5op30/KT89N8gnZyMPeTw70vqZHj74usb8bqNOdlxU6V33s/l7/UuqVhsyK1T5MXnIY8w8zcdbtW3FQa3O+50sTR2at/Iutmtc0Vms5/d6/pXVNoQ/WXs6yUMe4+U5035Ev659WD9+b4xaOk/qwcmlfonICIABGDvW8zFtcXFxjsdzc3P1+OOP689//rPuuaf/Lv1HHnlEU6dO1YQJE1jBw5JFxRVrVDzkZchDnuDOc0flLx3XV9e8q08fWsrmf5U+LC4PqVOukYc83UrmVzqu795So1W7H2KDRw81R02qIQ95QijPC0fucFzf17haP5z+qV+yMQLAiyIjI5Wbm6v29natXLmy3+WXLVvGSoNPVLW3SorVpO5J/6JjlSypvp1zvCNwTJI6v7rEiggi5z7pkMvp3G6JUpyklk/aWDkImC+/+IqVAJ8631kvKVnm7oneIiYpVlJrZxUrBwFqI5l06R/+GSpBB4CPOgJ+9KMf6dVXX+1zuR07dmjHjh2sMHhBqioLV105VVv7RTUoRnPibbOX2GZ/P6mKJoPmCbX3J+Ty9H6KyVXpM7Q2425Fhl2jpZPu0t5zjWyuAZWgx6oLr5yqrb5dFkUqLc/WW3hT3jiZdUY1OwyaJ9Ten2GSJ29RumYtztC1EWG6e/4k1f3pHJsqvKxIC2Zar5yqreuUWpSkKWbbt9ho8xzFaZdqLQbNE2rvzzDJc+/4Vbovca2uGRGpafFLdeove+kAMLqHH36435EADzzwACsKXnBMFWckRcZoiiS17Vdh7UXHud4PpcVo13vbAzSnqhfyhNr7I8lauMrxaG/XQ+H9ufDF37RgfLo+efB7uiUyRss+2MvmGlBNqjkoaVyUbpSkj6u1bWuHzAsf0KbqQi1fGKljzx0w0FB2tzyh9v4oQZuqCx2P9nbd6O/P5xe/0G33j9fKtx9UXHyk/nvjB2yq8LIy2859xARbl3VnsX7T2KC4xENaP9Oq4sQknTg620BD2N3yhNr7Yz8f0fqZVscSvV038vvzty8v6J9GL9CKaZ8oNvwW7azzz+hw5gDwsaeeekrr1q3rc5k//elPuvPOO1lZGHhZObDG8878mZOO+6uqX5CpOnTy9LuswfKYytc47u7tunHyWHr9/0vrP1Rp/YdstAHS63HcB0877j+3uUJLN4dOnn6XNVSeJi3NvDJ0a2lmucfrRnt/3P/399+o1/tv1LPBwmt6PX7bUuG4/3xjslY0hk6efpc1VB5b62LFXpPjod6uG+n9cf+/D58t1eGzpX7/fxkB4GNxcXH9HgpQXl7OisLQJMxTqd6U6cAx8pCHPOjd3BzN13taWtxEHvKQB8OHuVKPaJFWHC0jD3mGfR5GAPhBbm5un49v2bJFP/3pTxUZGcnKwtVp2i5TE3nIQx70Y8cBLd1BHvKQB8OMZbZWWMhDHvJIjADwi7i4uH7nAjh9+jQrCgAAAABAB4DR3X777X0+fubMGVYSAAAAAMBnOATAT8aOHdvn46dOnfLr/2My+XfyDKvVyocAAAAAAOgACH3R0dF9Pt7Q0MBKAgAACHIvvPDCkNttf/3rX4f0/D179ujo0aNDeo1vfvObevbZZ3lDAToA4Av9jQDYsmWLXnzxRVYUAABAEGtoaFBNTWBPtHbhwgVduHBhSK9x44038mYCwxBzAAAAAAAAQAcAAAAAAACgAwAAAAAAABgCcwAMU8zKDwAAAADDCyMAgsTcuXNZCQAAAAAAn2EEgJ+cOXOmz8dnzJjBSgo1C1eThzzkQZ9KahaShzzkwbD0szzykIc8gcAIAD/5+9//zkoAAAAAANABEOo+++yzPh+fMGECKwkAAAAA4DMcAuAn1dXVfT4eGxvLSgo1W1cbP4PzsHLykMefeYaJ4oyths/gPKycPOTxZx4Y2/f3GD+D87By8pDHn3mGghEAfrJs2bI+Hx8/fjwrCQAAAADgM4wA8IOjR4/2+fjcuXMVFxfn1//JZDL59e9x2kEAAAAACCxGAPjB66+/3m8HAAAAAAAAdAAY2JkzZ7Ru3bo+l7n99ttZUQAAAAAAOgCM7Cc/+Umfj8+dO1fp6emsKAAAAACATzEHgA/99re/1ZYtW/pc5rvf/S4rCgAAwCA2bdo05NfYvXu31q9ff9XPLyws1MKFC3kzAAwaIwB8uPP/yCOP9Ltcfn4+KwsAAAAA4HOMAAjgzv9bb72lyMjIgPyPzMoPAAAAAMMLIwC86OjRo1q8ePGAdv6feOIJZv8HAAAAAPgNIwAGoKWlRXFxcT3u7+jokMViUV1dnSorK/s93t/ZmjVrWLEAAAAAADoAgonZbPbq69XU1HjsUAAAAAAAwFc4BMDPampqOO0fAAAAAIAOgFA1d+5cNTc3s/MPAAAAAAgIDgHwg1dffVUPP/wwKwIAAAAAQAdAKCorK1NeXp7Gjh3LygAAAAAA0AEQSlauXKlilNhdAAAgAElEQVQ5c+YoJSVFkZGRrBAAAAAAAB0ARvbEE08oKSlJ3/jGN5SYmKgxY8ZowoQJrBgMSVHOKpVG/lHZFftVRR7ykAdOppYUav4MDw8012pDQbXOkYc85EEoM1dqfXq+y10njpq0zUIe8pBnMJgEUJLVah305cUXX9TTTz+tf//3f1dubi47//CCVBWMlRSZonnR5CEPeTBA49K0vDpHU8lDHvIgJGUpL8vaY2dMkianW7U+vYg85CEPHQCAAb/eMqfLVmpiNCfeTB7ykAeeHXxPSzPLbZfnztjvHKuZS6LJQx7yIAR3/g8pN8J2q6UxWyv2mmyXo7tsd4ZP0mjykIc8dAAAxmLWvHExjltJ4yYrizzkIQ/6s+O0jpGHPORByDaPVjp2xk4cNen5RqcD1iyzVdLYoBbLdp0nD3nIQwcAYCjRkzUnUlLHSe3qkBQ5TSsTyEMe8qAfc+OVKknqUO2eNvKQhzwIIVnKS7QPw+7crH0ejr0+b9mgdyxV5CEPeQaBSQCBIFCUMU1JkhqaD2qdYpWfFqP8+FSp6Rh5yEMeuJoxXZuqp1+5bfRJ2chDHsCjKTJ3D8Xu7VfXzjLVkIc85KEDADAW+2Rsuqidpy2q0kk1pE1T0tgUFemYyshDHvKgL+PStLw6TceeK9dLO8hDHvIYwaxZszRr1ixWRF8iJinWfrW1s4o85CGPl3AIABBoCSm2ydg6Tmp7m6S2E9rZIUkpWp5pJg95yANXzpOyZb7nOCY7da1BZ2YnD3kAAHQAAMODWSXpKZKkhuYT9nOxW7S9+aIkI07ORh7ywL+a9JJjZvZI3XgLechDHoSe2Igs8pCHPF7CIQBAIHVPxiYpKe0JWdPcHo9M0bzo/aq6mrmWEubJOj3FcbOhdouSqy3GzeOcK/2isiv2q8rI70+ofd4k5ZjHae998xX2yjrHfdbCVT2WM5WvYdv3pvp2WTRWZkXqxmRJHw/y+XNztGntWMdNy9a39NPNbcbN45zr39oDf7y6t/KEyOct8VazFr94n5bd9orjvk3VhT2WW5pZzrY93HVu1/HOJcqNkOIipkje/tY3V7qc672lMdt11nej5XHOlXhKJVXFvp2t3l95QuTzNj4qR4/fulc//EOY4771M609llux1+TzqIwAAAIoKz5FSX0uEaMlGalX8cqpqpyeoobaLTKVr1F27UUlpT2kkmij5rEpylnl0qlh3Pcn1D5v0rqMHO2buUAjR4zwuMPvfIGXJUfJdvBGhy7UD/bJCXps7VhZtr6lpZnl2rC1Q+aFMzT7lmDKM/hTVEwtKXTp1Aiu9ychxD5vA8+T/2SGFm+ZqREjR3jc4Xe+AFKVai0N9p3a5cqL8NgyUF5iluP6IFoUWpCe7zjPe0ljg+ISy3v5G0bIY5ORbnXp1Aiu92f4ft5mJq7T47fu04ivjfS4w+988Qc6AICASdXKNNu52Lt31J0v2bW2YdkamzL4r4CEFOXbJ3mTpKrTJ9WgGE2IMmgeSVmZT6rg9BotOhMC70+ofd4kLZ54qx4/9Babtd/ZduAlSc3Nqhn0r//xSnU6pdu5Pc2ydP+yGzR5mgb1CjctKVDGO+X61cFgfX+aQuzzNvA82f86Ua+tPcRmiwE737hBJyRJScrNsmqB2W0nfmapch077oOYxtZcoMlq0HH7Kd3OW3aqRUkyhxs0j6TRifVKu2DSq5ZgfX+G7+ftzm8t1hsnHg+anBwCAASsHWWfjM1pR91Z1Wn77OxKUUGCVDaINmNWVKykVn3UPYq3rVX1kpKjzJIshssjSVXVL2i2pKJ4478/ofZ5k6QbXtsgSXp5WkGPx+oKntTo665XVeunWvT+22ru6mD7Hwr307JJkjq075nBD3W/aXykpA5d6O44+LhdLZLixkdLajNcHkk6t7lCL0maeo/x359Q+7ytvPs1SdJ310zr8dgPKgr0jZjrdPpoq7ave19t57vY1iGpTNuqJunprCWKkzQ53ar1XnjV0RHJkupl6bTf0fmRWmU/9tun53n3TR7bzmuytknKuNH470+ofd7WvnuDJGle6ss9HluWVafrrxmt5o4qvXlikdq/bPZ5SkYAAAFxZTI2x2zs7hyzs0v56blBPjkbecjj2ZHWz/Twwdc15ncbdbLjokrvup/N39uaa7Uhs0KVH5OHPMbM03y8VdtWHNTqvN/J0tSheSvvYrvGFZ3Fen6v51+2bUP4k7WnkzzkMV6eM+1H9Ovah/Xj98aopfOkHpxc6peIjAAAAsKi4oo1Kh7yMuQhT3DnuaPyl47rq2ve1acPLWXzv0ofFpfrQ/KQJwTzlMyvdFzfvaVGq3Y/xAaPHmqOmlRDHvKEUJ4XjtzhuL6vcbV+OP1Tv2RjBAAQgqraWyXFalL3pH/RsUqWVN9uYeUgYEySOr+6xIoIIuc+6ZDL6dxuiVKcpJZP2lg5CJgvv/iKlQCfOt9ZLylZ5u5j0yMmKVZSa2cVKwcBaiOZdOkf/hkqQQcAEBJSVVm4StYC+9Dt9otqUIzmxNtmL7HN/n5SFU0GzRNq788wyrMqfYbWZtytyLBrtHTSXdp7rpHNNaAS9Fh1oTZVZOomyX5Kt0il5dl6C2/KGyezzqhmh0HzhNr7EwJc83g+FU3eonTNWpyhayPCdPf8Sar70zk2VXhZkRbMtGp9VolGS1LXKbUoSVPMtm+x0eY5itMu1VoMmifU3p9hkufe8at0X+JaXTMiUtPil+rUX/bSAQDgCufzqfe8fkwVZyRFxmiKJLXtV2HtRdu53gtX6VBajHa9t32Q88kGUZ5Qe3/6XT503p8LX/xNC8an65MHv6dbImO07IO9bMw+5Hw+dc/Xm1RzUNK4KN0oSR9Xa9vWDpkXPqBN1YVavjBSx547EDRD2QedJ9TeHyUM4DnBnEcDen8+v/iFbrt/vFa+/aDi4iP13xs/YGPGoFw5n3qRy7nVr1wvs+3cR0ywncays1i/aWxQXOIhrZ9pVXFikk4cnR00Q9ht/3eRW4Y+8hjm/RlonqIBPMf478/fvrygfxq9QCumfaLY8Fu0s26ZX/5/k9VqtVI2MNx9+9vfdrn9+9///uo2KJPT+Tu3rvZrhqKcVSrVmzIdOOa9F124mjzkCUieq/lq8tZ27M86UZyx1e/1bmpJoebrPS0t9s6QoJKaheQhT0DyBGOdoI008Pr3/T3+zZGRbtUjWqQVR733c8jP8kQe8gQkz2Dqn/s2zAgAIBQkzPP+zhh5yDNc8gwnc3O8ujNGHvIMqzwwLnOl13fGyEMeo+bhLABAKGjaLlMTechDHvRjxwEt3UEe8pAHw4xltlZYyEMe8kiMAAAAAAAAYFigAwAAAAAAADoAAAAAAAAAHQAAAAAAAIAOAAAAAAAAQAcAAAAAAACgAwAAAAAAANABAAAAAAAA6AAAAAAAAAB0AAAAAAAAQAcAAAAAAAAILSNZBYCPLFxNHvKQB30qqVlIHvKQB8PSz/LIQx7yBAIjAAAvMplMrAQA1AkAoP4BdAAAoc5qtbISAFAnAID6BwQlDgEAfGXrauNncB5WTh7y+DPPMFGcsdXwGZyHlZOHPP7MA2P7/h7jZ3AeVk4e8vgzz1AwAgAAAAAAgGGADgAAAAAAAOgAAAAAAAAAdAAAAAAAAAA6AAAAAAAAAB0AAAAAAACADgAAAAAAAEAHAAAAAAAAoAMAGIxRo0a53L58+TIrBQB1AgC1j9oH0AEAhJqoqCiX221tbawUANQJANQ+ah9ABwAAAMEsPDyclQCAOgEAdAAAAGjYA6BOUCcA0AEAAAAAAADoAABCU1xcHCsBMLD29nbqBICA1wnaSADoAABoJADwsUuXLrnc9sfQXuoEQJ2gjQTASMLDw+kAAAbSSABgrAapPxr21AmAOkEbCYDROgBGshqA4FGUs0qlkX9UdsV+VZGHPOS56gZpKDbsp5YUav4MDw8012pDQbXOkYc85Bn2dSKkmSu1Pj3f5a4TR03aZiEPecgzGIwAACRFR0cHwX+RqoKxkiJTNC86FNYqecjjPy0tLT5v2AdHnfBgXJqWV+doaqgUZPKQx8B1gjaSL2QpL8vaY2dMkianW7U+vchga5Q85Alc3aMDALALCwvrd2PxebnJnC5bqYnRnHiz4dcpecgTSL5o2AdDnZAkHXxPSzPLbZfnztjvHKuZSwzak0Me8oRQnaCN5IudsUPKjbD/7cZsrdhrsl2O7rK/kZM02lA7l+QhT+BER0fTAQB0bwzOLl++7Of/wKx542Ict5LGTVaWodcoecjjX+4N0q9//eshWCc82HFax0KpGJOHPAavE7SRvP11tdKxM3biqEnPNzodsGaZrZLGBrVYtuu8Yb5+yUOewNa9sLAwOgAASRo50nU6jLa2Nj9/u07WnEhJHSe1q0NS5DStTDBya4E85AmuBmtI1AlP5sYrVZLUodo9bcZ/48hDHoPXCdpI3pSlvET7MOzOzdrn4djr85YNesdilFlsyEMe/3PvsIuOjmYSQMBTI+Crr77y698vypimJEkNzQe1TrHKT4tRfnyq1GTM31rIQx5/a21tdbntPmQ1FOqEw4zp2lQ9/cptA0/KRh7yhFqdoI3kTVNk7h6K3duvrp1lqjHMmiQPefzPvcOOEQBAL40A//6yZ5+MTRe187RFVadPqkGSxqaoyJBrkzzk8b+urq4+G6zGrxN9GJem5dWFemxuiBRk8pDHwHWCNpIXRUxSrP1qa2cInHuHPOQJAE8jAOgAABTg49sSUmyTsXWc1PY2SW0ntLNDklK0PNOAk7ORhzwB4H6MW1xcXGjVCWfOk7Jlvuc4Jjt1rUFnmicPeUKoTtBGAhBM3DvsRo4cSQcAIAXylz2zStJTJEkNzSfs52K3aHvzRUlGnJyNPOQJjoZ9VFRUCNWJvjTpJcfM7JG68RajV2PykMfYdYI2km/ERmSF1DolD3n8xf2QHeYAAJw2hr4aCb77w/bJ2CQlpT0ha5rb45Epmhe9X1Vu37WJEdF6d1ahvvV6ieM+a+GqHi9vKl/j+e8mzJN1eorjZkPtFiVXW4ybxzlX+kVlV+xXVQDfn+D9oIdYHruuri6Xob3h4eE+Ob1XwOpEf+rbZdFYmRWpG5MlfWy7+4YxEXryF7O0dtbrjkU3VRf2ePrSzHLPrzs3R5vWjnXctGx9Sz/d3GbcPM65/q3df8ex95LHsAyax191gjaSF3Vu1/HOJcqNkOIipkgD/Gb/5nWJWnTru/rJH7/luG/9TGuP5VbsNXl+AXOlyzngWxqzXWeDN1oe51yJp1RSVeydWeyvMk/QCrU8vWyv4eHhdAAAUs9hgP76csuKT1FSn0vEaElGqooPHJNklmTRA2Mm6BdZDyj2uvDB7yBLklJVOT3FsdOflfmkDqU9pJLTL6i4zYh5bIpyVql0rKSOPwb0/fFqh0ZAP29D75zJMY/T3vvmK+yVdT5dD/4a1huoOtGv5CjZDt7o0IV6SRqpyTNu1Heey9L137xu8DvIkqQEPbZ2rGOn/6YlBVq+cIZm76lQ5cf+zhOtyTPCh5jHZmpJoebPkNTcHsD3J1pSm/c6NAL+eYvWDWMuX3WWxFvNWvzifVp22ys+XQ8M/zdi7atSraVBuRFJknm58iLKtKezx7e/8hKPa09jlaQspcTE6KHJv1DENbGD30G2v96C9HzHTv/oxHoVJ5Yrz5Ls4W8bIY9NRrpVj5gldZ4K6PsjVXmvQyMIPm/fvM5y1VnGR+Xo8Vv36od/CPPpevBU++gAAGQb3hYdHe0yrK2lpcXHDYRUrUyznYvd0y/wth3zGPvkbMdUZt+5/F7KHbr/nd/o8P3/fnV/NiFF+bqozadtr1d1+qQa0qZpQpSkNgPmsb92wek1WqRVKo0M7PvjrQ6NwH/ehpZlXUaOnkn9Z4342td8vh781bAPTJ0Y2I66JKm5WTUfS9JXmv5Iiv7v997Rf75y/9W97Nx4papD++yneju3p1mWhWl++MXXU542TX/k1qHlkXTTkgJlvFOuX6lQ8+MD+f60afKMMV7p0AiGz9tQOmfyn8zQPf8nVSNGfM3n64EOAGPWvvONG3QisVSTlaTcLKvMR03aZnHaWZ9ZqslqkCzJ2tNZpexxf9DL1ffrqTsPX90fNBdoshq0336qt/OWnWpJXCJzuKROA+aRNDqxXmkXTHpVVj0SHtj3JyXmAa90aATD5+3MtTF6aHLFVWWZmbhOdyc8oxFfG+Hz9UAHANCHuLg4/365dU/GZp+N3V33jnmSUlSQIJU12e6/d9+ven3JuoInNfq661XV+qkWvf+2mrs6eu4sR8VKatVH3VHbWlUvKTnqyi/YRsojSVXVL2i2pKL4wL8/3ujQCJbP21CyLJ54qx4/9JZenlbg82juX26+nNnb73XCE/fTskmSOrTvmStD2l98Yl+vT/9BRYG+EXOdTh9t1fZ176vtfFfPneXxkZI6dKF7Z//jdrVIihsfPdSewoDkkaRzmyv0kqSp9wT+/RlyB00Qfd7+vyFkyf7XiXpt7SF9d820kKoTtJG8qUzbqibp6awlipM0Od2q9X0s/Ys/39vrY8uy6nT9NaPV3FGlN08sUvuXzT13liOSJdXL0r2z3/mRWmU/Jtwr53/3bx7bTm2ytknKuDHw70/2uO8NuUMjWD5vQ8ly57cW640Tj2te6ss+T+Q+Z0dUVBSTAALOX259NRa868pkbI7Z2Htssd2zs0v56bn9Ts52pPUzPXzwdY353Uad7Lio0rv82bAkT3eee/f9Skf+cq7XDo3PH3lGu+/9/zUuPDKo8ww1yw2vbdDWT476ZT24n9s7NjY2ROrEADXXakPmwIbmNx9v1bYVB7U673eyNHVo3sq7gq8YD4M8Lz6xT5+e+EuvHRo/+Z9HtOiFexU9Ojzo8wwly8q7X9MHb33il/XgzzpBG8nLOov1/F6TXvXwG0VLY7ZW7O1/eP6Z9iP6de3D+vF7Y9TSeVIPTi4N3Mocxnl+8ed79dnnR3rt0FiT87kem7pbUdeMC/o8Q8my9t0b9OfzW32+Hnob+cQIACAgX24WFVesUfGQl7nijspfOq6vrnlXnz601I9rjzz9OdL6mZ54/201/q1NazJzVHrX/Zr1h18bMs9QsvhiPZw9e9bl9s033xySjeAPi8v14RBfo2R+peP67i01WrX7oYDVXPJ47tD43br3dfGzvyl/cabmrbxLpU/+wZB5hpLFF+vBn3WCNpJv1Bw1qeYqn/vCkTsc1/c1rtYPp38a8HVKHtcOjTc/ekJ//d9G3Ze4Rg9OLtVLH84yZJ6hZPH2euitA4ARAICd+68B7r8WGIlJUudXlzw+VtXeKilWk7pHP0bHKllSfbvFkHmM4o7KX6q6zaLPv7qk1TXvKiv2W8Myiy/WQ11dncvtMWPGUCcG6MsvvvJ4/7lPOuRymrdbohQnqeWTNkPmMYqS+ZU6W9emL7u+0u4tNYpPjx2WWXyxHvxZJ2gjBXsbyaRL//D8E/v5znpJyTJH2O+ImKRYSa2dVYbMYxQvHLlD5zur9eU/Pte+xtUaF5k1LLN4ez30dugTHQCAnXvvtvuvBcFuVfoMrc24W5Fh12jppLu091yj/ZFUVRaukrXAPky9/aIaFKM58bZ5nG0zw59URZNB8xi+g8bYefrOYh7Ec6/+y8391F7+HAEQnHWi98F9eYvSNWtxhq6NCNPd8yep7k/dh3gk6LHqQm2qyNRNkv1Ub5FKy7M1Fm7KGyezzqhmR3AlHXCeoDLwY8/76tBwnim/t+vBlGconTND7djxd52gjRR87h2/SvclrtU1IyI1LX6pTv1lr+OxBTOtWp9VotGS1HVKLUrSFLPtW2y0eY7itEu1FoPmCSoDa+W4d2gEb56r65z5j9uPqOjWd/3SsXP69GmP2zEdAEB3czEhweX2Z599Zqj//8IXf9OC8en65MHv6ZbIGC37oPvL4JgqzkiKjNEUSWrbr8Lai7bzwBeu0qG0GO16b7vKjJrH8B00PfM4nyqvt+tGyTLw516dpibXnitf/6pn9Drx+cUvdNv947Xy7QcVFx+p/974QfeaVM1BSeOidKMkfVytbVs7ZF74gDZVF2r5wkgde+7AkIe4ByyPQQymQ8N5lvzergdnFvXbOdPXc41QJ2gjBZ+/fXlB/zR6gVZM+0Sx4bdoZ90yx2O1FkkRE2xd1p3F+k1jg+ISD2n9TKuKE5N04ujsqx7a7tM8Ny3sP09QqeqzQyM/6WceOzRc8xS5nCqvt+vB2jkjSe9/9qJuuj7jqp471NqXnJwsiTkAAIeoqCiX09x0dXWpvb1dUVFRQfn/up9KrbT+Q5XW99E8P3PSsZNfVf2CTNWhk6db2YE1QdeR4dyh8YPUaXpy4m36w/kmPfmnnb3mcV4XvV03RhbL4J87xC+3CRMmUCf0Va87g++/Ua/336jv/akHTzt28s9trtDSzcG1HQ0lTzdvHKs/NG19dmjc+1iq/vk7E3Xq8Hm9/pM/OXVoTFfqDFuHxjklSAqWYVttg8wi1RyUU5bBPdcIdYI2UuC5n0rt8NlSHT7bx0R5lgrHTv75xmStaAydPI7tbgjH6PujQyMn4Qe6bcy/qeGvf1DFR0/2kqdMK/aWeVwvwXAaQecsd33rSY9Zznb8Wcdb3ryq5w619nV3ftIBADhJSEhwOV1GU1OTMjMzDR5qnkr1pkwHjoXIm2SMPAPu0DBAnqFkGXTHziC5D2+Lj4+nTlyNuTmar/e0tLgpNOqEQfJcfYdGUwhluYrnGqBO0EYyCHOlHtEirThaFhpvkkHyDLhDwwB5Bts5c9uYx7Tj4+Kreu5g9HXoE4cAAE7cjwk02hA3j5q2h87OP3nI4v4n3Xq33YepUicGaMeB0Nn5D8U8odRBE4AsgagTtJEMwjI7dHb+yWOMr6e6Yn35j88DWvcYAQA4cf9VwP1XAwDBo6WlxWWG21GjRvnl2F7qBPzfYjygpTvIYqQ6QRsJQCD11QHACACgl41Dko4fP85KAYKU+/Y5YcIEhYWFUScABLxO0EYCEEgnTpxwue3cgUcHAOD25RYeHu643dLSYrhT3QDD9ctt8uTJ1AkAQVEnaCMBCJTLly/r1KlTLvdNmTKFDgCgN1OnTnW5/eGHH7JSgCDk/uuT85cbdQJAoOsEbSQAgVBXV6dLly45bsfFxSkuLo4OAKA3GRmu5+Z0//UAQOCdPXu2x3G9EydOpE4ACJo6QRsJQCD01/FJBwDgxn0jOX78uC5fvsyKAYLIH//4R5fb/j6ulzoBUCdoI1H7gGBUU1Pjctv90Cc6AAA3cXFxLjMEd3V19WhEAAisQ4cOudz+53/+Z+oEgKCqE7SRAPhbU1OT6urqXO5jBAAwADk5OS639+7dy0oBgkR1dXWP80/fcccd1AkAQVcnaCMB8KcDBw702Pl3Pv6fDgCgF7m5uS636+rqepxPE0Bg7Nmzp0ejPioqijoBIOjqBG0kAP5y+fLlHrXv/vvv77EcHQCAB1FRUT1+KXjttddYMUCAHT58WIcPH+73y406AVAngqFO0EYC4C9vvPGGy+z/0dHRPc7cQQcA0Af3IW6HDx/W/v37WTFAgJw9e1alpaUu902cODGgp/WiTgDUCdpI1D4g0JqamvT73//e5b7c3FyPE5+arFarlVUGePaf//mfPYa1Pf30071OJGQymVhpgBe4fzUdP35czz77bI/lNm7cqISEBOoEQJ0I6jpBG4n6B/iq/p09e1b/8R//4fJ4dHS0fv7znys8PLzHcxkBAPThqaee0qhRo1zuc59Iw9m1117LSgN8wFPjff78+UHRqKdOANQJ2kj91z7qH+AbnuY3eeyxxzzu/NMBAAygMfHUU0/128DotmrVKlYaMERPPPFEj/vCw8Ndtr3c3Fw9+OCD1AmAOmGIOkEbifoH+Kr+ude+73znO32PxOEQAKB/1dXVeumllxQWFqaNGzeyQoAAeOmll3T48GF95zvf6TELNXUCgBHqBG0kAL7w85//XNXV1Xr00Uf7rX10AAAAAAAAMAxwCAAAAAAAAHQAAAAAAAAAOgAAAAAAAAAdAAAAAAAAgA4AAAAAAABgmA6Ahs3KNplkGsAle3OD7Tm7FvW+3KJdg/nj2pxt0uCeYv9/B/WkAPLj+m3YnO3xOX2uKo//X7a6/xWAOgKAekG9AKgj7L+w/xIqHQDeVjbb9cPQ58aTrOKqgb/0rkUmmZKLVSWpSKeG505qb+u3YbMKe1mZZbNNMmVvVoOnDc6+Pl1VqTiZxg6oIwCoF9QLgDpCHWH/JahYh6q+xJolWaUsa0n9AJ9TWWSVZJWKrJWuD1iLJPtjsmb1+oKuyxVVDuRPdi8/iP8zGPhx/daXZHm8z/PyTq+VVWKt7+U5A3lvAOoIAOoFAOoI+y/sv/hekI0AyFeptV4lWfZ+mO07PffamGarbBCv2rA5W7PLJClLJfWHtCRpuHZP9r1+k5Yc0iG3lZO05JAqizwtn6/S+hJlFVXKemiJklyeU+74G2UVjAIAdQQA9QIAdYQ6wv5LMAjCSQCTtGSl4x3TzgbXDSG52DYAptJaqaKBbT2OoSFFld7YeHZp0UCPGXEcy7JIu9TzGJXAjDDpff32JnmSfWuYMsFlQ1HSEh0qzff4NyZM4esA1BEA1AvqBUAdYf+F/Rc6AIagqNIqq7VU+QNaukGbC23HeGSV1Ks0f6jbzqJeeu9sx4z0ddxPxSKTfeO3yyrR8nxjrPP6j2z/d9ak5AE+o0GnjmuQzwGGSR0BQL0AQB1h/4X9FzoAPJmiCU5dNklLDg1uI2jYqe1VklSkeSoc2kyPDZuVPbvM8eGvt1pltV8cQ0yKk3vpFStTWZltI+5+jvuwk2BYv55j24cfFVX2GF7Te6HZYJ/cJEvz5vDTB6gjAKgX1AuAOsL+C/svIdYBYJ85ccinVGjQ5k+0IpcAAAObSURBVHXdH9RJGkr/S8PO7fYZHstU3GOGyP57vVw/E/bZIrNKVO/24c8vvbIRla3rOduk7Wn1A/8ABsH63bXI9vrJxVJJvVXWAVcup9cvWsnQR1BHAFAvqBcAdYT9F/ZfQq8DwBt2aZHpyqkxilYOpZepQTu3O200RZVXeq+cJ5IoLhzAB3CXKsr6/p/yC7q70T5Sfc/NJ0h6kq5m/XZvuLbjgPpd65sL7a9fpErGPoI6AoB6Qb0AqCPsv7D/EoodAFm2nhar+6WviSvKNNulN+jK8SlePXYuq0T1Li+WpCWHuifhqNL2/maSaDil4/aMvR4SkjxJtm3yuE71eLn+h6oE2/rNL7W/fnfXoMo0u79ZP5wmLMkqWS52/0EdAUC9oF4A1BH2X9h/CckOAC9+2K3WIQ43cXvJeXM89BTlq8Dj6SFCvc4Ncv3ml17ZiMrW9dHbeGXCkkEdbwNQRwBQLwBQR6gj7L8Mhw6AIlW69wj5cXIJx+khQpaX1m9+gaO38aP6XjYfhv6DOgKAekG9AKgj1BH2X+gACIz+z+XYfXqI/l9qgqb08wFS/Uf2CTu8MVzGYFzOVVrK0H9QRwBQL6gXANh/Yf+FDgD/6p7YwvMQmSsTY3geYuPySo7hNmUVno4jcZ45siD0PkC7KuzH3ng6hmiXFiVfGTrDj/+gjgCgXlAvALD/wv4LHQCB2IJsQz+qipXsMgFEgzZnz3Z8KAYyw6Vjlsyy2TL1eK1kx3kjS5Yb9RPUoM3ZHk7J4XL+0HlyX1W7FnWvR4b+gzoCgHpBvQDA/gv7L3QABG4LUml9iW12y7LZTrNJOp1KovLQwM71mF+q+u5zb/TyWlkl5cY9b+SuDSqu8nCuzu7eMRWp0v34m12L1L1t9Zyx0+mSvZnJjkAdAUC9oF4AYP+F/Rc6AHwsaYkOOZ038wrbBBOD6fRJWnJIVmv36Td6vpaRZ45sSF5+pUC4ySqpl9XqfmzMLi26svUA1BEAoF4AYP+F/ZcgZ7JarVa2MgAAAAAAQtvXWAUAAAAAANABAAAAAAAA6AAAAAAAAAB0AAAAAAAAADoAAAAAAAAAHQAAAAAAAIAOAAAAAAAAQAcAAAAAAACgAwAAAAAAADoAAAAAAAAAHQAAAAAAAIAOAAAAAAAAQAcAAAAAAACgAwAAAAAAANABAAAAAAAA6AAAAAAAAAB0AAAAAAAAQAcAAAAAAACgAwAAAAAAABjV/wPupMRARu5eMQAAAABJRU5ErkJggg==)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "CUDA 语言只能使用 `warp-level` 原语 `wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)` 在 Tensor Core 上执行 `16x16x16` 半精度矩阵乘法。 在调用矩阵乘法之前,我们必须使用原始的 `wmma::load_matrix_sync` 显式地将数据从内存加载到寄存器中(类似于我们在第6章第2部分的 `tmm` 演示中所做的)。 NVCC 编译器将该原语转换为多个内存加载指令。 在运行时,每个线程从矩阵 `A` 加载 16 个元素或从 `B` 加载 16 个元素。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 准备工作" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tvm\n", "from tvm.script import tir as T\n", "from tvm import tir\n", "\n", "import numpy as np" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 编写 Matmul 的 Tensor IR 程序" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "@tvm.script.ir_module\n", "class MatmulModule:\n", " @T.prim_func\n", " def main(\n", " X: T.Buffer[(1024, 1024), \"float16\"],\n", " Y: T.Buffer[(1024, 1024), \"float16\"],\n", " Z: T.Buffer[(1024, 1024), \"float32\"],\n", " ) -> None:\n", " T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n", " for i, j, k in T.grid(1024, 1024, 1024):\n", " with T.block(\"matmul\"):\n", " vi, vj, vk = T.axis.remap(\"SSR\", [i, j, k])\n", " with T.init():\n", " Z[vi, vj] = T.float32(0)\n", " Z[vi, vj] += T.cast(X[vi, vk], \"float32\") * T.cast(Y[vj, vk], \"float32\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "注意计算 `Z[vi, vj] += T.cast(X[vi, vk], \"float32\") * T.cast(Y[vj, vk], \"float32\")` 与常规表示有点不同。 由于 Tensor Cores 加载 `fp16` 的数据,但在 `fp32` 进行计算。 所以我们必须在计算之前将数据 `cast` 到`fp32`。\n", "\n", "![image11.png](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAA3cAAAFaCAIAAAAy/DOAAAAgAElEQVR4nO3dP2gcZ8LH8QdOJAuOczImPnEhRhcjkCFnVLiQZRL2XBhzwbAugsWhQiSgCAy2cCDY1TYOKlysqghCQJBGviKvCoPEVVu6VKlChcotVW65bzGzz87OPM/MM7PPzDN/vh+msFez8/eZZ377zDMzYgQAAADYJlwvAAAAAGqIlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQAAAPtImQAAALCPlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQAAAPtImQAAALCPlAkAaIbjLSGEEFvHKb+xtnee41IB9UXKBABMnO+tiSm1iViTNYus0vGWKnvKL9RmEwAFI2UCADyRhFmroKlMmZMPSZmAdaRMAMBoJK8nK6S4xFxqiivmpEwgR6RMAMAoGDKDoep8b602IVOFlAnkiJQJABgFUmZcqNQFL0VYmzQchhtJ1/bOo1fnE6OsfvnO99Z0n45nN70I/rhJHQSCU5ieALETMELKBACMAikuLkKlT5kpJGQ3Xcz0561borW98/DCpE2ZWZYWACkTAOAJJjFdy2KuKTOpQVM9c3U6luN6k8whZdanuyqQG1ImAMATCoaK1rosKTMQxqZmoMqjqWPmVBKcfDsUMnXLaNgvc+qvxksLgJQJAJiINEFORakZU2bipwmXoSOzH18t39paC05UFTJnSJmhpTJdWqDxSJkAgDDdzS6OU+ZkTG9E/39re+dTnTPVM7GWMrn3HDBEygQAKAWuGKtuu1aPmnPKnIqZU8+/nLwMcjqJxi0jKRPIFSkTAKARyWDuU+YkQ/pXyeWUxplzSx0ySZlA4UiZAACdcKDShbICU2bo1m/l2yKVkyJlAkUjZQIARqOR4j0/gdgWajAUys6axaRMzb3fcX8I/TX4t8kKTZZ8LxSpSZlANqRMAMBoFP+AS5nAkp4hWUTKjGmAjHuBkfpbqhUKN9ySMoFsSJkAgFFcyJwOUwlPWy8iZWpe9xOYkupJlppsqoiZpEzADlImAMAXSZDqHBV9B3mR/TLlN5Qjn++tJS51JIOqH0dPygRmRMoEAACAfaRMAAAA2EfKBAAAgH2kTAAAANhHygQAAIB9pEwAAADYR8oEAACAfaRMAAAA2EfKBAAAgH2kTAAAANhHygQAAIB9pEzANv8lx9OvSlZ+WCH+e57TLH+GrwAAaqTiKdM/jcUYn+E0Y67tnZvOxuBceb63ZjpJlIR5ETLkp8npLyo/rJLJChiX7wxfQV3kWDNPylVS4QqNWdVjDymF93t99/3xVvrf/cVviaanTIMiOP5m/M6ZKtj1LNH1RMo0QspEGjnVzLqRo0eVKmlU+QCEkbjSVKs9n/aM4vIMRMr0qU+EwW/F7JzpGm3r+Pycs2pVkDINccUc5vKpmWPGnS5ksROlPNaT9odFpBzVACmzMLIySSxCyjGDxTI8hUiR1e6cQI1Ws5LcBOZFyFBdUyZgLrea+XgreAhN6t6psVSHW6Ca5hisoakTdv1PxKTMwsxYl410Gz9Fx45J5UXdVUVmRUgWiOTiM1vKDCzO3nR7TPiLwUWabrqJnJb1f1TPOzTL6c/jjhM5cf1X1LOKLJU/5bW98xT98FAaedXMYboDczQ631vTXW6npq6dQCVhXkFYqoViS6queCdVy3HV+7Gm0Va35rpG3sj4iRskm8anzPif0lvHCVWT9ZYwFKysKTOxWjDoeKadWGL1MjWiJjIqv5MYTGOuayX9ztNMECWUV82s+7JBkdAHUlRd6j7gdmshbUhQljmjajmuejdOjQnTMpqfhcq2LikzcdukbsuMdBtSbGzNnqEWqxCzIlSSlBmcf1LhS5iS8rKjemLKyKievbqzndG8phY+oX4nZ5ZdXjVz5I/GxYGQWV9pLwdbr4U0hUvxsWm1HFe920+Z5o0bWTQ7Zcb1y4x8McVtjOrRUUalTZm639Tjz+OfahD4q2ZC4zXR3eE2fcUxsgLJLU2KdVZ3YU5aO/rWVU/ONbO6eBgtEAWnftJdUsyjFopvtFRlx4RqOemhNTb7ZZpvkGyalzLNRld8MblzmekkUSJlTZna64fRxkRF1WpSasMNlbFLFlkBTX0Z95WYecWsnebmYVqkSi3nmlnbIThpcaiU6yjV7s2pFoouw/gTRZNhYrWc2AfAYspMsUGyqUvKNO/9oxL3bX3BUGf9TN2Q4UxZ+2Xqu05GqqGUrexTxdP0OmLsiqonEBNM01Sy0+Ny3bMaCqiZQ1+OObDImDWX5rdnbrWQ/NAbN/Tf6cklF/mkM4a9lJlmg2TT+JSZuItMUqamCFKpVUD1U6Ziwc2rM9MjSLMC4cNKfRkonDLjlpiUWQsF1Mye5COLjFl7aU65OdZCwcZLRcjMljI1B1AOKdNgg2TTvJSZdnOZNHLrrgFRrZVfvVOmtePC+F4MVdd52jIbp4Ca2ej7/ORvhBSXdvOshSZn/q3I1fKYr8UuZnEpk7ZMNbcpM+nGMiq2CkiZMpPbrQvvl6lacONaN6n7j/EKRBcnl36ZpMxqKEfKpCpuihT3BeZZC4VbK7V9kM3DcO4pk36ZCdymzMBOUD78gJqtAgwLRsKuzjllqu4Dj6+GdHeOj0bHW+rf15EpHQffsmKeMhU3aEb6w0/PK/7uTlJmFeVUM5/vreluzk1z9kTtTOc7VaWhesab/VooWPHG/zW2Wk6u6BSl+3jP5AaT6PjcYx7LccqM72hB1VYFhgUjsUuN7ZSpoci4Sfcgxk0kaZZxa6WZQ0xbZuJmjLncbrbWKIlcambzMj1KOpKonGsnuddjNDzmUAvFXx0yLsKJFZ1qSnGHUNz4xhskG1Km8Sw09ZKmMuMkWBHGBcP4uVXKJhTTdhXzc2Ny3rLxNBn9r3r1d5S/hWNvGFJ9lZRZdXnUzDHnwph2TMMvoAbi09La1rEsYTnWQuM/Jt6NFlsqkys6xcoa9vpSjW+2QbIhZRrPIqZeCu0/ToAVku0hGOMikdBwqW7FizvDaQ92/SMOTC/16yelXDndOo27LJ2fhyedfGeUdqnSP8CBg6zU8quZE4vOKE2zFurHuMrLqxY63koszYnLaFbRhSaT9oEO+s7/Vo+SiqdMoF7s3NQHAEAJkDKBEiFlAgBqg5QJlAgpEwBQG6RMoERImQCA2iBlAiVCygQA1AYpEwAAAPaRMgEAAGAfKRMAAAD2kTIBAABgHykTAAAA9pEyAQAAYB8p0wb/8TO8Gxcl5pVSnpCE5vDfBk3NDCTJ7WDJKWUeb9X5yJYvsvdWUf63bGfweu8FpFHQczgpciiNScWcuUhSnktM1mla412nGVNfF06KTlKlGRqzsqUl7mCZ9SiwnjItHNllN17FcdErYcpswF5AKrmnTIocSmamIkl5Lr2ZU6Z6B+tGjpaDcMKMGbX01OXdzlFAykwtHDJHJbxiXv+9gHRImWgaUma92UqZ05VizLjTJSF2opUrNKRMi463ZjzNKkJm+ZR9L6BopEw0DSmz3szrNOWYwZbI6SkcbwV3+iRNTo2lKiCB4Fm1UlO2lBluJ5bbXteAHCkE4V8BkRHklNb2zqdH9kaN/IxQl7NIn4mtLaMFiim14zH9Uaa/KPeEP+O1vfOkHh6B8r83vRDh3RrcJuqVDHcTNdlEyFG2nasq8OEJmhdUf8S1cK8O0+Kkn+ZkFIpcSRnWk+pxFecV/TgpitPkyo+yOCd0dEsxI21RTtosxZ3IMIMZU+bINETpitxodL63pitaCanMtGAYJBNtHpvxYLFXq6dOmcpm4hQnG21XBk2EVoxktvb6GQmxtnUcvtqdNLnQ2OqmeMUuS1rT+LZ85ZYjZVZFhp2rLStmB07CqCJzytStijIYqNYQRTPfK7E712yc9CkzhSzVoPojs81S0IkMs5k9ZRpNYlJek3eaPpDqRtQXDKNkEpfHqpoy1TM2XyzT065mSsoGxNQzinao1ImUqnDIzJYyjYpSZGFJmZWTYedqRrCQV0PjpkmZ+rl7X6bIlZLZD+iknWs2Tq4pM1gWZ2/LTN4sRZ3IMBPjHpTZ2jIz3D1uHDJnrPDj2/UsHSzOUmZy9De84Kb54ahcQdWENPXk+GNVV4rANKP7SPP7M7SakZCpW2XVOql7bah/KMVuE4MjhU5FpZBh5yr3mfmBM1U7BaYzGTd9LNBMMnS5iCJXPon1pNnONRknS8pU9n3TFGY7zTPGm6WgExlmMWPK1PfLjP7dbNeZXi4fJRUM02SSlMcsNvwX2C9Ts/bKMWLOleE/Rf4SPSnGC+9e9e4OfRpTJuJLpa6nkEn9Fp5uclt+ZJuQMqsiw85VFfgUB475pSHj4mR4wYgiVwGRKs9k5xoVgBlTZuKn0XbGGfplJm6W+C/neCJDChlSptnonpj+jvGLY1IBmlX4SckkKY9VM2VGNr2+u3NksYx+HIZ7+WhXTlNuYg7u8B5JurQYXr/x+CarrFv+6F7X/xrRbhNSZlVk2LmKHZbiwNFfscmcMk0vAlHkSsmwnozZuWYFoGIpM36zxH/Z7okMmWXol6linB5jd2KqjGla4ScvdXweq2rKHI2i+8xsseJqq/DfYms2kz6XsTsqvv+icnoxIZOUCY3Zd27Kv+mruswp07Qqp8iVjUE9mermh5qkTMMu+wWcyDCb2VOmYVWVotHLtPIzKkRxYu4uUeSACqbM8AKoWwyt/ASMaThMu+3D+yBdBaAPmaRMaFhOmckHjvkMSZn1ZlZPNi5lmp8+8j6RYWYZUmbG3RD//SxVX9ZmBYMpivSVvHYlnKfMqYWIVgJ2urOEt7OmO4OqFKhSproGS9FZN76XsM2UGdPFRzNNUmbZzN7pVj229i/63j76lJlQnPSTnEaRKxPDetJk55oVgBS1k6WUmTyj6EcpTh85n8gwu3KkzGwVX8YKP+VUZzpYsq9cmK2UqahDAh/thX45aw5ng3bGpIsgkfQWV6x0905G7kKUYyb0EraWMpVLproXcuq3eWBeMXsBhTHfufFno2z3tAYnpTi9mhYn3SRHo2PNKzIocs6Z1pMmO9esAJjXTrOlzBQzikuZCaePUd4nMswup5R5vremSVqpfnEkz0S/RObJRDPV8bdmOVh065e+Vp/17p+EdQj91fh6hX4fJNwtZjAnZQ/vuIVJCJmWUmbyYiR210hYMyq6opnv3ISzUZYLfUnjmhenuNVQtZ8rFwyFMqwnzXauyTgZilO2lGk+o9g2ECX1D7Po3y2cyDCzXFKmaSaYnq5SXO6cocKf+kZCHpvpYNFNIHVZTpsy1Zs1ab0NektqphFdoaTtFgnx8fOKGS9SNWlKTY4pMzTHhDpS9wMsutYohPnOTT4bmR04yXM1K/vh5dOPGhszKXLOGNaTo5HZzjUYx7Q4zZgyU5RbRUNMis2S84kMs8ojZZo1T4WmavoF1XzSV/jJh6Wtg0WzQfJOmefn5+GlTmzxVYyjvykqvGrqFVIsQ3QbGd5LqBlVUQnrCk0+KTMx0o5H03edSNgLyJv5zjU8GyUeOJHpjUfcUtezaYqTor5Jqngpck4Z1ZMek52bPI5RcZo5ZZrOaGq0mFQds1lyPpFhFvn1yzSpZxObCWdMmcoFCU3VLI/NdLAoliL3tsxqK+EdCrP2S0aJsXMBAE1W45R5vBWTyUmZyB87FwDQZLVNmdn75RaLIFJj7FwAQJPVNGXG9pko1SmfIFJj7FwAQJPVNGX6Ig2a5WnDHCOI1Bg7FwDQZPVOmQAAAHCDlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQAAAPtImQAAALCPlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQAAAPtImQAAALCPlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQAAAPtImQAAALCPlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQAAAPtImQAAALCPlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQAAAPtImQAAALCPlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQAAAPtImQAAALCPlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQAAAPtImQAAALCPlAkAqLjzvTWxtnfuejEATCNlAgAqLkPKJJgC+SNlAgAqjpQJlBIpEwAQcb63trZ3vLcmhBBCbB2Pzsf/nkQz+dHks+MtsbZ3frwlPzveEiI8UnhGIjpKZDqKeam+G8yOUzkyMOrW8dRiTa2c/2cAVpAyAQAR53trMtAFwt7oeGucw7wgKEeefLi2JoPaZGxt22E4F+qmoxwnMnF1ypx8ZzQaHW9tHYfmG5g8AItImQCACF2joPx3IOMF/jP1qWYc/YxSTSf4YXzKVM45PCZtmIB9pEwAQIRJypy+dj5Oh4FWweB1ad3V6Mh1bfV0ovMKfpiYMqNNleG2VX9ZadMELCJlAgAisrdl6tKhwYxippNrW+bUhzRqAtaQMgEAEYkpMxgElZ01x2NP4uGxKr6F+keqp6ObV7TT5+TDQM/SUEqNfH0UnBEpE7CGlAkAiEhOmTH3mAcnlHSTedw95srRYu8xD3y2trUVaiUNXbaXY25tTd9xDsASUiYAwB2eWwnUFykTAOAOKROoL1ImAMAdUiZQX6RMAAAA2EfKBAAAgH2kTAAAANhHygQA5KI7VukpAMiMlAkAyIV8BmWlpwAgMw48AEAuypARSZmAQxx4AIBclCEjkjIBhzjwAAC5KENGJGUCDlXpwPvtwY16DK43JAAUoQwZkZQJOFSlA895OiRlAoC5MmREUibgUJUOPOfpkJQJAObKkBFJmYBDVTrwnKdDUiYAmCtDRiRlAg5V6cBzng5JmQBgrgwZkZQJOFSlA895OiRlAoC5MmREUibgUJUOPOfpkJQJAObKkBFJmYBDVTrwnKdDUiYAmCtDRiRlAg5V6cBzng5JmQBgrgwZkZQJOFSlA895OiRlAoC5zAlvfX293W632205Be+/6+vrhS0DgNlV6cBzng5JmQBgLnPC6/V6QqXX6xW2DABmV6UDz3k6JGUCgLnMCW84HC4sLIQi5sLCwnA4LGwZAMyuSgee83RIygQAc7MkvGhzZoaGzBmXAcCMqnTgOU+HpEwASPThw4d+v9/v92XC64+ZTyTUnJmtIXNEygScqtKB5zwdkjIBINHOzo6yV2Wn00k1nWBzZraGzBEpE3CqSgee83RIygSARIPBoNVqRVPm6elpqunI5swMDZmdTkeZdFdWVlJNB8As6psy7139UlnHjH25dF31xetPPxVf3/nMaOKfz5MyASAq2pyZtiHT4zVnZmjIPD09Vdb8R0dHGRYDQDakzHDEFCIhZT7/3J/C159ffXovKY+SMgE0T7Q5M21Dpmc4HK6srGTrkRltzqQhEyhY7VPmnHEQnP96XBPFpMxxxDSfLCkTQBMFmzOzNWR6Li8vs30x2pxJQyZQMFLmjd8e3HizNBesiXQpczzaTBGTlAmgCYLNmdkaMmcXbM6kIRMoHilTZsfW83Fzpjplji/BJ/faJGUCwLg5c5aGzBkFmzNpyASKR8q88WZpbhwcY1Km32VTc89QlpTZ7XbDfUUbY2FhYXd3121xqo3hcOg9ibDX63W73e3t7XbA5uZmt9vd3d31xsl88RHIwGvOTNWQeXZ21u/3j46Out1ut9t99OhRsDx7Hx4eHvb7fcPJes2ZNGQCTpAyg4M+ZfpTaz2durae8dK5tzpNTplCiFarReKZxWAw2N/f1z2uJcbq6mqv1zs7O3O9BmiEk5MTk9H6/f7Ozs7y8nKqwrywsLC9vX10dBRze5DXnElDJuBE7VNmVEw01KbMUMfNoAytm97qNDxlCiEuLi6cFqhKOjs7293dXVlZmX37Ly4u7uzspHodC2DR5eXl4eHh+vr6/Pz8jIW51Wp1Op39/f3BYBCdUbfbLXzlAIxGpEyzlOlfLhci+IxM+WHqFk3XG9KxxcVFb8ORMlM5PT1tt9vKMn3jlrh5R9x9Iu5viAfbYv3tZPj3T+L+hri/IW7eETfvqA+JxcVFWnpQpOFw2Ov1lOHy40/EzTtiac0vt0+6k8L83S/+h7fb4uYdceWaOm7u7OxwnQQoidqnTCtXzMeB8tOrbx58Fh0/bXOm6w3pGCkzrYuLC+WV8aU18fC5ePZO/Pw/0+HFn+Lxa3G7LT7+JDy1lZUV2jVRgIODA1kJSNe/EHefiPW3KQrzz/8Tm7+K1afixq1wYZ6fn9/d3c32lE0AFpEyU6TMaJr0H58ZTp+kzDikTHODwWBzczMaLh+/Fi/+THc+jg7f/SJWvhVzH01NvN1uf/jwwfV6o55OTk5CnT2uXBMPtsUPv89amH/8Qzx8Ho6bCwsLBwcHrlcaaDRS5kwp0++vScpMg5RpaH9/P/T2lH/cFZu/zno+Dg3P3om7T8JNQdvb27QDwaLhcLi+vh4sYx9/Ir75Xrx8b7k8//un8JX01dVVZWdNAAUgZZqkTL/NkrZMK0iZJra3t4Nnyr/fTn0xMW1T0O12uFGTzm2w4uLiItiEOfeRuPvEQmO8bnj5Xnzz/VSfkIWFBVroASdImUYp87c7LVWapF9mFqTMeJeXl8G7fK5/IZ50c8yXwWHzV/GPu5Nz8+LiIg88wow+fPgQvMvndlv8+EcRhfnFn1ON9K1Wa39/3/XGABqHlGmWMuVbzrnHfGakzBhnZ2fBGyOW1nJs8tEN9zcm5+b5+XnD5x0CUaFeHw+fF12Y199ONWrSFQSJ0iWNEg+uN6SPlGmYMrWPRsrwwknXG9IxUqbOyclJsNXn/kbRp2Q5PH49dVcQt1Agg1evXski9PEn+Xb5iBl+/GPqrqBHjx4RNBHDeTokZTrjOGU+uDH14EwhhGg9T9Mds2z73hVSptLZ2ZmMmHMficevnUVMb9j8Vfz1b5OyTosmUtnf35eF58atgq6S64aX76e6HW9vb7vePCgv5+mQlOmM831Ws33vCikz6vLyUm6Wv/7N/o3k2YZn7yaNQGnfRo0m6/f78kL5P+7av5E82xDsCtLr9VxvJJSU84RQs6RBymzuvneFlBkyHA7l7T5zH5UlYnrDiz8nLZqLi4s8EQaJLi4uZKv8jVtliZjesPItzfNI4Dwh1CxpkDKbu+9dIWWGBB9aVNjt5ObDD79P7p9YXV2lTxtiXF5eLi8vy1Z5txfKo8PL95P3rM7Pz/MIBUQ5Twg1SxqkzObue1dImUG9Xk9GzG++d38aVg7f/TK5GajT6bjeZiiv0rbKyyHUPM+vJoQ4Twg1SxqkzObue1eakDJPT09NejEOBgPZfe122/0JOGZ4+HxyqfHo6KiAbYjKOTo6KnOrvByCzfM7OzuuNxvKxXlCqFnSIGU2d9+70oSUeXBw4DX7xWdN+c69m3fK1X1NOcg+bbQAIWo4HMpD++4T98U1fnj82i/MrVaL3sYIcp4QapY0SJnN3feuNCdlykvMyqx5enoqxynntcXQ8Ozd5A3R3KKLENn348o18eyd++KaOMjnJ6yvr7veeCgR5wmhZkmDlNncfe+KTJmrq6vtmpI3QMRkzdXV1UpcKw8OD8b3Kc3Pz9MCBGkwGMj7yh9suy+oJsN3v0wOT57SBcl5QqhZ0iBlNnffuxJ8g2LTyKwpe7DNfVSNhh9vePl+cucEHdog7ezseKXiyrUK9P2Qw9KaX5jb7bbrTYiycJ4QapY0SJnN3feubG5uFpztSmJhYaHX6w2Hw2r1YAsNT7r+6vCcdniCN7GV+aaf6LD56+Tw5PGZ8DhPCDVLGqTM5u57h/p1F3yDczBfeqt/cnLifV6thh85yA5tr169cluQUAayR+aNW+4LZ9pB3tO2ubnpekOiFJwnhJolDVJmc/c98iPv/gnlS498DPvqU/dn2QyD7NC2vLzsagujPNrjZ2T++yf3hTPtsLHnF+b5+XnXGxKl4Dwh1CxpkDKbu++Rn4ODA2W+9CwsLHgnto0992fZDMPL95OHtNf4QQEwcXl5KdvsK9TDODjIrsb9ft/15oR76U7o965+GdtR6sul696Yb5bmQn/6+s5nBhOce3pPNVp1kgYpk5QJ+waDge6Jkv1+X14ud35+zTzc9luvxO7ubsHbFqUim+1v3nFfLLMN8qI5N7RhlFPK1I326dU3DyYhMppEfZ/PVzdpkDJJmSiU7LK58q3782vm4d8/+bUfN+c2XKfT8UpCVR5gFB3W3/qFeXFx0fXmhHuZUmZyi+ObpTnZrhkMlIEP57+ORE85mrrhswpJo0opE6gBeXf5+lv359fMw4s/JxfNLy8vXW9UuDEcDuXd5T/+kUtJC8qvPMsXTvLYBOSUMqPD888jzZn3rn4Zbra8/vTTjM2Zrjekj5QJFEe+78fi5fJn/xX/2koY519b4tl/LZ+Yb97xT8yHh4eutyvckN0/8ru7vJiU+dVDfxb7+/uuNyocKyxl+u2USfHRD6OkTACJ5DOMltasRcxbq0IIce8/CSPcWrUcNO9v+CfmbrfrervCjcPDQ68M5Pf+qmJS5jffU5ibYmdnJ77FuuC2zOBldNXgt2UmjUbKBDAa7e/veyczW50y/7U1OQevfBvOkTJiehKbPFMND5/7k93e3na9XeGGfFJmfi8XKCZlyn7GPDWz9rw+S9FX/krFpEzDhszf7ng9UrKk2II3rA4pEyhOt+u/Oef+hrUT5L3/TE7DwQbLUMS0frORfAlQp9NxvV2Rr4ODg4ODg+jn8sWS33xf7ZQpbwDibrbaC77iWJk1M6XMKG0u9K+AGwXH7J0ySZlAE8m3a9p9frV8FIsMmqGIaf1y+c+Bx1mvrq663q7Il/fraHFxMZQ119fXvTLw+HW1U+YPv/uzkC8akD8IUXuhrFlUyvS0nj/QZs3xDeZx45AyAUw8evTIq1q++8XmOTKaKfOOmD//T/z4hz/9hRXg6ToAAA+fSURBVIUF19sV+QpGrmDWlG/9sfXAhLRsFeYXf/oTlG8AImU2zfb2tveQ40wpM/0Vbf9SuL6dcpxfM/TIJGUCDbWysuJVLJu/Ws58oaAp5RQxQ5mgjVoLXmT0eFlTfv6jpccYpWWxMIeezEXKbI7Qe9oKSpkP4vtcznStnJQJNJR8t6Sts3J80Mw1Yv78P3HlWoGnApTPX/7yF+8fNUiZ8pGZg8HAdT2BHAV/MinfA1xcyhw/hj36xPUZr5WTMoGGWl5e9mq3H36vQ8pEk3U6nX/+85/evzf2qp0yX76fTNN1JYF8eSlTmS897lPm+Fp5hvf9kDKBRmvb7scWEzE9+QXNZ+/8WXzyySd91Jq8a03mS+9uCfl6ySfdXMrYz0Xd/fPjuJMxL5msvdXVVV2+9BSXMtVXzMfvmZzhWjkpE2ionO4xd3L3z+av/vRXVlZcb1fkS3ZSDN2Nu7297X3+8Hm1UyZPMoKUT8q8/vTT6XHkzenBN0xO7kCf6Vo5KRNoqFevXnlViMXnC774P0WmLOBJRt/94k/80aNHrrcr8tXtdpVPFtzd3fXKgMXnvzpJmY9f+7NYX193soVRHrmkTHlHedh0mtSOpsijpEwAYXm8K8Xwqey3VsWL/7N5YuZ1Kc2hu7x4cHDglYGvHlY7ZT7w22TFzs5OwdsWZZNHynxz7/r4hp6JyCOKxtfKSZkAsjk6OvKqC1vvfXb4hkneY45+v++VgZt3qp0yV5/6s+j1eq43Khyb8Tp1eQbXG9JHygSK8+HDB+9kdv0LO2dHmSPv/SdhBOsXzb966J+Y9/f3XW9XuDEYDLwyMPeRePFnhVPm9S/8WZycnLjeqHDMeTokZQLIaDgczs/Pe+czWw8zevbfhEbKZ/9VNHPOOLx8P3m+4NnZmevtCmdWV/3W8vxeMpn3IF8vOT8/H3PrMRrCeTokZTrjfJ/VbN/DCfnqZ4s3ABU/yHtyefJLw8kbgGx1Ail++OZ7vzBz6w9GJA3bSJnN3fdw4vDw0Dul/f22+/Nr5mHlW//E/OrVK9dbFC6dnZ15JeHjT8TL9+5LZobh77f9wnx4eOh6c8I95wmhZkmDlNncfQ8nhsNhq+U/riKP90wWM/z1b/6Jud/vu96icEy+0eq7X9yXzLSDfB57q9Xy3mCOhnOeEGqWNEiZzd33yM/JyUnM6yUePXpU6YvmG3uTfmwFb1iUkHwK7NKa+8KZdpDPMOKxr/A4Twg1SxqkzObue+THe46g7lW58imDN265P8tmGOTlcp6UiVHgyQlCiM1f3ZdP8+HFn+LKNX/JeVQCPM4TQs2SBimzufse+ZE5Upk1Ly8v5UXz/F4AndMgXywpeOwLxmTzfLWaM+8+mRyk3F0Oj/OEULOkQcps7r535eTkRHbkao5Q1pQvhr5yrWL3TCyt+WvEG58hnZ6eyqJelR9Oz96JuY/8ZT44OHC9CVEWzhNCzZIGKbO5+94V+YC9BlpYWPjw4cNoNBoOhwsLC96HD7bdn3ENhyfdybpE32qNJtvc3PQKRlX6gdxu+yV5ZWXF9cZDiThPCDVLGqTM5u57VxYXF4uLdWXS6XSCyUxeVf/4E/HsnfuTrslw45a/LvTIRMhgMJD9QP79k/uyGj8EO37wnAQEOU8INUsapMzm7ntXZMq8uLhwvSx5CfbLjOZLaWVlxRth5Vv3593E4eFzf3VardZgMCh+q6Lk5M3mV66V+ildL99Pfi91Oh3Xmw3l4jwh1CxpkDKbu+9daVTK1OVLT7/fl0m05C/o++H3ySslu91ugdsSlRHsB3L9i7zebD77IPsWt1otOn4gxHlCqFnSIGU2d9+70pCUGZ8vJfnCybmPxMae+xOwcnj2bvIY9sXFRe7Ghc6HDx/kdfN/3HVfdKPD6tPJRQaeXoQo5wmhZkmDlNncfe9KE1KmueFwKK+bl/M648v3k1fwzc/Pn52dud5mKDX5DtUSdgV5/HoSMXd2dlxvKpSR84RQs6RBymzuvneFlBlycXExPz/vbZMbt0r3YCN5K67gAZkws7u7K8tMeR6hsLE3eXQRb/oBikHKJGUWjZQZ1e/35XXGpbUSBc37G5OI2ev1XG8nVIbsCiKEuPvEfUl+/HoSMZeXl3llOVAMUiYps2ikTKX9/X15Vv77bffPNnr5Xnz1cBIxt7e3XW8hVMlwOAw+GffmHZc3A8l3/Hi9Pqh5gMLUN2Xeu/qliPPl0nXVF68//VR8feczo4l/Pk/KzICUqbOzsyPL55VrLl8J/ezdpC+md3mRO36Q1uXlZafTkaXo+hfih9+LLskv/hQ370xK8vLyMh2LgSKRMsMRU4iElPn8c38KX39+9em9pDxKyowgZcbo9XqyiM595ObxRht74sq1yZGyvr5OxERm8mWqQoiPPym0SK+/nTwbQQjR6XS4UA4UrPYpc844CM5/Pa6MYlLmOGKaT5aUGUbKjHd0dCRvBhJC3H1S6KXGh88n3dcEfTFhQ6hI//22WH+bbzH+4ffJQzE9POQVcIKUeeO3BzfeLM0F6yNdyhyPNlPEJGWSMhOdnZ0tLy8HW4C++T73W4KedKcafubn57mjHLaEirQQYmktlz4hz96JlW+n8mWr1To8PHS9AYCGImXK7Nh6Pm7OVKfM8SX45F6bpMxYpEwToT5tQogr18TD57nky/W3kxfueei+BusuLy9fvXoln6Xg+eqhtc6aP/4h7m9MtcQLITqdDiUZcIiUeePN0tw4OMakTL/LpuaeIVJmCqRMcwcHB/KVfZ6//k08fm3tGvp3v0zdG+E1/HS7XbqvIScXFxfBhxx5rn8h7j7J+O6rzV/F/Y3wzyQhxOrqar/fd726QNORMoODPmX6U2s9nbq2nvHSuesN6RgpM5XhcNjr9YLd2jxLa+Lh8ywPPHrxp3j8WtxuT95LLu3s7AwGA9drjPo7PT1tt9vh8ifElWviq4fi/ob47hex/lbRzPnjH2L9rVh/K+5viJVvp/p4SMvLy0dHR65XEcBo1ICUGRUTDbUpM9RxMyhD66a3OhcXF/1Gko1zpExzl5eX3W43dLXRc+OWuL8hHj73z77R7psbe2L9rXj8WtzfEP+4qy7G6+vr7A4U7OjoaH19PfoLKir4XK0YnU7n4ODA9WoBmCBlmqRM/3K5EMFnZMoPU7doeqsTfMBHMxFr0hoMBru7u/K95zGuXBPXv0jeBYuLizs7O6enp67XDI12cnKys7Mjr3KksrCwsLm5eXR0xPO2gBKqfcq0csV8HCg/vfrmwWfR8dM2Z3qr0/CU2Wq16PyX2WAw2N/fD90eZG51dXV3d5e7IlA2Z2dnvV6v2+12Op12ux26LV0Isbi42G632+12t9vtdrv8QAJKjpSZImVG06T/+Mxw+jRKmQcHB+0G48KWFcPh8PDwsNvtbm9vexs2elV9dXW13W5vbm52u92DgwN6XqJy+v0+TZVAFZEyZ0qZfn/NTCkTyM/FxQVNlQAAt0iZJinTb7O025YJAABQY6RMo5T5252WKk3O1C8TAGCCPtxARZEyzVKmfMu5vXvMAQCJhsPhysoK/TKBKiJlGqZM7aORMrxw0vWGBIDK6PV6Qoher+d6QQCkRso0TplT7Zee1vM03TFJmQCQynA49N7jsLCwQHMmUDn1TZklHlxvSACoBq8h00NzJlA5pExSJgCUkWzI9NCcCVQOKZOUCQBlFGzIpDkTqCJSJikTAEon1JBJcyZQRaRMUiYAlE60IZPmTKBySJmkTAAonf39/W632+12Zb70/ru/v+960QCYImWSMgGgvGTKdL0gAFKr0nHrPB2SMgGgYKRMoLqqdNw6T4ekTAAoGCkTqK4qHbfO0yEpEwAKRsoEqqtKx63zdEjKBICCkTKB6qrSces8HZIyAaBgpEyguqp03DpPh6RMACgYKROoriodt87TISkTAApGygSqq0rHrfN0SMoEgIKRMoHq4rgFAJQXKROoLo5bAEB5kTKB6uK4BQCUFykTqC6OWwBAeZEygeriuAUAlBcpE6gujlsAQHmRMoHq4rgFAJQXKROoLo5bAEB5kTKB6uK4BQCUFykTqC6OWwBAeZEygeriuAUAlBcpE6gujlsAQHmRMoHq4rgFAJQXKROoLo5bAEB5kTKB6uK4BQCUFykTqC6OWwBAeXXHXC8IgNRImQAAALCPlAkAAAD7SJkAAACwj5QJAAAA+0iZAAAAsI+UCQDI7rcHN+oxuN6QQA2RMgEA2TlPh6RMoLRImQCA7JynQ1ImUFqkTABAds7TISkTKC1SJgAgO+fpkJQJlBYpEwCQnfN0SMoESouUCQDIznk6JGUCpUXKBABk5zwdkjKB0iJlAgCyc54OSZlAaZEyAQDZOU+HpEygtEiZAIDsnKdDUiZQWqRMAEB2ztMhKRMoLVImACA75+mQlAmUFikTAJCd83RIygRKi5QJAMjOeTokZQKlRcoEAGTnPB2SMoHSImUCALJLF+buXf1SxPly6bo35puludCfvr7zmcEE557eU41GygRcIGUCALLLJWXqRvv06psHkxAZTaK+z+dJmUAZkDIBANllSpnJLY5vluZku2YwUAY+nP86Ej3laOqGT1ImUCxSJgAgu5xSZnR4/nmkOfPe1S/DzZbXn36asTnT9YYEaoiUCQDIrrCU6bdTJsVHP4ySMoESIGUCALIruC0zeBldNfhtmUmjkTKBIpAyAQDZFZMyDRsyf7vTynynuesNCdQQKRMAkF2mlBmlzYX+FXCj4Ji9UyYpE8gDKRMAkF1RKdPTev5AmzXHN5jHjUPKBIpEygQAZJcpZaa/ou1fCte3U47za4YemaRMICekTABAdgWlzAfxfS5nulZOygRyQsoEAGRXXMocP4Y9+sT1Ga+VkzKBnJAyAQDZuU+Z42vlGd73Q8oEckXKBABkV1zKVF8xH79ncoZr5aRMICekTABAdvmkzOtPP50eR96cHnzD5OQO9JmulZMygZyQMgEA2eWSMuUd5WHTaVI7miKPkjKB4pEyAQDZ5ZEy39y7Pr6hZyLyiKLxtXJSJlBWpEwAQHYzXqcuz+B6QwI1RMoEAGTnPB2SMoHSImUCALJzng5JmUBpkTIBANk5T4ekTKC0SJkAgOycp0NSJlBapEwAQHbO0yEpEygtUiYAIDvn6ZCUCZQWKRMAkJ3zdEjKBEqLlAkAyM55OiRlAqVFygQAZOc8HZIygdIiZQIAsnOeDkmZQGmRMgEA2TlPh6RMoLRImQCA7JynQ1ImUFqkTABAds7TISkTKC1SJgAgO+fpkJQJlBYpEwAAAPaRMgEAAGAfKRMAAAD2kTIBAABgHykTAAAA9v0/Blnct8TEPoIAAAAASUVORK5CYIIA)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 内存层级\n", "\n", "在传统的 GPU 调度中,我们有 “全局内存”、“共享内存”和“本地寄存器”的内存层级。 为了支持 Tensor Cores,我们引入了另外三个特殊的内存范围:`wmma.matrix_a`、`wmma.matrix_b` 和 `wmma.accumulator`(类似于在第 6 章第 2 部分的演示中的 `global.A_reg`、`global.B_reg` 和 `global. accumulator`)。 在硬件上,所有 `wmma` 的内存相关层级都存储在片上寄存器级别,与本地内存相同。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 注册 Tensor Intrinsic\n", "\n", "这里我们注册了所有的 Tensor Core intrinsics,包括`load_matrix_a`、`load_matrix_b`、`wmma_fill`(初始化`C = 0`)、`wmma_sync`(累加计算`C += A * B`)和 `store_matrix`。 在本教程中,我们不会解释如何编写 intrinsic ,而是关注如何将给定的 intrinsic 应用于张量化程序。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "@T.prim_func\n", "def wmma_load_a_desc(a: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(a, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"shared\")\n", " C = T.match_buffer(c, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_a\")\n", "\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " for i, j in T.grid(16, 16):\n", " with T.block(\"load\"):\n", " vii, vjj = T.axis.remap(\"SS\", [i, j])\n", " C[vii, vjj] = A[vii, vjj]\n", "\n", "\n", "@T.prim_func\n", "def wmma_load_a_impl(a: T.handle, c: T.handle) -> None:\n", " s1 = T.var(\"int32\")\n", " s0 = T.var(\"int32\")\n", " A = T.match_buffer(\n", " a,\n", " (16, 16),\n", " \"float16\",\n", " align=128,\n", " offset_factor=16,\n", " scope=\"shared\",\n", " strides=[s1, s0],\n", " )\n", " C = T.match_buffer(c, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_a\")\n", "\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_load_matrix_sync(\n", " C.data,\n", " 16,\n", " 16,\n", " 16,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " A.access_ptr(\"r\"),\n", " s1,\n", " \"row_major\",\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "@T.prim_func\n", "def wmma_load_b_desc(a: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(a, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"shared\")\n", " C = T.match_buffer(c, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_b\")\n", "\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " for i, j in T.grid(16, 16):\n", " with T.block(\"load\"):\n", " vii, vjj = T.axis.remap(\"SS\", [i, j])\n", " C[vii, vjj] = A[vii, vjj]\n", "\n", "\n", "@T.prim_func\n", "def wmma_load_b_impl(a: T.handle, c: T.handle) -> None:\n", " s1 = T.var(\"int32\")\n", " s0 = T.var(\"int32\")\n", " A = T.match_buffer(\n", " a,\n", " (16, 16),\n", " \"float16\",\n", " align=128,\n", " offset_factor=16,\n", " scope=\"shared\",\n", " strides=[s1, s0],\n", " )\n", " C = T.match_buffer(c, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_b\")\n", "\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_load_matrix_sync(\n", " C.data,\n", " 16,\n", " 16,\n", " 16,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " A.access_ptr(\"r\"),\n", " s1,\n", " \"col_major\",\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "@T.prim_func\n", "def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(a, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_a\")\n", " B = T.match_buffer(b, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_b\")\n", " C = T.match_buffer(\n", " c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", "\n", " with T.block(\"root\"):\n", " T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " for i, j, k in T.grid(16, 16, 16):\n", " with T.block(\"\"):\n", " vii, vjj, vkk = T.axis.remap(\"SSR\", [i, j, k])\n", " C[vii, vjj] += T.cast(A[vii, vkk], \"float32\") * T.cast(B[vjj, vkk], \"float32\")\n", "\n", "\n", "@T.prim_func\n", "def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(a, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_a\")\n", " B = T.match_buffer(b, (16, 16), \"float16\", align=128, offset_factor=16, scope=\"wmma.matrix_b\")\n", " C = T.match_buffer(\n", " c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", "\n", " with T.block(\"root\"):\n", " T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_mma_sync(\n", " C.data,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " A.data,\n", " A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),\n", " B.data,\n", " B.elem_offset // 256 + T.floordiv(T.floormod(B.elem_offset, 256), 16),\n", " C.data,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "@T.prim_func\n", "def wmma_fill_desc(c: T.handle) -> None:\n", " C = T.match_buffer(\n", " c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", "\n", " with T.block(\"root\"):\n", " T.reads()\n", " T.writes(C[0:16, 0:16])\n", " for i, j in T.grid(16, 16):\n", " with T.block(\"init\"):\n", " vii, vjj = T.axis.remap(\"SS\", [i, j])\n", " C[vii, vjj] = T.float32(0)\n", "\n", "\n", "@T.prim_func\n", "def wmma_fill_impl(c: T.handle) -> None:\n", " C = T.match_buffer(\n", " c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", " with T.block(\"root\"):\n", " T.reads()\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_fill_fragment(\n", " C.data,\n", " 16,\n", " 16,\n", " 16,\n", " C.elem_offset // 256 + T.floordiv(T.floormod(C.elem_offset, 256), 16),\n", " T.float32(0),\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "@T.prim_func\n", "def wmma_store_desc(a: T.handle, c: T.handle) -> None:\n", " A = T.match_buffer(\n", " a, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", " C = T.match_buffer(c, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"global\")\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " for i, j in T.grid(16, 16):\n", " with T.block(\"store\"):\n", " vii, vjj = T.axis.remap(\"SS\", [i, j])\n", " C[vii, vjj] = A[vii, vjj]\n", "\n", "\n", "@T.prim_func\n", "def wmma_store_impl(a: T.handle, c: T.handle) -> None:\n", " s1 = T.var(\"int32\")\n", " s0 = T.var(\"int32\")\n", " A = T.match_buffer(\n", " a, (16, 16), \"float32\", align=128, offset_factor=16, scope=\"wmma.accumulator\"\n", " )\n", " C = T.match_buffer(\n", " c,\n", " (16, 16),\n", " \"float32\",\n", " align=128,\n", " offset_factor=16,\n", " scope=\"global\",\n", " strides=[s1, s0],\n", " )\n", " with T.block(\"root\"):\n", " T.reads(A[0:16, 0:16])\n", " T.writes(C[0:16, 0:16])\n", " T.evaluate(\n", " T.tvm_store_matrix_sync(\n", " A.data,\n", " 16,\n", " 16,\n", " 16,\n", " A.elem_offset // 256 + T.floordiv(T.floormod(A.elem_offset, 256), 16),\n", " C.access_ptr(\"w\"),\n", " s1,\n", " \"row_major\",\n", " dtype=\"handle\",\n", " )\n", " )\n", "\n", "\n", "try:\n", " # handle exception if we register multi times\n", " tir.TensorIntrin.register(\"wmma_load_a\", wmma_load_a_desc, wmma_load_a_impl)\n", " tir.TensorIntrin.register(\"wmma_load_b\", wmma_load_b_desc, wmma_load_b_impl)\n", " tir.TensorIntrin.register(\"wmma_sync\", wmma_sync_desc, wmma_sync_impl)\n", " tir.TensorIntrin.register(\"wmma_fill\", wmma_fill_desc, wmma_fill_impl)\n", " tir.TensorIntrin.register(\"wmma_store\", wmma_store_desc, wmma_store_impl)\n", "except ValueError:\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Blockize 张量计算\n", "\n", "正如课程中所说,我们可以使用 TensorIR 来表示一组带有 `Block` 的张量化计算。 我们可以直接用 `Block` 编写一个 TensorIR 程序,也可以通过`blockize` 生成新的 `block`。 请记住,`wmma` 操作适用于 `16x16x16` 矩阵乘法,我们需要切分循环,而最里面的循环是 `16x16x16`。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0, j_0, k_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m64\u001b[39m, \u001b[38;5;28m64\u001b[39m, \u001b[38;5;28m64\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o, vj_o, vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_0, j_0, k_0])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "sch = tir.Schedule(MatmulModule)\n", "block = sch.get_block(\"matmul\")\n", "i, j, k = sch.get_loops(block)\n", "\n", "i, ii = sch.split(i, factors=[None, 16])\n", "j, ji = sch.split(j, factors=[None, 16])\n", "k, ki = sch.split(k, factors=[None, 16])\n", "sch.reorder(i, j, k, ii, ji, ki)\n", "wmma_sync = sch.blockize(loop=ii)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 切分循环并绑定 threadIdx\n", "\n", "### Warp 指令\n", "请注意,所有 Tensor Core 指令都是 warp 指令,这意味着一个 warp 中的所有 32 个线程应该同时执行此指令。 使 `threadIdx.x` extent=32 是解决此问题的最简单方法之一。 然后我们可以将`threadIdx.x`绑定到任何循环**除了**那些直接或间接包含Tensor Core内在函数的循环。 另请注意,这不是唯一的解决方案。 我们唯一应该做的就是确保一个 warp 中的所有线程都可以同时调用 Tensor Core。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0, k_0_1, i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "i0, i1, i2 = sch.split(i, factors=[8, 4, 2])\n", "j0, j1, j2 = sch.split(j, factors=[8, 4, 2])\n", "k0, k1, k2 = sch.split(k, factors=[16, 2, 2])\n", "\n", "sch.reorder(i0, j0, i1, j1, k0, k1, i2, j2, k2)\n", "bx = sch.fuse(i0, j0)\n", "sch.bind(bx, \"blockIdx.x\")\n", "ty = sch.fuse(i1, j1)\n", "sch.bind(ty, \"threadIdx.y\")\n", "# We can't bind to `threadIdx.x` since we have warp-level operators under the loop\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 将`A`和`B`缓存到共享内存中\n", "\n", "与 Cuda Cores 的优化技巧类似,我们仍然需要将 `A` 和 `B` 缓存到共享内存中。 此外,还需要利用 cooperative fetching 技术。" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1, i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "X_shared = sch.cache_read(wmma_sync, read_buffer_index=0, storage_scope=\"shared\")\n", "Y_shared = sch.cache_read(wmma_sync, read_buffer_index=1, storage_scope=\"shared\")\n", "\n", "\n", "def schedule_shared(block):\n", " sch.compute_at(block, k0)\n", " x, y = sch.get_loops(block)[-2:]\n", " fused = sch.fuse(x, y)\n", " x0, x1, x2, x3 = sch.split(fused, factors=[None, 16, 32, 8])\n", " sch.bind(x1, \"threadIdx.y\")\n", " # here we must bind threadIdx.x == 32 to satisfy the requirements of warp-level operation.\n", " sch.bind(x2, \"threadIdx.x\") \n", " sch.vectorize(x3)\n", "\n", "\n", "schedule_shared(X_shared)\n", "schedule_shared(Y_shared)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 将输入输出数据缓存到特殊内存层级\n", "\n", "Tensor Cores 不能直接使用共享内存或本地内存中的数据。 我们必须将数据缓存到 `wmma.matrix_a`、`wmma.matrix_b` 并更新 `wmma.accumulator` 中的计算。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 缓存输入数据" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0, v1])\n", " X_shared_wmma_matrix_a[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0, v1])\n", " Y_shared_wmma_matrix_b[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \n", "\n" ] } ], "source": [ "X_local = sch.cache_read(wmma_sync, 0, storage_scope=\"wmma.matrix_a\")\n", "Y_local = sch.cache_read(wmma_sync, 1, storage_scope=\"wmma.matrix_b\")\n", "sch.compute_at(X_local, k1)\n", "sch.compute_at(Y_local, k1)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 缓存输出数据" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Z_wmma_accumulator \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0, v1])\n", " X_shared_wmma_matrix_a[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0, v1])\n", " Y_shared_wmma_matrix_b[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0, ax1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m32\u001b[39m, \u001b[38;5;28m32\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mZ_wmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[v0, v1])\n", " Z[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[v0, v1]\n", " \n", "\n" ] } ], "source": [ "write_back_block = sch.cache_write(wmma_sync, 0, storage_scope=\"wmma.accumulator\")\n", "sch.reverse_compute_at(write_back_block, ty)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 切分 Tensor Core 内存拷贝\n", "\n", "`wmma.load_matrix` 和 `wmma.store_matrix` 使用 `16x16` 矩阵执行内存复制。 然后我们对循环进行切分,以此来匹配 intrinsic。" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Z_wmma_accumulator \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0, v1])\n", " X_shared_wmma_matrix_a[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0, v1])\n", " Y_shared_wmma_matrix_b[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00minit():\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mZ_wmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[v0, v1])\n", " Z[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[v0, v1]\n", " \n", "\n" ] } ], "source": [ "def schedule_copy(block):\n", " x, y = sch.get_loops(block)[-2:]\n", " x0, x1 = sch.split(x, factors=[None, 16])\n", " y0, y1 = sch.split(y, factors=[None, 16])\n", " sch.reorder(x0, y0, x1, y1)\n", "\n", "schedule_copy(X_local)\n", "schedule_copy(Y_local)\n", "schedule_copy(write_back_block)\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tensorize\n", "\n", "tensorize 之前,我们需要先执行 `decompose_reduction`,因为 `wmma_sync` 和 `wmma_fill` 是两个 intrinsic,需要对 init block 和 update block 进行两次 tensorize\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Z_wmma_accumulator \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2_init, j_0_2_init \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2_init)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2_init)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i_init, vj_i_init \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSS\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i_init, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i_init] \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0, v1])\n", " X_shared_wmma_matrix_a[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0, v1])\n", " Y_shared_wmma_matrix_b[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y_shared[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o_update\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_1, j_1, k_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_i, vj_i, vk_i \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mremap(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSSR\u001b[39m\u001b[38;5;124m\"\u001b[39m, [i_1, j_1, k_1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i])\n", " Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i] \u001b[38;5;129;01m+\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vi_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01m*\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mcast(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vj_i, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m vk_i], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0, ax0_1, ax1_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mZ_wmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_1)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m32\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_1)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[v0, v1])\n", " Z[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Z_wmma_accumulator[v0, v1]\n", " \n", "\n" ] } ], "source": [ "init = sch.decompose_reduction(wmma_sync, k0)\n", "sch.mod.show()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[38;5;129m@tvm\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mscript\u001b[38;5;129;01m.\u001b[39;00mir_module\n", "\u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mModule\u001b[39;00m:\n", " \u001b[38;5;129m@T\u001b[39m\u001b[38;5;129;01m.\u001b[39;00mprim_func\n", " \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmain\u001b[39m(X: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Y: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m], Z: T\u001b[38;5;129;01m.\u001b[39;00mBuffer[(\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;129;01m-\u001b[39;00m\u001b[38;5;129;01m>\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", " \u001b[38;5;30;03m# function attr dict\u001b[39;00m\n", " T\u001b[38;5;129;01m.\u001b[39;00mfunc_attr({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mglobal_symbol\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmain\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtir.noalias\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mTrue\u001b[39;00m})\n", " s0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s0_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s0_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s1_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " s1_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvar(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;30;03m# body\u001b[39;00m\n", " \u001b[38;5;30;03m# with T.block(\"root\")\u001b[39;00m\n", " X_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " X_shared_wmma_matrix_a \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Y_shared_wmma_matrix_b \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " Z_wmma_accumulator \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00malloc_buffer([\u001b[38;5;28m1024\u001b[39m, \u001b[38;5;28m1024\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_0_j_0_0_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m64\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblockIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2_init, j_0_2_init \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o_init\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2_init)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2_init)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads()\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " C \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_fill_fragment(C\u001b[38;5;129;01m.\u001b[39;00mdata, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, C\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, T\u001b[38;5;129;01m.\u001b[39;00mfloat32(\u001b[38;5;28m0\u001b[39m), dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m16\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared[v0, v1])\n", " X_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m X[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m16\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.y\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mthread_binding(\u001b[38;5;28m32\u001b[39m, thread\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mthreadIdx.x\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_ax1_fused_3 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mvectorized(\u001b[38;5;28m8\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m128\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " v1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m1024\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m64\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m (ax0_ax1_fused_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4096\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_ax1_fused_3) \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m64\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y[v0, v1])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared[v0, v1])\n", " Y_shared[v0, v1] \u001b[38;5;129;01m=\u001b[39;00m Y[v0, v1]\n", " \u001b[38;5;28;01mfor\u001b[39;00m k_0_1 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mserial(\u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_shared_wmma.matrix_a_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0)\n", " v1_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(X_shared[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(X_shared_wmma_matrix_a[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " A \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(X_shared[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, strides\u001b[38;5;129;01m=\u001b[39;00m[s1, s0], scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " C_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(X_shared_wmma_matrix_a[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_load_matrix_sync(C_1\u001b[38;5;129;01m.\u001b[39;00mdata, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, C_1\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C_1\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, T\u001b[38;5;129;01m.\u001b[39;00mtvm_access_ptr(T\u001b[38;5;129;01m.\u001b[39;00mtype_annotation(dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m), A\u001b[38;5;129;01m.\u001b[39;00mdata, A\u001b[38;5;129;01m.\u001b[39;00melem_offset, s1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m1\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m), s1, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrow_major\u001b[39m\u001b[38;5;124m\"\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mY_shared_wmma.matrix_b_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0)\n", " v1_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Y_shared[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Y_shared_wmma_matrix_b[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " A_1 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Y_shared[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, strides\u001b[38;5;129;01m=\u001b[39;00m[s1_1, s0_1], scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mshared\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " C_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Y_shared_wmma_matrix_b[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_load_matrix_sync(C_2\u001b[38;5;129;01m.\u001b[39;00mdata, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, C_2\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C_2\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, T\u001b[38;5;129;01m.\u001b[39;00mtvm_access_ptr(T\u001b[38;5;129;01m.\u001b[39;00mtype_annotation(dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m), A_1\u001b[38;5;129;01m.\u001b[39;00mdata, A_1\u001b[38;5;129;01m.\u001b[39;00melem_offset, s1_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m1\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m), s1_1, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcol_major\u001b[39m\u001b[38;5;124m\"\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \u001b[38;5;28;01mfor\u001b[39;00m i_0_2, j_0_2, k_0_2 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmatmul_o_update\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " vi_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_2)\n", " vj_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m j_0_2)\n", " vk_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mreduce(\u001b[38;5;28m64\u001b[39m, k_0_0 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_1 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m k_0_2)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " A_2 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(X_shared_wmma_matrix_a[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_a\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " B \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Y_shared_wmma_matrix_b[vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vk_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat16\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.matrix_b\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " C_3 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Z_wmma_accumulator[vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vi_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : vj_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_mma_sync(C_3\u001b[38;5;129;01m.\u001b[39;00mdata, C_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, A_2\u001b[38;5;129;01m.\u001b[39;00mdata, A_2\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m A_2\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, B\u001b[38;5;129;01m.\u001b[39;00mdata, B\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m B\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, C_3\u001b[38;5;129;01m.\u001b[39;00mdata, C_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m C_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \u001b[38;5;28;01mfor\u001b[39;00m ax0_0, ax1_0 \u001b[38;5;28;01min\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mgrid(\u001b[38;5;28m2\u001b[39m, \u001b[38;5;28m2\u001b[39m):\n", " \u001b[38;5;28;01mwith\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mblock(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mZ_wmma.accumulator_o\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", " v0_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax0_0)\n", " v1_o \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00maxis\u001b[38;5;129;01m.\u001b[39;00mspatial(\u001b[38;5;28m64\u001b[39m, i_0_0_j_0_0_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m8\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m i_0_1_j_0_1_fused \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m4\u001b[39m \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m2\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m ax1_0)\n", " T\u001b[38;5;129;01m.\u001b[39;00mreads(Z_wmma_accumulator[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " T\u001b[38;5;129;01m.\u001b[39;00mwrites(Z[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m])\n", " A_3 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Z_wmma_accumulator[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, scope\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwmma.accumulator\u001b[39m\u001b[38;5;124m\"\u001b[39m, offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " C_4 \u001b[38;5;129;01m=\u001b[39;00m T\u001b[38;5;129;01m.\u001b[39;00mmatch_buffer(Z[v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v0_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m, v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m : v1_o \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m \u001b[38;5;28m16\u001b[39m], [\u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m], dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m, strides\u001b[38;5;129;01m=\u001b[39;00m[s1_2, s0_2], offset_factor\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;28m16\u001b[39m)\n", " T\u001b[38;5;129;01m.\u001b[39;00mevaluate(T\u001b[38;5;129;01m.\u001b[39;00mtvm_store_matrix_sync(A_3\u001b[38;5;129;01m.\u001b[39;00mdata, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m16\u001b[39m, A_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m+\u001b[39;00m A_3\u001b[38;5;129;01m.\u001b[39;00melem_offset \u001b[38;5;129;01m%\u001b[39;00m \u001b[38;5;28m256\u001b[39m \u001b[38;5;129;01m/\u001b[39;00m\u001b[38;5;129;01m/\u001b[39;00m \u001b[38;5;28m16\u001b[39m, T\u001b[38;5;129;01m.\u001b[39;00mtvm_access_ptr(T\u001b[38;5;129;01m.\u001b[39;00mtype_annotation(dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfloat32\u001b[39m\u001b[38;5;124m\"\u001b[39m), C_4\u001b[38;5;129;01m.\u001b[39;00mdata, C_4\u001b[38;5;129;01m.\u001b[39;00melem_offset, s1_2 \u001b[38;5;129;01m*\u001b[39;00m \u001b[38;5;28m16\u001b[39m, \u001b[38;5;28m2\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m), s1_2, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrow_major\u001b[39m\u001b[38;5;124m\"\u001b[39m, dtype\u001b[38;5;129;01m=\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhandle\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", " \n", "\n" ] } ], "source": [ "sch.tensorize(sch.get_loops(X_local)[-2], \"wmma_load_a\")\n", "sch.tensorize(sch.get_loops(Y_local)[-2], \"wmma_load_b\")\n", "sch.tensorize(init, \"wmma_fill\")\n", "sch.tensorize(wmma_sync, \"wmma_sync\")\n", "sch.tensorize(sch.get_loops(write_back_block)[-2], \"wmma_store\")\n", "sch.mod.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 构建并评估结果" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Performance: 9443.610543 GFLOPS\n" ] } ], "source": [ "rt_mod = tvm.build(sch.mod, target=\"cuda\")\n", "\n", "dev = tvm.cuda()\n", "num_flop = 1024**3 * 2\n", "A_np = np.random.randn(1024, 1024).astype(\"float16\")\n", "B_np = np.random.randn(1024, 1024).astype(\"float16\")\n", "C_np = A_np.astype(\"float32\") @ (B_np.astype(\"float32\").T)\n", "\n", "A_nd = tvm.nd.array(A_np, dev)\n", "B_nd = tvm.nd.array(B_np, dev)\n", "C_nd = tvm.nd.array(np.empty((1024, 1024), dtype=\"float32\"), dev)\n", "\n", "rt_mod(A_nd, B_nd, C_nd)\n", "np.testing.assert_allclose(C_np, C_nd.numpy(), rtol=1e-3, atol=1e-3)\n", "\n", "evaluator = rt_mod.time_evaluator(\"main\", dev, number=10)\n", "print(\"Performance: %f GFLOPS\" % (num_flop / evaluator(A_nd, B_nd, C_nd).mean / 1e9))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 讨论\n", "\n", "请考虑如何使这个程序运行得更快?(极限性能在 50T 左右,而这个程序在 RTX-3080 上只有 23T)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 参考文献\n", "\n", "https://developer.nvidia.com/blog/programming-tensor-cores-cuda-9/\n", "\n", "https://tvm.apache.org/docs/how_to/optimize_operators/opt_conv_tensorcore.html" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('py38': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7" } } }, "nbformat": 4, "nbformat_minor": 2 }