def after_matmul_vectorize( placeholder: T.Buffer[(64, 768), "float32"], placeholder_1: T.Buffer[(768, 768), "float32"], T_matmul_NT: T.Buffer[(64, 768), "float32"], ) -> None: T.func_attr({ "global_symbol": "main", "tir.noalias": True, "layout_free_placeholders": [1] }) T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3): for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(48, 8, 1, 16, 8): for i1_3_fused in T.vectorized(16): with T.block("T_matmul_NT"): i = T.axis.spatial(64, i0_2 * 8 + i0_3) j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3_fused) k = T.axis.reduce(768, i2_0 * 16 + i2_1) T.reads(placeholder[i, k], placeholder_1[j, k]) T.writes(T_matmul_NT_global[i, j]) with T.init(): T_matmul_NT_global[i, j] = T.float32(0) T_matmul_NT_global[i, j] = T_matmul_NT_global[ i, j] + placeholder[i, k] * placeholder_1[j, k] for ax0 in T.serial(64): for ax1_fused in T.vectorized(16): with T.block("T_matmul_NT_global"): v0 = T.axis.spatial(64, ax0) v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1_fused) T.reads(T_matmul_NT_global[v0, v1]) T.writes(T_matmul_NT[v0, v1]) T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 4], dtype="float16") B = T.match_buffer(b, [4, 16], dtype="float16") C = T.match_buffer(c, [16, 16], dtype="float32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([4], "float16", scope="local") MultiB = T.allocate([4], "float16", scope="local") Accum = T.allocate([8], "float32", scope="local") for i in range(8): Accum[i] = T.float32(0) for mma_multi_a_col in T.vectorized(4): MultiA[mma_multi_a_col] = A[ ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)), mma_multi_a_col, ] for mma_multi_b_col in T.vectorized(4): MultiB[mma_multi_b_col] = B[ (tx % 32) % 4, mma_multi_b_col + (4 * ((tx % 32) // 8)), ] T.evaluate( T.ptx_mma( "m8n8k4", "row", "row", "fp16", "fp16", "fp32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="float32", ) ) for mma_accum_c_id in range(8): C[ ((tx % 32) % 2) + ((mma_accum_c_id // 2 % 2) * 2) + 4 * ((tx % 32) // 16) + ((tx % 32) % 16 // 4) % 2 * 8, (tx % 32) % 4 // 2 * 2 + (tx % 32) % 16 // 8 * 4 + mma_accum_c_id % 2 + mma_accum_c_id // 4 * 8, ] = T.load("float32", Accum, mma_accum_c_id)
def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 32], dtype="float32") B = T.match_buffer(b, [32], dtype="float32") for i in T.vectorized(0, 32): with T.block("B_init"): vi = T.axis.S(32, i) B[vi] = T.float32(0) for k in T.serial(0, 128): for i in T.vectorized(0, 32): with T.block("B"): vk, vi = T.axis.remap("RS", [k, i]) B[vi] = B[vi] + A[vk, vi]
def ptx_global_to_shared_dyn_copy_fp16x8( A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"], C: T.Buffer[(32, 128), "float16"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") T.reads(A[0:32, 0:128], B[0:32, 0:128]) T.writes(C[0:32, 0:128]) T.attr("default", "async_scope", 1) for i in T.serial(16): for j in T.vectorized(8): A_shared[tx, i * 8 + j] = A[tx, i * 8 + j] B_shared[tx, i * 8 + j] = B[tx, i * 8 + j] T.evaluate(T.ptx_commit_group(dtype="")) T.evaluate(T.ptx_wait_group(0, dtype="")) for i in range(128): C[tx, i] = A_shared[tx, i] + B_shared[tx, i]
def decomposed_gemm_parallelize_init( A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"], ) -> None: local = T.alloc_buffer([16, 16], dtype="float32") for i, j in T.grid(4, 4): for ii in T.serial(4): for jj in T.vectorized(4): with T.block("init"): vi = T.axis.spatial(16, i * 4 + ii) vj = T.axis.spatial(16, j * 4 + jj) T.reads() T.writes(local[vi, vj]) local[vi, vj] = 0 for k, ii, jj in T.grid(16, 4, 4): with T.block("update"): vi = T.axis.spatial(16, i * 4 + ii) vj = T.axis.spatial(16, j * 4 + jj) vk = T.axis.reduce(16, k) T.reads(local[vi, vj], A[vi, vk], B[vj, vk]) T.writes(local[vi, vj]) local[vi, vj] = local[vi, vj] + A[vi, vk] * B[vj, vk] for ii, jj in T.grid(4, 4): with T.block("C"): vi = T.axis.spatial(16, i * 4 + ii) vj = T.axis.spatial(16, j * 4 + jj) T.reads(local[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = local[vi, vj]
def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 8], dtype="float16") B = T.match_buffer(b, [8, 8], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([4], "float16", scope="local") MultiB = T.allocate([2], "float16", scope="local") Accum = T.allocate([4], "float32", scope="local") for i in range(4): Accum[i] = T.float32(0) for mma_multi_a_col in T.vectorized(4): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_a_col % 2 ] for mma_multi_b_col in T.vectorized(4): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4 + mma_multi_b_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_b_col % 2 ] T.evaluate( T.ptx_mma( "m16n8k8", "row", "col", "fp16", "fp16", "fp32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="float32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2 ] = T.load("float32", Accum, mma_accum_c_id)
def loop_syntax_sugar(a: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) for i in T.serial(128): for j in T.parallel(128): for k in T.vectorized(128): for x in T.unroll(128): for y in T.thread_binding(128, "threadIdx.x"): for z in T.thread_binding(128, thread="threadIdx.x"): A[i, j, k, x] = A[i, j, k, x] * 2.0
def vector_func(a: T.handle, b: T.handle): n = T.var("int32") m = 128 A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m)) for i in T.serial(n): for j in T.vectorized(m): A[i, j] = A[i, j] + B[i, j]
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None: Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local") X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared") for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"): for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"): for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"): for i1_3_init, i2_4_init in T.grid(4, 2): with T.block("Z_init"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init) T.reads() T.writes(Z_local[b, i, j]) Z_local[b, i, j] = T.float32(0) for i3_0 in T.serial(4): for ax0_ax1_ax2_fused_0 in T.serial(4): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): for ax0_ax1_ax2_fused_2 in T.vectorized(2): with T.block("X_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32) v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32) T.reads(X[v0, v1, v2]) T.writes(X_shared[v0, v1, v2]) X_shared[v0, v1, v2] = X[v0, v1, v2] for ax0_ax1_ax2_fused_0 in T.serial(8): for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"): with T.block("Y_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32) T.reads(Y[v0, v1, v2]) T.writes(Y_shared[v0, v1, v2]) Y_shared[v0, v1, v2] = Y[v0, v1, v2] for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2): with T.block("Z_update"): b = T.axis.spatial(1, 0) i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3) j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4) k = T.axis.reduce(128, i3_0 * 32 + i3_2) T.block_attr({ "meta_schedule.thread_extent_low_inclusive": 1024, "meta_schedule.thread_extent_high_inclusive": 1024, }) T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j]) T.writes(Z_local[b, i, j]) Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j] for ax0, ax1, ax2 in T.grid(1, 4, 2): with T.block("Z_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1) v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2) T.reads(Z_local[v0, v1, v2]) T.writes(Z[v0, v1, v2]) Z[v0, v1, v2] = Z_local[v0, v1, v2]
def element_wise_split_predicate_vectorized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i in T.vectorized(0, 128): for j_0, j_1 in T.grid(13, 10): with T.block("B"): T.where(j_0 * 10 + j_1 < 128) vi = T.axis.S(128, i) vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0
def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [8, 32], dtype="int4") B = T.match_buffer(b, [8, 32], dtype="uint4") C = T.match_buffer(c, [8, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([8], "int4", scope="local") MultiB = T.allocate([8], "uint4", scope="local") Accum = T.allocate([2], "int32", scope="local") for i in range(2): Accum[i] = T.int32(0) for mma_multi_a_col in T.vectorized(8): MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8] for mma_multi_b_col in T.vectorized(8): MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8] T.evaluate( T.ptx_mma( "m8n8k32", "row", "col", "int4", "uint4", "int32", MultiA, 0, MultiB, 0, Accum, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(2): C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load( "int32", Accum, mma_accum_c_id )
def element_wise_split_predicate_vectorized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i in T.vectorized(0, 128): for j_0, j_1 in T.grid(13, 10): with T.block([128, 128], "B") as [vi, vj]: T.where(j_0 * 10 + j_1 < 128) T.bind(vi, i) T.bind(vj, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0
def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="int8") B = T.match_buffer(b, [8, 16], dtype="uint8") C = T.match_buffer(c, [16, 8], dtype="int32") brow = T.env_thread("blockIdx.y") bcol = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(brow, 1) T.launch_thread(bcol, 1) T.launch_thread(tx, 32) MultiA = T.allocate([8], "int8", scope="local") MultiB = T.allocate([4], "uint8", scope="local") Accum = T.allocate([4], "int32", scope="local") for i in range(4): Accum[i] = T.int32(0) for mma_multi_a_col in range(8): MultiA[mma_multi_a_col] = A[ (tx % 32) // 4 + mma_multi_a_col // 4 * 8, (tx % 32) % 4 * 4 + mma_multi_a_col % 4, ] for mma_multi_b_col in T.vectorized(4): MultiB[mma_multi_b_col] = B[ (tx % 32) // 4, (tx % 32) % 4 * 4 + mma_multi_b_col, ] T.evaluate( T.ptx_mma( "m16n8k16", "row", "col", "int8", "uint8", "int32", MultiA.data, 0, MultiB.data, 0, Accum.data, 0, False, dtype="int32", ) ) for mma_accum_c_id in range(4): C[ (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2, ] = Accum[mma_accum_c_id]
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") C = T.match_buffer(var_C, [512, 512], dtype="float32") # body # with T.block("root") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 32768): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): with T.block("A_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(0, 2): with T.block("B_shared"): v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) B_shared[v0, v1] = B[v0, v1] for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): with T.block("C"): i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) k = T.axis.reduce(512, i2_1 * 32 + i2_2) T.reads([A_shared[i, k], B_shared[k, j]]) T.writes([C_local[i, j]]) with T.init(): C_local[i, j] = T.float32(0) C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): with T.block("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) T.writes([C[v0, v1]]) C[v0, v1] = C_local[v0, v1]
def func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None: for bx in T.thread_binding(144, thread="blockIdx.x"): for vx in T.thread_binding(2, thread="vthread.x"): for tx_p in T.thread_binding(256, thread="threadIdx.x"): with T.block(): for k_0 in T.serial(193): with T.block(): A_shared = T.alloc_buffer([960, 770], dtype="float32", scope="shared") B_shared = T.alloc_buffer([770, 2304], dtype="float32", scope="shared") for _u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(3): with T.block("A_shared"): T.where(bx // 18 * 128 + ((_u * 256 + tx) * 3 + vec) // 4 < 960 and k_0 * 4 + ((_u * 256 + tx) * 3 + vec) % 4 < 770 and (_u * 256 + tx) * 3 + vec < 512) A_shared[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] = A[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] for _u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(4): with T.block("B_shared"): T.where(k_0 * 4 + ((_u * 256 + tx) * 4 + vec) // 128 < 770 and (_u * 256 + tx) * 4 + vec < 512) B_shared[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] = B[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2): with T.block("update_update"): C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] = C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] + A_shared[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, (k_0 + k_1) * 4 + k_2] * B_shared[(k_0 + k_1) * 4 + k_2, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
def compacted_func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None: for bx in T.thread_binding(144, thread="blockIdx.x"): for vx in T.thread_binding(2, thread="vthread.x"): for tx_p in T.thread_binding(256, thread="threadIdx.x"): with T.block(): for k_0 in T.serial(193): with T.block(): A_shared = T.alloc_buffer([128, 4], dtype="float32", scope="shared") B_shared = T.alloc_buffer([4, 128], dtype="float32", scope="shared") for v_u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(3): with T.block("A_shared"): T.where(bx // 18 * 128 + (tx * 3 + vec) // 4 < 960 and k_0 * 4 + (tx * 3 + vec) % 4 < 770 and tx * 3 + vec < 512) A_shared[(tx * 3 + vec) // 4, (tx * 3 + vec) % 4] = A[bx // 18 * 128 + (tx * 3 + vec) // 4, k_0 * 4 + (tx * 3 + vec) % 4] for v_u in T.serial(1): for tx in T.thread_binding(256, thread="threadIdx.x"): for vec in T.vectorized(4): with T.block("B_shared"): T.where(k_0 * 4 + tx // 32 < 770 and tx * 4 + vec < 512) B_shared[tx // 32, tx % 32 * 4 + vec] = B[k_0 * 4 + tx // 32, bx % 18 * 128 + tx % 32 * 4 + vec] for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2): with T.block("update_update"): C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] = C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] + A_shared[tx_p // 32 * 16 + i_3 * 2 + i_4, k_2] * B_shared[k_2, vx * 64 + tx_p % 32 * 2 + j_4]
def element_wise_compute_at_split_vectorized(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j0]) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.serial(0, 32): for j1i in T.vectorized(0, 4): with T.block("C"): vi = T.axis.S(128, i) vj = T.axis.S(128, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0
def element_wise_compute_at_split_vectorized(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) B = T.alloc_buffer((128, 128)) for i in T.serial(0, 128): for j0 in T.serial(0, 128): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i) T.bind(vj, j0) B[vi, vj] = A[vi, vj] * 2.0 for j1o in T.serial(0, 32): for j1i in T.vectorized(0, 4): with T.block([128, 128], "C") as [vi, vj]: T.bind(vi, i) T.bind(vj, j1o * 4 + j1i) C[vi, vj] = B[vi, vj] + 1.0
def Move_PUV0(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") # body with T.block("root"): for i0_j0_fused in T.parallel(0, 8192): for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8): for k1_fused in T.vectorized(0, 32): with T.block("move"): vi = T.axis.spatial( 1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2) vj = T.axis.spatial( 1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2) vk = T.axis.spatial(1024, k0 * 32 + k1_fused) T.where(i0_j0_fused // 64 * 16 + i1 * 4 + i2 < 1024 and i0_j0_fused % 64 * 32 + j1 * 8 + j2 < 1024 and k0 * 32 + k1_fused < 1024) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk]
def ptx_global_to_shared_copy(A: T.Buffer[(32, 128), dtype], B: T.Buffer[(32, 128), dtype]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) T.launch_thread(tx, 32) with T.block(): A_shared = T.alloc_buffer([32, 128], dtype, scope="shared") T.reads(A[0:32, 0:128]) T.writes(B[0:32, 0:128]) T.attr("default", "async_scope", 1) for i in T.serial(num_iters): for j in T.vectorized(vector_size): A_shared[tx, i * vector_size_expr + j] = A[tx, i * vector_size_expr + j] T.evaluate(T.ptx_commit_group(dtype="")) T.evaluate(T.ptx_wait_group(0, dtype="")) for i in range(128): B[tx, i] = A_shared[tx, i]
def decomposed_gemm_after_vectorize( A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"], ): local = T.alloc_buffer((16, 16), "float32") for i, j in T.grid(4, 4): for ii, jj in T.grid(4, 4): with T.block("init"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) local[vi, vj] = 0 for k, ii, jj in T.grid(16, 4, 4): with T.block("update"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) vk = T.axis.R(16, k) local[vi, vj] += A[vi, vk] * B[vj, vk] for ii in range(4): for jj in T.vectorized(4): with T.block("C"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) C[vi, vj] = local[vi, vj]
def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle, handle_c: T.handle) -> None: # pylint: disable=missing-function-docstring # match buffer match_buffer_a = T.match_buffer(handle_a, [1024, 1024], "float16") match_buffer_b = T.match_buffer(handle_b, [1024, 1024], "float16") match_buffer_c = T.match_buffer(handle_c, [1024, 1024], "float32") # body for block_idx_x in T.thread_binding(0, 16, "blockIdx.x"): for block_idx_y in T.thread_binding(0, 8, "blockIdx.y"): with T.block(): axis_bx, axis_by = T.axis.remap("SS", [block_idx_x, block_idx_y]) shared_a = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_b = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_a = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_b = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_c = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") # pylint: disable=too-many-nested-blocks for thread_ty in T.thread_binding(0, 2, "threadIdx.y"): for thread_tz in T.thread_binding(0, 2, "threadIdx.z"): for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) T.reads([]) T.writes(wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) match_buffer_c0 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_fill_fragment( match_buffer_c0.data, 16, 16, 16, index_i * 4 + index_jj, T.float32(0), # pylint: disable=not-callable dtype="handle", )) for k_o in range(0, 32): # copy data from global to shared for thread_tx in T.thread_binding( 0, 32, "threadIdx.x"): for index_i0, index_j0 in T.grid(1, 4): for index_j1 in T.vectorized(0, 4): with T.block(): new_axis_vi = T.axis.S( 1024, axis_bx * 64 + thread_ty * 32 + thread_tx + index_i0, ) new_axis_vj = T.axis.S( 1024, k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1, ) shared_a[new_axis_vi, new_axis_vj + 8] = match_buffer_a[ new_axis_vi, new_axis_vj] for index_i0, index_j0 in T.grid(2, 4): for index_j1 in T.vectorized(0, 4): with T.block(): new_axis_vi = T.axis.S( 1024, axis_by * 128 + thread_ty * 64 + thread_tx * 2 + index_i0, ) new_axis_vj = T.axis.S( 1024, k_o * 32 + thread_tz * 16 + index_j0 * 4 + index_j1, ) shared_b[new_axis_vi, new_axis_vj + 8] = match_buffer_b[ new_axis_vi, new_axis_vj] for k_i in range(0, 2): for index_i in range(0, 2): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) axis_vk = T.axis.S(64, k_o * 2 + k_i) T.reads(shared_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ]) T.writes( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") match_buffer_a0 = T.match_buffer( shared_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[stride0, stride1], scope="shared", offset_factor=1, ) wmma_a0 = T.match_buffer( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_a0.data, 16, 16, 16, index_i, T.tvm_access_ptr( T.type_annotation( dtype="float16"), match_buffer_a0.data, match_buffer_a0.elem_offset + 8, match_buffer_a0.strides[0], 1, dtype="handle", ), match_buffer_a0.strides[0], "row_major", dtype="handle", )) for index_jj in range(0, 4): with T.block(): new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) axis_vk = T.axis.S(64, k_o * 2 + k_i) T.reads(shared_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ]) T.writes( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") match_buffer_b0 = T.match_buffer( shared_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[stride0, stride1], scope="shared", offset_factor=1, ) wmma_b0 = T.match_buffer( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_b0.data, 16, 16, 16, index_jj, T.tvm_access_ptr( T.type_annotation( dtype="float16"), match_buffer_b0.data, match_buffer_b0.elem_offset + 8, match_buffer_b0.strides[0], 1, dtype="handle", ), match_buffer_b0.strides[0], "col_major", dtype="handle", )) for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) axis_vk = T.axis.R(64, k_o * 2 + k_i) T.reads([ wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], ]) T.writes( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) wmma_a1 = T.match_buffer( wmma_a[new_axis_vi * 16:new_axis_vi * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_b1 = T.match_buffer( wmma_b[new_axis_vj * 16:new_axis_vj * 16 + 16, axis_vk * 16:axis_vk * 16 + 16, ], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_c1 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_mma_sync( wmma_c1.data, index_i * 4 + index_jj, wmma_a1.data, index_i, wmma_b1.data, index_jj, wmma_c1.data, index_i * 4 + index_jj, dtype="handle", )) for index_i, index_jj in T.grid(2, 4): with T.block(): new_axis_vi = T.axis.S( 64, axis_bx * 4 + thread_ty * 2 + index_i) new_axis_vj = T.axis.S( 64, axis_by * 8 + thread_tz * 4 + index_jj) T.reads(wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) T.writes( match_buffer_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ]) stride0 = T.var("int32") stride1 = T.var("int32") wmma_c2 = T.match_buffer( wmma_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) match_buffer_c1 = T.match_buffer( match_buffer_c[new_axis_vi * 16:new_axis_vi * 16 + 16, new_axis_vj * 16:new_axis_vj * 16 + 16, ], (16, 16), "float32", strides=[stride0, stride1], offset_factor=1, ) T.evaluate( T.tvm_store_matrix_sync( wmma_c2.data, 16, 16, 16, index_i * 4 + index_jj, T.tvm_access_ptr( T.type_annotation(dtype="float32"), match_buffer_c1.data, match_buffer_c1.elem_offset, match_buffer_c1.strides[0], 1, dtype="handle", ), match_buffer_c1.strides[0], "row_major", dtype="handle", ))
def main( placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid( 1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1): for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid( 4, 7, 4, 4): with T.block("conv2d_NCHWc_int8_o_init"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init) oh = T.axis.spatial(56, i2_0 * 28 + i2_2_init * 4 + i2_3_init) ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init) oc_block_o = T.axis.spatial(1, 0) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) for i4_1 in T.vectorized(16): with T.block("conv2d_NCHWc_int8_init"): oc_block_init = T.axis.spatial(16, i4_1) T.reads() T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init]) conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_init] = 0 for ( i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3, ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1): with T.block("conv2d_NCHWc_int8_o_update"): n = T.axis.spatial(1, 0) oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2) oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3) ow = T.axis.spatial(56, i3_1 * 4 + i3_3) oc_block_o = T.axis.spatial(1, 0) kh = T.axis.reduce(1, 0) kw = T.axis.reduce(1, 0) ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1]) ic_s_inner_o = T.axis.reduce(1, 0) T.reads( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], ) T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16]) A = T.match_buffer( placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4:ic_f_inner * 4 + 4], [4], dtype="uint8", offset_factor=1, ) B = T.match_buffer( placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16, 0:4], [16, 4], dtype="int8", offset_factor=1, ) C = T.match_buffer( conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16], [16], dtype="int32", offset_factor=1, ) A_u8x4 = A.vload([0], "uint8x4") A_i32 = T.reinterpret(A_u8x4, dtype="int32") B_i8x64 = B.vload([0, 0], dtype="int8x64") B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") C[T.ramp( 0, 1, 16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id( "llvm.x86.avx512.vpdpbusd.512"), T.uint32(0), T.broadcast(0, 16), T.broadcast(A_i32, 16), B_i32x16, dtype="int32x16", )
def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None: # match buffer A = T.match_buffer(a, [1024, 1024], "float16") B = T.match_buffer(b, [1024, 1024], "float16") C = T.match_buffer(c, [1024, 1024], "float32") # body for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"): for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"): with T.block([16, 8]) as [bx, by]: T.bind(bx, blockIdx_x) T.bind(by, blockIdx_y) shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared") shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared") wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a") wmma_B = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b") wmma_C = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator") for ty in T.thread_binding(0, 2, "threadIdx.y"): for tz in T.thread_binding(0, 2, "threadIdx.z"): for i, j in T.grid(2, 4): with T.block([64, 64]) as [vi, vj]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.reads([]) T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) C0 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_fill_fragment( C0.data, 16, 16, 16, i * 4 + j, T.float32(0), dtype="handle", ) ) for ko in range(0, 32): # copy data from global to shared for tx in T.thread_binding(0, 32, "threadIdx.x"): for i0, j0 in T.grid(1, 4): for j1 in T.vectorized(0, 4): with T.block([1024, 1024]) as [vi, vj]: T.bind(vi, bx * 64 + ty * 32 + tx + i0) T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_A[vi, vj + 8] = A[vi, vj] for i0, j0 in T.grid(2, 4): for j1 in T.vectorized(0, 4): with T.block([1024, 1024]) as [vi, vj]: T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0) T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1) shared_B[vi, vj + 8] = B[vi, vj] for ki in range(0, 2): for i in range(0, 2): with T.block([64, 64]) as [vi, vk]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vk, ko * 2 + ki) T.reads( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) T.writes( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = T.var("int32") s1 = T.var("int32") A0 = T.match_buffer( shared_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_A0 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_A0.data, 16, 16, 16, i, T.tvm_access_ptr( T.type_annotation(dtype="float16"), A0.data, A0.elem_offset + 8, A0.strides[0], 1, dtype="handle", ), A0.strides[0], "row_major", dtype="handle", ) ) for j in range(0, 4): with T.block([64, 64]) as [vj, vk]: T.bind(vj, by * 8 + tz * 4 + j) T.bind(vk, ko * 2 + ki) T.reads( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ] ) T.writes( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16] ) s0 = T.var("int32") s1 = T.var("int32") B0 = T.match_buffer( shared_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 + 8, ], (16, 16 + 8), "float16", strides=[s0, s1], scope="shared", offset_factor=1, ) wmma_B0 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) T.evaluate( T.tvm_load_matrix_sync( wmma_B0.data, 16, 16, 16, j, T.tvm_access_ptr( T.type_annotation(dtype="float16"), B0.data, B0.elem_offset + 8, B0.strides[0], 1, dtype="handle", ), B0.strides[0], "col_major", dtype="handle", ) ) for i, j in T.grid(2, 4): with T.block([64, 64, T.reduce_axis(0, 64)]) as [ vi, vj, vk, ]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.bind(vk, ko * 2 + ki) T.reads( [ wmma_A[ vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_B[ vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16 ], wmma_C[ vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16 ], ] ) T.writes( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16] ) wmma_A1 = T.match_buffer( wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_a", offset_factor=1, ) wmma_B1 = T.match_buffer( wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16], (16, 16), "float16", strides=[16, 1], scope="wmma.matrix_b", offset_factor=1, ) wmma_C1 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) T.evaluate( T.tvm_mma_sync( wmma_C1.data, i * 4 + j, wmma_A1.data, i, wmma_B1.data, j, wmma_C1.data, i * 4 + j, dtype="handle", ) ) for i, j in T.grid(2, 4): with T.block([64, 64]) as [vi, vj]: T.bind(vi, bx * 4 + ty * 2 + i) T.bind(vj, by * 8 + tz * 4 + j) T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) s0 = T.var("int32") s1 = T.var("int32") wmma_C2 = T.match_buffer( wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[16 * 4, 1], scope="wmma.accumulator", offset_factor=1, ) C1 = T.match_buffer( C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16], (16, 16), "float32", strides=[s0, s1], offset_factor=1, ) T.evaluate( T.tvm_store_matrix_sync( wmma_C2.data, 16, 16, 16, i * 4 + j, T.tvm_access_ptr( T.type_annotation(dtype="float32"), C1.data, C1.elem_offset, C1.strides[0], 1, dtype="handle", ), C1.strides[0], "row_major", dtype="handle", ) )
def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") compute_local = T.alloc_buffer([1, 512, 56, 56], dtype="float32", scope="local") pad_temp_shared = T.alloc_buffer([1, 512, 58, 58], dtype="float32", scope="shared") W_shared = T.alloc_buffer([512, 512, 3, 3], dtype="float32", scope="shared") for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): with T.block("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0, 224, thread="blockIdx.x"): for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding( 0, 2, thread="vthread.x"): for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding( 0, 8, thread="threadIdx.x"): for i4_0, i5_0, i6_0 in T.grid(1, 3, 1): for ax0_ax1_ax2_ax3_fused_0 in T.serial( 0, 40960, annotations={ "meta_schedule.cooperative_fetch": 1 }): for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3): with T.block("pad_temp_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial( 512, (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 // 8 % 512) v2 = T.axis.spatial( 58, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i5_0 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) // 30 % 8) v3 = T.axis.spatial( 58, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + (ax0_ax1_ax2_ax3_fused_0 * 3 + ax0_ax1_ax2_ax3_fused_1) % 30) pad_temp_shared[v0, v1, v2, v3] = pad_temp[v0, v1, v2, v3] for ax0_ax1_ax2_ax3_fused_0 in T.serial( 0, 12288, annotations={ "meta_schedule.cooperative_fetch": 1 }): for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4): with T.block("W_shared"): v0 = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 1536) v1 = T.axis.spatial( 512, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) // 3 % 512) v2 = T.axis.spatial(3, i5_0) v3 = T.axis.spatial( 3, (ax0_ax1_ax2_ax3_fused_0 * 4 + ax0_ax1_ax2_ax3_fused_1) % 3) W_shared[v0, v1, v2, v3] = W[v0, v1, v2, v3] for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid( 32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28): with T.block("compute"): nn = T.axis.spatial(1, 0) ff = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4) yy = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4) xx = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4) rc = T.axis.reduce(512, i4_1 * 16 + i4_2) ry, rx = T.axis.remap("RR", [i5_0, i6_2]) with T.init(): compute_local[nn, ff, yy, xx] = T.float32(0) compute_local[nn, ff, yy, xx] = compute_local[ nn, ff, yy, xx] + pad_temp_shared[ nn, rc, yy + ry, xx + rx] * W_shared[ff, rc, ry, rx] for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28): with T.block("compute_local"): v0 = T.axis.spatial(1, ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1) v2 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 + i0_1_i1_1_i2_1_i3_1_fused * 4 + i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2) v3 = T.axis.spatial( 56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3) compute_1[v0, v1, v2, v3] = compute_local[v0, v1, v2, v3] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("compute_1"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) compute[i0_2, i1_2, i2_2, i3_2] = T.max( (compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) * bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0))
def loops() -> None: for i in T.parallel(0, 2): for j in T.serial(0, 1): for z in T.vectorized(3, 4): T.evaluate(0)
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") C = T.match_buffer(var_C, [512, 512], dtype="float32") C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"): for i2_0 in T.serial(0, 1): for ax0_ax1_fused_0 in T.serial(0, 32768): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.x"): with T.block("A_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512) v1 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512) T.reads([A[v0, v1]]) T.writes([A_shared[v0, v1]]) T.block_attr( {"meta_schedule.cooperative_fetch": 1}) A_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused_0 in T.serial(0, 1024): for ax0_ax1_fused_1 in T.thread_binding( 0, 8, thread="threadIdx.x"): for ax0_ax1_fused_2 in T.vectorized(0, 2): with T.block("B_shared"): v0 = T.axis.spatial( 512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32) T.reads([B[v0, v1]]) T.writes([B_shared[v0, v1]]) T.block_attr({ "meta_schedule.cooperative_fetch": 2 }) B_shared[v0, v1] = B[v0, v1] for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid( 2, 2, 16, 2): with T.block("C_init"): i = T.axis.spatial( 512, i0_1_i1_1_fused * 32 + i0_3_init * 16 + i0_4_init) j = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3_init * 2 + i1_4_init) T.reads([]) T.writes([C_local[i, j]]) C_local[i, j] = T.float32(0) for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid( 16, 2, 2, 32, 16, 2): with T.block("C_update"): i = T.axis.spatial( 512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) j = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4) k = T.axis.reduce(512, i2_1 * 32 + i2_2) T.reads([ C_local[i, j], A_shared[i, k], B_shared[k, j] ]) T.writes([C_local[i, j]]) C_local[i, j] = C_local[ i, j] + A_shared[i, k] * B_shared[k, j] for ax0, ax1 in T.grid(32, 4): with T.block("C_local"): v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) v1 = T.axis.spatial( 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1) T.reads([C_local[v0, v1]]) T.writes([C[v0, v1]]) C[v0, v1] = C_local[v0, v1]
def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024), "float32"], C: T.Buffer[(1024, 1024), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # body # with T.block("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"): for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): for k_0 in T.serial(32): with T.block(): T.reads( A[blockIdx_y * 32:blockIdx_y * 32 + 32, k_0 * 32:k_0 * 32 + 32], B[k_0 * 32:k_0 * 32 + 32, blockIdx_x * 32:blockIdx_x * 32 + 32]) T.writes( C[blockIdx_y * 32:blockIdx_y * 32 + 32, blockIdx_x * 32:blockIdx_x * 32 + 32]) A_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") B_shared = T.alloc_buffer([1024, 1024], dtype="float32", scope="shared") for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("A_shared"): T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(A_shared[ blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.block_attr({ "tir.manifest_shared_memory_local_stage": 1 }) A_shared[ blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A[ blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.block("B_shared"): T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(B_shared[ k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.block_attr({ "tir.manifest_shared_memory_local_stage": 1 }) B_shared[ k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B[ k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for k_1, i_2, j_2, k_2 in T.grid( 2, 16, 16, 16): with T.block("C"): T.reads( A_shared[blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2], B_shared[k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) T.writes(C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]) if k_0 * 32 + k_1 * 16 + k_2 == 0: C[blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = T.float32(0) C[ blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] = C[ blockIdx_y * 32 + threadIdx_y * 16 + i_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2] + A_shared[ blockIdx_y * 32 + threadIdx_y * 16 + i_2, k_0 * 32 + k_1 * 16 + k_2] * B_shared[ k_0 * 32 + k_1 * 16 + k_2, blockIdx_x * 32 + threadIdx_x * 16 + j_2]