def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp") B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp") C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp") with T.block("root"): T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8]) T.writes(C[0:32, 0:8]) for i, j, k in T.grid(16, 16, 16): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) T.reads( C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2], A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2], B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2], ) T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]) C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = ( C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] + A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2] * B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2] )
def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16, 16], "float32") B = T.match_buffer(b, [16], "float32") B_rf_local = T.alloc_buffer([16, 16], "float32", scope="local") for j in T.thread_binding(0, 16, thread="blockIdx.x"): for i_o in T.thread_binding(0, 4, thread="threadIdx.x"): for i_i, k in T.grid(4, 16): with T.block([16, 16, T.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: T.bind(vi, i_o * 4 + i_i) T.bind(vj, j) T.bind(vk, k) with T.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for k in T.serial(0, 4): with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]: T.bind(vi, j) T.bind(vk, i_o * 4 + k) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi]
def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle): N = T.var("int32") M = T.var("int32") K = T.var("int32") A = T.match_buffer(a, (N, K), "float32") B = T.match_buffer(b, (K, M), "float32") C = T.match_buffer(c, (N, M), "float32") for i, j, k in T.grid(N, M, K): with T.block("gemm"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128], "float32") B = T.alloc_buffer([128, 128], "float32") C = T.match_buffer(c, [128, 128], "float32") for i_o, j_o in T.grid(8, 8): with T.block([8, 8], "B_outer") as [vio, vjo]: T.bind(vio, i_o) T.bind(vjo, j_o) T.reads([A[vio * 16:vio * 16 + 16, vjo * 16:vjo * 16 + 16, ]]) T.writes([B[vio * 16:vio * 16 + 16, vjo * 16:vjo * 16 + 16]]) for i_i, j_i in T.grid(16, 16): with T.block([128, 128], "B_inner") as [vi, vj]: T.bind(vi, vio * 16 + i_i) T.bind(vj, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for ax0, ax1 in T.grid(16, 16): with T.block([128, 128], "C") as [vi, vj]: T.bind(vi, i_o * 16 + ax0) T.bind(vj, j_o * 16 + ax1) T.reads([B[vi, vj]]) T.writes([C[vi, vj]]) C[vi, vj] = B[vi, vj] + 1.0
def unschedulable_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with T.block(): T.reads(A[i, 0:16]) T.writes(C[i, 0:16]) B = T.alloc_buffer((16, 16), "float32") for j in range(0, 16): T.evaluate(T.call_extern("dummy_extern_function", B.data, dtype="int32")) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): C[i, j] = B[i, j] * 2.0
def reduction_loop_only( A: T.Buffer[2, "float32"], B: T.Buffer[2, "float32"], C: T.Buffer[(), "float32"], ) -> None: for i0 in T.serial(2): with T.block("C"): k0 = T.axis.reduce(2, i0) T.reads(A[k0], B[k0]) T.writes(C[()]) with T.init(): C[()] = T.float32(1.0) C[()] = T.min(C[()], A[k0] / B[k0])
def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) D = T.match_buffer(d, [128, 128]) for k, i, j in T.grid(128, 128, 128): with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: T.bind(ck, k) T.bind(ci, i) T.bind(cj, j) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: T.bind(dk, k) T.bind(di, i) T.bind(dj, j) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [16, 256, 256]) D = T.match_buffer(d, [16]) C = T.alloc_buffer([16]) for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: T.bind(b, i0) T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) T.bind(j, T.floormod(i1_i2_fused_outer, 256)) T.reads([C[b], A[b, i, j]]) T.writes([C[b]]) with T.init(): C[b] = 0.0 C[b] = C[b] + (A[b, i, j] * A[b, i, j]) for i0_1 in T.serial(0, 16): with T.block([16], "D") as [b_1]: T.bind(b_1, i0_1) T.reads([C[b_1]]) T.writes([D[b_1]]) D[b_1] = T.sqrt(C[b_1], dtype="float32")
def opaque_access(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [16, 16], "float32") B = T.match_buffer(b, [16, 16], "float32") for i, j in T.grid(16, 16): with T.block("A"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([A[0:16, 0:16]]) A[vi, vj] = 1 for i, j in T.grid(16, 16): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads([]) T.writes([B[0:16, 0:16]]) T.evaluate( T.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle"))
def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): with T.block("B"): T.where((i0 * 2 + i1) * 3 + i2 < 128 and j1 < 128 and k0 * 43 + k1 < 128) vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) vj = T.axis.S(128, j1) vk = T.axis.S(128, k0 * 43 + k1) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def main(a: T.handle, d: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, [1024, 1024], dtype="float32") D = T.match_buffer(d, [1024, 1024], dtype="float32") # body # with tir.block("root") B = T.alloc_buffer([1024, 1024], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): with T.block("A"): vi = T.axis.S(1024, i0_0 * 64 + i0_1) vj = T.axis.S(1024, i1_0 * 16 + i1_1) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = A[vi, vj] * T.float32(2) for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): with T.block("C"): vi = T.axis.S(1024, i0_0 * 64 + i0_1) vj = T.axis.S(1024, i1_0 * 16 + i1_1) T.reads([B[vi, vj]]) T.writes([D[vi, vj]]) D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5)
def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block("root"): T.reads([]) T.writes([]) B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in T.serial(0, 128): for ax1 in T.serial(0, 128): with T.block("B"): T.block_attr({"buffer_dim_align": [0]}) vi, vj = T.axis.remap("SS", [i0, ax1]) T.reads([A[vi, vj]]) T.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj]*T.float32(2)) for i1 in T.serial(0, 128): with T.block("C"): vi_1, vj_1 = T.axis.remap("SS", [i0, i1]) T.reads([B[vi_1, vj_1]]) T.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable A = T.match_buffer(a, [2048, 2048], "float32") B = T.match_buffer(b, [2048, 2048], "float32") C = T.match_buffer(c, [2048, 2048], "float32") A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local") C_local = T.alloc_buffer([2048, 2048], "float32", scope="local") with T.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] with T.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] with T.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] for by in T.thread_binding(0, 32, thread="blockIdx.y"): for bx in T.thread_binding(0, 32, thread="blockIdx.x"): for vy in T.thread_binding(0, 2, thread="vthread.y"): for vx in T.thread_binding(0, 2, thread="vthread.x"): for ty in T.thread_binding(0, 8, thread="threadIdx.y"): for tx in T.thread_binding(0, 8, thread="threadIdx.x"): for k_0 in T.serial(0, 256): for k_1 in T.unroll(0, 8): for i, j in T.grid(1, 4): with T.block( [2048, 2048], "A_shared_local") as [v0, v1]: T.bind(v0, k_0 * 8 + k_1 + i) T.bind( v1, by * 64 + vy * 32 + ty * 4 + j) A_shared_local[v0, v1] = A_shared[v0, v1] for _, i, j in T.grid(1, 4, 4): with T.block([ 2048, 2048, T.reduce_axis(0, 2048) ], "C") as [vi, vj, vk]: T.bind( vi, by * 64 + vy * 32 + ty * 4 + i) T.bind( vj, bx * 64 + vx * 32 + tx * 4 + j) T.bind(vk, k_0 * 8 + k_1) with T.init(): C_local[vi, vj] = T.float32(0) C_local[vi, vj] = C_local[ vi, vj] + A_shared_local[ vk, vi] * B_shared_local[ vk, vj] for i, j in T.grid(4, 4): with T.block([2048, 2048], "C_local") as [v0, v1]: T.bind(v0, by * 64 + vy * 32 + ty * 4 + i) T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j) C[v0, v1] = C_local[v0, v1]
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" ) B = T.match_buffer( b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp" ) C = T.match_buffer( c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp" ) with T.block("root"): T.reads( C[0:WARP_SIZE, 0:local_size_out], A[0:WARP_SIZE, 0:local_size], B[0:WARP_SIZE, 0:local_size], ) T.writes(C[0:WARP_SIZE, 0:local_size_out]) for i, j, k in T.grid(M_DIM, N_DIM, k_dim): with T.block("C"): i, j, k = T.axis.remap("SSR", [i, j, k]) b_row_ind, b_col_ind = maybe_swap(k, j) thread_id_C, local_id_C = index_map_C(i, j) thread_id_A, local_id_A = index_map_A(i, k) thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind) T.reads( C[thread_id_C, local_id_C], A[thread_id_A, local_id_A], B[thread_id_B, local_id_B], ) T.writes(C[thread_id_C, local_id_C]) C[thread_id_C, local_id_C] += maybe_cast( A[thread_id_A, local_id_A] ) * maybe_cast(B[thread_id_B, local_id_B])
def spatial_tiled_pad_and_pooling( X: T.Buffer[(64, 112, 112), "int32"], Y: T.Buffer[(64, 56, 56), "int32"] ) -> None: for h_o, w_o in T.grid(14, 14): with T.block(): X_cache = T.alloc_buffer([112, 112, 64], dtype="int32") for ax0, ax1, ax2 in T.grid(64, 9, 9): with T.block("cache"): T.where(1 <= h_o * 8 + ax1 and 1 <= w_o * 8 + ax2) T.reads(X[ax0, h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2]) T.writes(X_cache[h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2, ax0]) X_cache[h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2, ax0] = X[ ax0, h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2 ] for h_i, w_i, kh, kw, c in T.grid(4, 4, 3, 3, 64): with T.block("compute"): T.reads( X_cache[(h_o * 4 + h_i) * 2 + kh - 1, (w_o * 4 + w_i) * 2 + kw - 1, c] ) T.writes(Y[h_o * 4 + h_i, w_o * 4 + w_i, c]) if kh == 0 and kw == 0: Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = 0 Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = T.max( Y[h_o * 4 + h_i, w_o * 4 + w_i, c], T.if_then_else( T.likely(1 <= (h_o * 4 + h_i) * 2 + kh, dtype="bool") and T.likely((h_o * 4 + h_i) * 2 + kh < 113, dtype="bool") and T.likely(1 <= (w_o * 4 + w_i) * 2 + kw, dtype="bool") and T.likely((w_o * 4 + w_i) * 2 + kw < 113, dtype="bool"), X_cache[ (h_o * 4 + h_i) * 2 + kh - 1, (w_o * 4 + w_i) * 2 + kw - 1, c, ], 0, dtype="int32", ), )
def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.alloc_buffer((16, 16)) D = T.alloc_buffer((16, 16)) E = T.alloc_buffer((16, 16)) F = T.match_buffer(f, (16, 16)) for i in T.serial(0, 16): for j1 in T.serial(0, 16): for k1o, k1i in T.grid(4, 4): with T.block("C"): ci, cj = T.axis.remap("SS", [i, j1]) ck = T.axis.R(16, k1o * 4 + k1i) with T.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] for k2o, k2i in T.grid(4, 4): with T.block("D"): di, dj = T.axis.remap("SS", [i, j1]) dk = T.axis.R(16, k2o * 4 + k2i) with T.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] for j2 in T.serial(0, 16): for k3o, k3i in T.grid(4, 4): with T.block("E"): ei, ej = T.axis.remap("SS", [i, j2]) ek = T.axis.R(16, k3o * 4 + k3i) with T.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in T.grid(4, 4): with T.block("F"): fi, fj = T.axis.remap("SS", [i, j2]) fk = T.axis.R(16, k4o * 4 + k4i) with T.init(): F[fi, fj] = 0.0 F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
def get_valid_counts( data: T.handle, valid_count: T.handle, out: T.handle, out_indices: T.handle, score_threshold: T.float32, id_index: T.int32, score_index: T.int32, ) -> None: data_buf = T.match_buffer(data, (1, 2500, 6), "float32") valid_count_buf = T.match_buffer(valid_count, (1, ), "int32") out_buf = T.match_buffer(out, (1, 2500, 6), "float32") out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32") with T.block("init"): vi = T.axis.S(1, 0) valid_count_buf[vi] = T.int32(0) for j in range(2500): with T.block("update"): vj = T.axis.S(2500, j) T.reads([data_buf[vi, vj, 6]]) T.writes([ valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6] ]) if (data_buf[vi, vj, score_index] > score_threshold) and ( (id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0))): for k in T.serial(0, 6): out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] out_indices_buf[vi, valid_count_buf[vi]] = vj valid_count_buf[vi] = valid_count_buf[vi] + 1 if vj >= valid_count_buf[vi]: for k in T.serial(0, 6): out_buf[vi, vj, k] = T.float32(-1) out_indices_buf[vi, vj] = T.int32(-1)
def cache_write_multi_consumer() -> None: A = T.alloc_buffer((128)) B = T.alloc_buffer((128)) C = T.alloc_buffer((128)) A_global = T.alloc_buffer((128)) for i in T.grid(8): for j in T.grid(16): with T.block("A_global"): vi = T.axis.S(128, i * 16 + j) A_global[vi] = 1.0 for j in T.grid(16): with T.block("A"): vi = T.axis.S(128, i * 16 + j) A[vi] = A_global[vi] for j in T.grid(16): with T.block("B"): vi = T.axis.S(128, i * 16 + j) B[vi] = A[vi] + 1.0 for i in T.grid(128): with T.block("C"): vi = T.axis.S(128, i) C[vi] = A[vi]
def cache_read_shape_int64(var_A: T.handle, var_C: T.handle) -> None: A = T.match_buffer(var_A, (T.int64(128), T.int64(128)), dtype="float32") C = T.match_buffer(var_C, (T.int64(128), T.int64(128)), dtype="float32") B = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32") A_global = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32") for ax0, ax1 in T.grid(T.int64(128), T.int64(128)): with T.block("A_global"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v0, v1]) T.writes(A_global[v0, v1]) A_global[v0, v1] = A[v0, v1] for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_global[vi, vj]) T.writes(B[vi, vj]) B[vi, vj] = A_global[vi, vj] * T.float32(2) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(B[vi, vj]) T.writes(C[vi, vj]) C[vi, vj] = B[vi, vj] + T.float32(1)
def decomposed_gemm( A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"], ): local = T.alloc_buffer((16, 16), "float32") for i, j in T.grid(4, 4): for ii, jj in T.grid(4, 4): with T.block("init"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) local[vi, vj] = 0 for k, ii, jj in T.grid(16, 4, 4): with T.block("update"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) vk = T.axis.R(16, k) local[vi, vj] += A[vi, vk] * B[vj, vk] for ii, jj in T.grid(4, 4): with T.block("C"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) C[vi, vj] = local[vi, vj]
def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): with T.block("update"): vi, vj = T.axis.remap("SS", [i0, i1]) vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) T.reads([A[vi, vk], B[vj, vk]]) T.writes([C[vi, vj]]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
def buffer_load_store(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16, 16)) C = T.match_buffer(c, (16, 16)) for i, j, k in T.grid(4, 16, 8): with T.block(): T.reads(C[i * 4:i * 4 + 4, k * 2:k * 2 + 2]) T.writes(A[i * 4:i * 4 + 4, j, k * 2:k * 2 + 2]) sub_A = T.match_buffer(A[i * 4:i * 4 + 4, j, k * 2:k * 2 + 2], (4, 1, 2), offset_factor=1) sub_C = T.match_buffer(C[i * 4:i * 4 + 4, k * 2:k * 2 + 2], (4, 2), offset_factor=1) for ii, kk in T.grid(4, 2): sub_A[ii, 0, kk] += sub_C[ii, kk]
def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i0 in range(0, 4): with T.block(): T.reads(A[i0 * 4:i0 * 4 + 4, 0:16]) T.writes(C[i0 * 4:i0 * 4 + 4, 0:16]) B = T.alloc_buffer([4, 16], "float32", strides=[17, 1], scope="global") for i1 in range(0, 4): for j in range(0, 16): with T.block() as []: T.reads(A[i0 * 4 + i1, j]) T.writes(B[i1, j]) B[i1, j] = A[i0 * 4 + i1, j] + 1.0 for i1 in range(0, 4): for j in range(0, 16): with T.block() as []: T.reads(B[i1, j]) T.writes(C[i0 * 4 + i1, j]) C[i0 * 4 + i1, j] = B[i1, j] * 2.0
def tir_matmul( A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"], ) -> None: T.func_attr({"layout_free_buffers": [1]}) for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4): with T.block("matmul"): vi = T.axis.S(16, i0 * 4 + i1) vj = T.axis.S(16, j) vk = T.axis.R(16, k0 * 4 + k1) with T.init(): C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") # body with T.block("root"): T.block_attr({ "meta_schedule.parallel": 128, "meta_schedule.vectorize": 32 }) for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid( 128, 64, 4, 4, 64, 4, 8, 32): with T.block("move"): vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) vk = T.axis.spatial(1024, k0 * 32 + k1) T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk]
def multiple_bufferstore(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") C = T.alloc_buffer([], dtype="float32") for i in T.serial(0, 128): for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vk = T.axis.remap("SR", [i, k]) T.reads([A[vi, vk], B[vi], C[()]]) T.writes([B[vi], C[()]]) with T.init(): B[vi] = T.float32(0) C[()] = A[vi, vk] B[vi] = B[vi] + C[()]
def with_block_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 120], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i, ko in T.grid(128, 4): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("B"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) T.reads([B[vi], A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + A[vi, vk]
def gemm() -> None: A = T.alloc_buffer([16, 16], "float32") B = T.alloc_buffer([16, 16], "float32") C = T.alloc_buffer([16, 16], "float32") for i, j, k, ii, jj in T.grid(4, 4, 16, 4, 4): with T.block("update"): vi = T.axis.S(16, i * 4 + ii) vj = T.axis.S(16, j * 4 + jj) vk = T.axis.R(16, k) T.reads(A[vi, vk], B[vj, vk]) T.writes(C[vi, vj]) with T.init(): C[vi, vj] = 0 C[vi, vj] += A[vi, vk] * B[vj, vk]
def expected_bufferslice_indices(data: T.handle, index: T.handle) -> None: index_buf = T.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = T.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) with T.block("root"): T.reads([]) T.writes([]) out_buf = T.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in T.grid(16, 16): with T.block(): vi, vj = T.axis.remap("SS", [i0, i1]) T.reads([data_buf[vi, index_buf[0]], index_buf[0]]) T.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[vi, index_buf[0]]
def two_bound_loops(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i in T.serial(0, 128): for ko in T.thread_binding(0, 4, thread="threadIdx.x"): for ki in T.thread_binding(0, 32, thread="threadIdx.y"): with T.block("B"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(128, ko * 32 + ki) T.reads([B[vi], A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + A[vi, vk]