def war_dependency(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, 128)) C = tir.match_buffer(c, (128, 128)) for i, j in tir.grid(128, 128): with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0 with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) C = tir.match_buffer(c, [128, 128]) for i, j in tir.grid(128, 128): with tir.block([128, 128], "init") as [vi, vj]: C[vi, vj] = tir.float32(0) for k in range(0, 128): with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def compacted_predicate_func(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (32), "float32") C = tir.match_buffer(c, (32), "float32") for i, j in tir.grid(5, 7): with tir.block([]) as []: tir.reads(A[i * 7 + j]) tir.writes(C[i * 7 + j]) tir.where(i * 7 + j < 32) C[i * 7 + j] = A[i * 7 + j] + 1.0
def elementwise_predicate(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 for i, j in tir.grid(128, 128): with tir.block([128, 128], "C") as [vi, vj]: tir.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0
def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [128, 128, 128]) B = tir.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, ((i1 * 64) + i3)) tir.bind(vj, ((j1 * 32) + j2)) tir.bind(vk, ((k1 * 8) + k2)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def rowsum_transformed(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, )) for io, ii_ko_fused, ki in tir.grid(32, 128, 4): with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32)) tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, )) for i, k in tir.grid(128, 16): with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: tir.bind(vi, i) tir.bind(vk, tir.floordiv(k * k, 2)) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) C = tir.alloc_buffer((128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.serial(0, 128): with tir.block([128, 128, 128], "C") as [vi, vj, vk]: C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in tir.serial(0, 128): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: B[vi, vj, vk] = C[vi, vj, vk] * 2.0
def buffer_shape_mismatch(a: ty.handle) -> None: A = tir.match_buffer(a, (8, 8)) for i, j in tir.grid(8, 2): with tir.block([]): tir.reads([]) tir.writes([A[i, j * 4:j * 4 + 4]]) sub_A = tir.match_buffer( A[i, j * 4:j * 4 + 4], (5)) # error: shape mismatched between 4 and 5 for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1
def tiled(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128], "float32") B = tir.alloc_buffer([128, 128], "float32") C = tir.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1, j_1 in tir.grid(8, 8, 16, 16): with tir.block([128, 128], "B") as [vi, vj]: tir.bind(vi, i_0 * 16 + i_1) tir.bind(vj, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def elementwise_affine_producer(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128), "float32") C = tir.match_buffer(c, (128, 128), "float32") B = tir.alloc_buffer((128, 128), "float32") for i, j, k, l in tir.grid(16, 2, 32, 16): with tir.block([128, 128], "B") as [vi, vj]: tir.bind(vi, i * 8 + j * 4 + k // 8) tir.bind(vj, k % 8 * 16 + l) B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def blockized_2(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128], "float32") B = tir.alloc_buffer([128, 128], "float32") C = tir.match_buffer(c, [128, 128], "float32") for i_o, j_o in tir.grid(8, 8): with tir.block([8, 8], "B_outer") as [vio, vjo]: tir.bind(vio, i_o) tir.bind(vjo, j_o) tir.reads([A[vio * 16:vio * 16 + 16, vjo * 16:vjo * 16 + 16, ]]) tir.writes([B[vio * 16:vio * 16 + 16, vjo * 16:vjo * 16 + 16]]) for i_i, j_i in tir.grid(16, 16): with tir.block([128, 128], "B_inner") as [vi, vj]: tir.bind(vi, vio * 16 + i_i) tir.bind(vj, vjo * 16 + j_i) B[vi, vj] = A[vi, vj] * 2.0 for i_o, j_o, i_i, j_i in tir.grid(4, 4, 32, 32): with tir.block([128, 128], "C") as [vi, vj]: tir.bind(vi, i_o * 32 + i_i) tir.bind(vj, j_o * 32 + j_i) C[vi, vj] = B[vi, vj] + 1.0
def elementwise_reordered_with_predicate(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in tir.grid(128, 128, 128, 128): with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.bind(vl, l) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_with_wrong_block_var_type(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j, k in tir.grid(128, 128, 128): with tir.block([128, 128, tir.scan_axis(0, 128)], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def func() -> None: A = tir.alloc_buffer((128, 128), "float32") B = tir.alloc_buffer((128, 128), "float32") C = tir.alloc_buffer((128, 128), "float32") D = tir.alloc_buffer((128, 128), "float32") with tir.block([]): # Need add read/write region manually to avoid triggering block access region detector tir.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) tir.writes([A[0:12, 0:12]]) for i, j in tir.grid(8, 8): A[i, j] = B[0, 0] + C[0, 0] with tir.block([2, 2]) as [vi, vj]: tir.reads([ A[vi * 4 + 4:vi * 4 + 8, vj * 4 + 4:vj * 4 + 8], C[12:16, 12:16] ]) tir.writes([A[vi * 4 + 4:vi * 4 + 8, vj * 4 + 4:vj * 4 + 8]]) for i, j in tir.grid(4, 4): A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] tir.evaluate(D.data)
def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.serial(10, 128): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.thread_binding(0, 128, thread="threadIdx.x"): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.serial(0, 128, annotations={"useless_annotation": True}): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def blockized_after_compute_at(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128], "float32") B = tir.alloc_buffer([128, 128], "float32") C = tir.match_buffer(c, [128, 128], "float32") for i0_0, i1_0 in tir.grid(8, 8): for ax0, ax1 in tir.grid(16, 16): with tir.block([128, 128], "B") as [vi, vj]: tir.bind(vi, i0_0 * 16 + ax0) tir.bind(vj, i1_0 * 16 + ax1) B[vi, vj] = A[vi, vj] * 2.0 with tir.block([8, 8], "C_outer") as [vi_o, vj_o]: tir.bind(vi_o, i0_0) tir.bind(vj_o, i1_0) tir.reads( [B[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16, ]]) tir.writes([C[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16]]) for i0_1, i1_1 in tir.grid(16, 16): with tir.block([128, 128], "C_inner") as [vi, vj]: tir.bind(vi, vi_o * 16 + i0_1) tir.bind(vj, vj_o * 16 + i1_1) C[vi, vj] = B[vi, vj] + 1.0
def elementwise_split_case1(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [128, 128, 128]) B = tir.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, j3, k1, k2, k3 in tir.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i1 * 64 + i3) tir.bind(vj, j1 * 64 + j3) tir.bind(vk, k1 * 64 + k3) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) C = tir.match_buffer(c, [128, 128]) C_rf = tir.alloc_buffer([4, 128, 128]) for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid( 128, 128, 4, 8, 4): with tir.block( [4, 128, 128, tir.reduce_axis(0, 4), tir.reduce_axis(0, 8)], "update_rf") as [ vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer ]: tir.bind(vi2_inner_inner, i2_inner_inner) tir.bind(vi, i0) tir.bind(vj, i1) tir.bind(vi2_outer, i2_outer) tir.bind(vi2_inner_outer, i2_inner_outer) with tir.init(): C_rf[vi2_inner_inner, vi, vj] = 0.0 C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + (A[vi, ( ((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] * B[vj, ( ((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)]) for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4): with tir.block([tir.reduce_axis(0, 4), 128, 128], "update") as [ vi2_inner_inner_1, vi_1, vj_1, ]: tir.bind(vi2_inner_inner_1, i2_inner_inner_1) tir.bind(vi_1, i0_1) tir.bind(vj_1, i1_1) with tir.init(): C[vi_1, vj_1] = 0.0 C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
def cache_write_multi_consumer() -> None: A = tir.alloc_buffer((128)) B = tir.alloc_buffer((128)) C = tir.alloc_buffer((128)) A_global = tir.alloc_buffer((128)) for i in tir.grid(8): for j in tir.grid(16): with tir.block([128], "A_global") as [vi]: tir.bind(vi, i * 16 + j) A_global[vi] = 1.0 for j in tir.grid(16): with tir.block([128], "A") as [vi]: tir.bind(vi, i * 16 + j) A[vi] = A_global[vi] for j in tir.grid(16): with tir.block([128], "B") as [vi]: tir.bind(vi, i * 16 + j) B[vi] = A[vi] + 1.0 for i in tir.grid(128): with tir.block([128], "C") as [vi]: C[vi] = A[vi]
def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "main", "tir.noalias": True}) A = tir.match_buffer(a, [128, 128]) B = tir.match_buffer(b, [128, 128]) C = tir.match_buffer(c, [128, 128]) # body for x, y in tir.grid(128, 128): C.data[x * 128 + y] = 0.0 for k in tir.serial(0, 128): C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load( "float32", A.data, x * 128 + k ) * tir.load("float32", B.data, y * 128 + k)
def elementwise_with_loops_not_same_scope(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): with tir.block([128, 128], "A") as [vi, vj]: tir.bind(vi, i) tir.bind(vj, j) for k in tir.serial(0, 128): with tir.block([128], "B") as [vk]: tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def factorized(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [16, 16, 16], "float32") B = tir.match_buffer(b, [16], "float32") B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local") for j in tir.thread_binding(0, 16, thread="blockIdx.x"): for i_o in tir.thread_binding(0, 4, thread="threadIdx.x"): for i_i, k in tir.grid(4, 16): with tir.block([16, 16, tir.reduce_axis(0, 16)], "B_rf") as [vi, vj, vk]: tir.bind(vi, i_o * 4 + i_i) tir.bind(vj, j) tir.bind(vk, k) with tir.init(): B_rf_local[vi, vj] = 0.0 B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk] for i, k in tir.grid(16, 16): with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]: tir.bind(vi, i) tir.bind(vk, k) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + B_rf_local[vk, vi]
def tir_multi_output(a0: ty.handle, a1: ty.handle, b0: ty.handle, b1: ty.handle) -> None: m = tir.var("int32") n = tir.var("int32") A0 = tir.match_buffer(a0, (m, n)) A1 = tir.match_buffer(a1, (m, n)) B0 = tir.match_buffer(b0, (m, n)) B1 = tir.match_buffer(b1, (m, n)) for i0, i1 in tir.grid(m, n): with tir.block([m, n], "B.v0") as [i, j]: B0[i, j] = A0[i, j] + 2.0 with tir.block([m, n], "B.v1") as [i, j]: B1[i, j] = A1[i, j] * 3.0
def two_elementwise_after_compute_at(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128), "float32") B = tir.alloc_buffer((128, 128), "float32") C = tir.match_buffer(c, (128, 128), "float32") for i in range(0, 128): for ax0, ax1 in tir.grid(1, 128): with tir.block([128, 128], "B") as [vi, vj]: tir.bind(vi, i + ax0) tir.bind(vj, ax1) B[vi, vj] = A[vi, vj] * 2.0 for j in range(0, 128): with tir.block([128, 128], "B") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [16, 256, 256]) C = tir.match_buffer(c, [16]) C_rf = tir.alloc_buffer([16, 256]) for i0, i1, i2 in tir.grid(16, 256, 256): with tir.block([256, 16, tir.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: tir.bind(vi2, i2) tir.bind(b, i0) tir.bind(i, i1) with tir.init(): C_rf[b, vi2] = 0.0 C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) for i0_1, i2_1 in tir.grid(16, 256): with tir.block([tir.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: tir.bind(vi2_1, i2_1) tir.bind(b_1, i0_1) with tir.init(): C[b_1] = 0.0 C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
def elementwise_symbolic_split(a: ty.handle, b: ty.handle, n: ty.int32) -> None: A = tir.match_buffer(a, (128, 128, n)) B = tir.match_buffer(b, (128, 128, n)) for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)): with tir.block([128, 128, n], "B") as [vi, vj, vk]: tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n)) tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None: B = tir.match_buffer(b, [16, 16], "float32") C = tir.match_buffer(c, [16, 16], "float32") with tir.block([]): tir.reads([]) tir.writes(B[0:16, 0:16]) A = tir.allocate([256], "float32", "global") for i, j in tir.grid(16, 16): tir.store(A, i * 16 + j, 1) for i in range(0, 16): for j in range(0, 16): tir.evaluate(tir.load("float32", A, i * 16 + j)) for j in range(0, 16): tir.evaluate( tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, tir.float32(0), dtype="handle") ) for i, j in tir.grid(16, 16): with tir.block([16, 16]) as [vi, vj]: tir.bind(vi, i) tir.bind(vj, j) C[vi, vj] = B[vi, vj]