def gemm() -> None: A = T.alloc_buffer([16, 16], "float32") B = T.alloc_buffer([16, 16], "float32") C = T.alloc_buffer([16, 16], "float32") for i, j, k, ii, jj in T.grid(4, 4, 16, 4, 4): with T.block("update"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) vk = T.axis.R(16, k) T.reads(A[vi, vk], B[vj, vk]) T.writes(C[vi, vj]) with T.init(): T.reads([]) T.writes(C[vi, vj]) C[vi, vj] = 0 C[vi, vj] += A[vi, vk] * B[vj, vk]
def main(a: T.handle, b: T.handle, c: T.handle) -> None: T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") with T.block("root"): for i, j, k in T.grid(1024, 1024, 1024): with T.block("matmul"): T.block_attr({ "schedule_rule": "tvm.meta_schedule.test.custom_search_space" }) vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 1), offset_factor=1) B = T.match_buffer(b, (16, 1), offset_factor=1) C = T.match_buffer(c, (16, 16), offset_factor=1) with T.block("root"): T.reads( C[0 : 16, 0 : 16], A[0 : 16, 0 : 1], B[0 : 16, 0 : 1], ) T.writes(C[0 : 16, 0 : 16]) for i, j in T.grid(16, 16): with T.block("update"): vii, vjj = T.axis.remap("SS", [i, j]) C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0]
def func(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [64, 32], dtype="float32") B = T.match_buffer(b, [64, 32], dtype="float32") C = T.match_buffer(c, [64, 32], dtype="float32") for i, j in T.grid(64, 32): # type: ignore with T.block(): T.reads([A[i, j], B[i, j]]) # type: ignore T.writes([B[i, j], C[i, j]]) # type: ignore with T.block("B"): T.reads([A[i, j]]) # type: ignore T.writes([B[i, j]]) # type: ignore B[i, j] = A[i, j] # type: ignore with T.block("C"): T.reads([B[i, j]]) # type: ignore T.writes([C[i, j]]) # type: ignore C[i, j] = B[i, j] # type: ignore
def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid( 128, 128, 4, 8, 4): with T.block("update"): vi, vj = T.axis.remap("SS", [i0, i1]) vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) T.writes([C[vi, vj]]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None: m = T.var("int32") n = T.var("int32") A0 = T.match_buffer(a0, (m, n)) A1 = T.match_buffer(a1, (m, n)) B0 = T.match_buffer(b0, (m, n)) B1 = T.match_buffer(b1, (m, n)) for i0, i1 in T.grid(m, n): with T.block("B.v0"): i, j = T.axis.remap("SS", [i0, i1]) B0[i, j] = A0[i, j] + 2.0 with T.block("B.v1"): i, j = T.axis.remap("SS", [i0, i1]) B1[i, j] = A1[i, j] * 3.0
def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) m = T.var("int32") n = T.var("int32") A0 = T.match_buffer(a0, (m, n)) A1 = T.match_buffer(a1, (m, n)) B0 = T.match_buffer(b0, (m, n)) B1 = T.match_buffer(b1, (m, n)) for i0, i1 in T.grid(m, n): with T.block("B.v0"): i, j = T.axis.remap("SS", [i0, i1]) B0[i, j] = A0[i, j] + 2.0 with T.block("B.v1"): i, j = T.axis.remap("SS", [i0, i1]) B1[i, j] = A1[i, j] * 3.0
def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") # body with T.block("root"): T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): with T.block("move"): vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) vk = T.axis.spatial(1024, k0 * 32 + k1) T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk]
def matmul_relu_ann1(a: T.handle, b: T.handle, d: T.handle) -> None: A = T.match_buffer(a, (1024, 1024)) B = T.match_buffer(b, (1024, 1024)) C = T.alloc_buffer((1024, 1024)) D = T.match_buffer(d, (1024, 1024)) for i in T.serial(0, 1024, annotations={"test1": "aaa"}): for j in T.serial(0, 1024, annotations={"test2": 612}): for k in T.serial(0, 1024): with T.block("matmul"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] for i, j in T.grid(1024, 1024): with T.block("relu"): vi, vj = T.axis.remap("SS", [i, j]) D[vi, vj] = T.max(C[vi, vj], 0.0)
def transformed_high_dim_opaque_access(a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64)) for i, j, k in T.grid(16, 2, 4): with T.block([]): T.reads([]) T.writes(A[i, j * 16 : j * 16 + 16, k * 16 : k * 16 + 16]) T.evaluate( T.intrin_test( A.data, i * 2048 + j * 1024 + k * 16, 64, 1, 16, 16, dtype="handle", ) )
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1, ), "float32"]) -> None: C = T.alloc_buffer([1], dtype="float32") for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): for i1, i2 in T.grid(256, 256): with T.block("C"): b = T.axis.S(1, 0) i, j = T.axis.remap("RR", [i1, i2]) with T.init(): C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(1, thread="threadIdx.x"): with T.block("D"): b = T.axis.S(1, 0) D[b] = T.sqrt(C[b], dtype="float32")
def main( # type: ignore placeholder: T.Buffer[(1, 3, 16, 16), "float32"], # type: ignore T_layout_trans: T.Buffer[(1, 1, 16, 16, 3), "float32"], # type: ignore ) -> None: # type: ignore # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2, i3, i4 in T.grid(1, 1, 16, 16, 3): with T.block("T_layout_trans"): ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) T.reads(placeholder[0, ax4, ax2, ax3]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) T_layout_trans[ax0, ax1, ax2, ax3, ax4] = placeholder[0, ax4, ax2, ax3]
def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536): with T.block("C"): b = T.axis.S(16, i0) i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256)) j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256)) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block("D"): b_1 = T.axis.S(16, i0_1) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def transformed_high_dim_opaque_access_with_source_strides( a: T.handle) -> None: A = T.match_buffer(a, (16, 32, 64), strides=[2576, 80, 1]) for i, j, k in T.grid(16, 2, 4): with T.block(): T.reads([]) T.writes(A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16]) T.evaluate( T.intrin_test( A.data, i * 2576 + j * 1280 + k * 16, 80, 1, 16, 16, dtype="handle", ))
def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) D = T.match_buffer(d, [128, 128]) for k, i, j in T.grid(128, 128, 128): with T.block("C"): ck, ci, cj = T.axis.remap("RSS", [k, i, j]) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] with T.block("D"): dk, di, dj = T.axis.remap("RSS", [k, i, j]) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
def different_access_indices(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128, 128], dtype="float32") B = T.match_buffer(b, [128, 128], dtype="float32") for i, j in T.grid(128, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) T.reads([B[vi, vj], A[vi, vj, vk]]) T.writes([ B[T.min(vj, vi):T.min(vj, vi) + (T.max(vj, vi) + 1 - T.min(vj, vi)), T.min(vi, vj):T.min(vi, vj) + (T.max(vi, vj) + 1 - T.min(vi, vj)), ] ]) with T.init(): B[vj, vi] = T.float32(0) B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
def elementwise_non_single_branch(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128)) C = T.alloc_buffer((128, 128, 128)) B = T.match_buffer(b, (128, 128, 128)) for i, j in T.grid(128, 128): for k in T.serial(0, 128): with T.block([128, 128, 128], "C") as [vi, vj, vk]: T.bind(vi, i) T.bind(vj, j) T.bind(vk, k) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in T.serial(0, 128): with T.block([128, 128, 128], "B") as [vi, vj, vk]: T.bind(vi, i) T.bind(vj, j) T.bind(vk, k) B[vi, vj, vk] = C[vi, vj, vk] * 2.0
def main(placeholder: T.Buffer[(12, 64, 64), "float32"], T_reshape: T.Buffer[(64, 768), "float32"]) -> None: for i0_i1_fused_0, i0_i1_fused_1 in T.grid(1536000, 32): with T.block("T_reshape_1"): ax0 = T.axis.spatial( 64, (i0_i1_fused_0 * 32 + i0_i1_fused_1) // 768) ax1 = T.axis.spatial( 768, (i0_i1_fused_0 * 32 + i0_i1_fused_1) % 768) T.reads(placeholder[ax1 % 768 // 64, (ax1 // 768 + ax0) % 64, ax1 % 64]) T.writes(T_reshape[ax0, ax1]) T_reshape[ax0, ax1] = placeholder[( (ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) // 64 + ax1 % 768 // 64) % 12, (ax1 % 64 // 64 + (ax1 // 768 + ax0) % 64) % 64, ax1 % 64 % 64, ]
def matmul( A: T.Buffer[(512, 512), "float32"], B: T.Buffer[(512, 512), "float32"], C: T.Buffer[(512, 512), "float32"], ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2 in T.grid(512, 512, 512): with T.block("C"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) T.reads(C[i, j], A[i, k], B[k, j]) T.writes(C[i, j]) with T.init(): C[i, j] = T.float32(0) C[i, j] = C[i, j] + A[i, k] * B[k, j]
def access_of_padding_pattern() -> None: X = T.alloc_buffer([28, 28]) X_pad = T.alloc_buffer([32, 32]) Y = T.alloc_buffer([28, 28]) for i, j in T.grid(32, 32): with T.block("padding"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([X[vi - 2, vj - 2]]) T.writes([X_pad[vi, vj]]) X_pad[vi, vj] = T.if_then_else( 2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32" ) with T.block("padding_reverse"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([X_pad[vi, vj]]) T.writes([Y[vi - 2, vj - 2]]) if 2 <= vi and vi < 30 and 2 <= vj and vj < 30: Y[vi - 2, vj - 2] = X_pad[vi, vj]
def thread_bound_nested_block_after_cache_read( A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16,), "float32"] ) -> None: for i in T.thread_binding(16, thread="blockIdx.x"): with T.block("outer"): vi = T.axis.spatial(16, i) A_shared = T.alloc_buffer([1, 16], dtype="float32", scope="shared") for ax0, ax1 in T.grid(1, 16): with T.block("A_shared"): v0 = T.axis.spatial(16, vi + ax0) v1 = T.axis.spatial(16, ax1) A_shared[v0, v1] = A[v0, v1] for j in T.thread_binding(16, thread="threadIdx.x"): with T.block("inner"): vj = T.axis.reduce(16, j) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + A_shared[vi, vj]
def wmma_store_desc(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope="wmma.accumulator") C = T.match_buffer(c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=scope) with T.block("root"): T.reads(A[0:m_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) for i, j in T.grid(m_dim, n_dim): with T.block("store"): vii, vjj = T.axis.remap("SS", [i, j]) C[vii, vjj] = A[vii, vjj]
def wmma_load_desc(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=shared_scope) C = T.match_buffer(c, (m_dim, n_dim), dtype, align=128, offset_factor=16, scope=wmma_fragment_scope) with T.block("root"): T.reads(A[0:m_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) for i, j in T.grid(m_dim, n_dim): with T.block("load"): vii, vjj = T.axis.remap("SS", [i, j]) C[vii, vjj] = A[vii, vjj]
def main(a: T.handle, b: T.handle, c: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True }) A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) # body for x, y in T.grid(128, 128): C.data[x * 128 + y] = 0.0 for k in T.serial(0, 128): C.data[x * 128 + y] = T.load( "float32", C.data, x * 128 + y) + T.load("float32", A.data, x * 128 + k) * T.load( "float32", B.data, y * 128 + k)
def transformed_rank0_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (8, 8)) B = T.match_buffer(b, (8, 8)) for i, j in T.grid(8, 8): with T.block(): T.reads([]) T.writes([A[i, j], B[i, j]]) A[i, j] = 1 T.evaluate( T.intrin_test( B.data, i * 8 + j, 0, 0, 0, 0, dtype="handle", ))
def main(var_X: T.handle, var_W: T.handle, var_B: T.handle, var_bn_scale: T.handle, var_bn_offset: T.handle, var_compute: T.handle) -> None: X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32") W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32") B = T.match_buffer(var_B, [512, 1, 1], dtype="float32") bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32") bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32") compute = T.match_buffer(var_compute, [1, 512, 56, 56], dtype="float32") pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32") compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bias_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bn_mul = T.alloc_buffer([1, 512, 56, 56], dtype="float32") bn_add = T.alloc_buffer([1, 512, 56, 56], dtype="float32") for i0, i1, i2, i3 in T.grid(1, 512, 58, 58): with T.block("pad_temp"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57 and i3_1 >= 1 and i3_1 < 57, X[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 512, 56, 56, 512, 3, 3): with T.block("compute"): nn, ff, yy, xx, rc, ry, rx = T.axis.remap( "SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]) with T.init(): compute_1[nn, ff, yy, xx] = T.float32(0) compute_1[nn, ff, yy, xx] = compute_1[nn, ff, yy, xx] + pad_temp[ nn, rc, yy + ry, xx + rx] * W[ff, rc, ry, rx] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bias_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bias_add[i, j, k, l] = compute_1[i, j, k, l] + B[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bn_mul"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_mul[i, j, k, l] = bias_add[i, j, k, l] * bn_scale[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("bn_add"): i, j, k, l = T.axis.remap("SSSS", [i0, i1, i2, i3]) bn_add[i, j, k, l] = bn_mul[i, j, k, l] + bn_offset[j, 0, 0] for i0, i1, i2, i3 in T.grid(1, 512, 56, 56): with T.block("compute_1"): i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3]) compute[i0_2, i1_2, i2_2, i3_2] = T.max(bn_add[i0_2, i1_2, i2_2, i3_2], T.float32(0))
def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid( 128, 128, 4, 8, 4): with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: T.bind(vi, i0) T.bind(vj, i1) T.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) T.writes([C[vi, vj]]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
def func_match_buffer(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): with T.block("root"): s = T.var("int32") e = T.var("int32") # A0 should be remapped A0 = T.match_buffer( A[0:128, 0:128], shape=(128, 128), dtype="float32", # s and e should be remapped strides=[s, s], elem_offset=e, ) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A0[vi, vj] * 2.0
def concat_func_3( placeholder: T.Buffer[(50176,), "int8"], placeholder_1: T.Buffer[(25088,), "int8"], placeholder_2: T.Buffer[(25088,), "int8"], T_concat: T.Buffer[(100352,), "int8"], ) -> None: T.preflattened_buffer(placeholder, (1, 64, 28, 28), "int8", data=placeholder.data) T.preflattened_buffer(placeholder_1, (1, 32, 28, 28), "int8", data=placeholder_1.data) T.preflattened_buffer(placeholder_2, (1, 32, 28, 28), "int8", data=placeholder_2.data) T.preflattened_buffer(T_concat, (1, 128, 28, 28), "int8", data=T_concat.data) for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): for i2, i3 in T.grid(28, 28): if 96 <= i1: T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3 - 75264] if 64 <= i1 and i1 < 96: T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[i1 * 784 + i2 * 28 + i3 - 50176] if i1 < 64: T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp") B = T.match_buffer(b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp") C = T.match_buffer(c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp") with T.block("root"): T.reads( C[0:WARP_SIZE, 0:local_size_out], A[0:WARP_SIZE, 0:local_size], B[0:WARP_SIZE, 0:local_size], ) T.writes(C[0:WARP_SIZE, 0:local_size_out]) for i, j, k in T.grid(M_DIM, N_DIM, k_dim): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) b_row_ind, b_col_ind = maybe_swap(k, j) thread_id_C, local_id_C = index_map_C(i, j) thread_id_A, local_id_A = index_map_A(i, k) thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind) T.reads( C[thread_id_C, local_id_C], A[thread_id_A, local_id_A], B[thread_id_B, local_id_B], ) T.writes(C[thread_id_C, local_id_C]) C[thread_id_C, local_id_C] += maybe_cast( A[thread_id_A, local_id_A]) * maybe_cast(B[thread_id_B, local_id_B])