def symbolic_match(a: ty.handle, b: ty.handle, n: ty.int32, m: ty.int32) -> None: A = tir.match_buffer(a, (n * m, m)) B = tir.match_buffer(b, (n * 2, m * 4)) for i in range(0, n): with tir.block([]): tir.reads([]) tir.writes([A[i * m:i * m + n, 0:m], B[i * n:i * n + 2, 0:m * 4]]) Bs_0 = tir.var("int32") Bs_1 = tir.var("int32") sub_A = tir.match_buffer(A[i * m:i * m + m, 0:m], (m, m), offset_factor=1) sub_B = tir.match_buffer(B[i * n:i * n + 2, 0:m * 4], (2, m * 4), strides=[Bs_0, Bs_1], offset_factor=1) for ii, jj in tir.grid(m, m): sub_A[ii, jj] = 1 for j in range(0, 4): tir.evaluate( tir.intrin_test( sub_B.data, sub_B.elem_offset, sub_B.strides[0], sub_B.strides[1], sub_B.shape[0], sub_B.shape[1], dtype="handle", ))
def get_valid_counts( data: ty.handle, valid_count: ty.handle, out: ty.handle, out_indices: ty.handle, score_threshold: ty.float32, id_index: ty.int32, score_index: ty.int32, ) -> None: data_buf = tir.match_buffer(data, (1, 2500, 6), "float32") valid_count_buf = tir.match_buffer(valid_count, (1, ), "int32") out_buf = tir.match_buffer(out, (1, 2500, 6), "float32") out_indices_buf = tir.match_buffer(out_indices, (1, 2500), "int32") with tir.block([1], "init") as [vi]: valid_count_buf[vi] = tir.int32(0) with tir.block([2500], "update") as [vj]: tir.reads([data_buf[vi, vj, 6]]) tir.writes([ valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6] ]) if (data_buf[vi, vj, score_index] > score_threshold) and ( (id_index < 0) or (data_buf[vi, vj, id_index] >= tir.float32(0))): for k in tir.serial(0, 6): out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] out_indices_buf[vi, valid_count_buf[vi]] = vj valid_count_buf[vi] = valid_count_buf[vi] + 1 if vj >= valid_count_buf[vi]: for k in tir.serial(0, 6): out_buf[vi, vj, k] = tir.float32(-1) out_indices_buf[vi, vj] = tir.int32(-1)
def transformed_recursive_match(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (64, 64, 64)) B = tir.match_buffer(b, (64, 64, 64)) for i, j, k in tir.grid(64, 4, 4): with tir.block([]): tir.reads([]) tir.writes([ A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], B[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], ]) for jj, kk in tir.grid(4, 4): with tir.block([]): tir.reads([]) tir.writes([ A[i, j * 16 + jj * 4:j * 16 + jj * 4 + 4, k * 16 + kk * 4:k * 16 + kk * 4 + 4, ], B[i, j * 16 + jj * 4:j * 16 + jj * 4 + 4, k * 16 + kk * 4:k * 16 + kk * 4 + 4, ], ]) tir.evaluate( tir.intrin_test( A.data, i * 4096 + j * 1024 + jj * 256 + k * 16 + kk * 4, 64, 1, 4, 4, dtype="handle", )) for jjj, kkk in tir.grid(4, 4): B[i, j * 16 + jj * 4 + jjj, k * 16 + kk * 4 + kkk] = 1
def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (32, 64, 128)) B = tir.match_buffer(b, (64, 64, 64)) for i, j, k in tir.grid(2, 64, 8): with tir.block([]): tir.reads([]) tir.writes(A[i * 16:i * 16 + 16, j, k * 16:k * 16 + 16]) tir.evaluate( tir.intrin_test( A.data, i * 131072 + j * 128 + k * 16, 8192, 128, 16, 1, dtype="handle", )) for i, j, k in tir.grid(64, 2, 8): with tir.block([]): tir.reads([]) tir.writes(B[i, j * 32:j * 32 + 32, k * 8:k * 8 + 8]) tir.evaluate( tir.intrin_test( B.data, i * 4096 + j * 2048 + k * 8, 64, 1, 32, 8, dtype="handle", ))
def expected_recursive_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: index_buf = tir.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) with tir.block([], "root"): tir.reads([]) tir.writes([]) out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in tir.grid(16, 16): with tir.block([16, 16], "") as [vi, vj]: tir.bind(vi, i0) tir.bind(vj, i1) tir.reads([data_buf[0:16, 0:16], index_buf[0]]) tir.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]]
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None: B = tir.match_buffer(b, [16, 16], "float32") C = tir.match_buffer(c, [16, 16], "float32") with tir.block([]): tir.reads([]) tir.writes(B[0:16, 0:16]) A = tir.allocate([256], "float32", "global") for i, j in tir.grid(16, 16): tir.store(A, i * 16 + j, 1) for i in range(0, 16): for j in range(0, 16): tir.evaluate(tir.load("float32", A, i * 16 + j)) for j in range(0, 16): tir.evaluate( tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, tir.float32(0), dtype="handle")) for i, j in tir.grid(16, 16): with tir.block([16, 16]) as [vi, vj]: tir.bind(vi, i) tir.bind(vj, j) C[vi, vj] = B[vi, vj]
def transformed_func() -> None: A = tir.alloc_buffer([128, 128]) with tir.block([128, 128], "") as [i, j]: A[i, j] = tir.float32(0) with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]: B = tir.alloc_buffer([128, 128]) if k == 0: for ii, jj in tir.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in tir.grid(4, 4): with tir.block([], ""): tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) C = tir.alloc_buffer([128, 128]) for kk in tir.serial(0, 4): B[((i * 4) + ii), ((j * 4) + jj)] = (B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)]) for kk in tir.serial(0, 4): with tir.block([], ""): tir.reads([ B[((i * 4) + ii), ((j * 4) + jj)], C[((i * 4) + ii), ((k * 4) + kk)], ]) tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) D = tir.alloc_buffer([128, 128]) B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + (D[((j * 4) + jj), ( (k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)])
def opaque_access(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [1024]) B = tir.match_buffer(b, [1024]) A_cache = tir.alloc_buffer([1024]) for i in tir.serial(0, 8): with tir.block([8]) as [vi]: with tir.block([8]) as [v]: tir.bind(v, vi) tir.reads([A[(v * 128):((v * 128) + 128)]]) tir.writes([A_cache[(v * 128):((v * 128) + 128)]]) tir.evaluate( tir.call_extern( "test", A_cache.data, (v * 128), 128, A.data, (v * 128), 128, dtype="float32", )) for j in tir.serial(0, 128): with tir.block([1024]) as [v]: tir.bind(v, ((vi * 128) + j)) tir.reads([A_cache[v]]) tir.writes([B[v]]) B[v] = A_cache[v]
def fail_match_load(a: ty.handle) -> None: A = tir.match_buffer(a, (8, 8)) for i, j in tir.grid(8, 8): with tir.block([]): tir.reads(A[i, j]) tir.writes([]) sub_A = tir.match_buffer(A[i, j], ()) tir.evaluate(tir.load("float32", sub_A.data, 0))
def fail_match_store(a: ty.handle) -> None: A = tir.match_buffer(a, (8, 8)) for i, j in tir.grid(8, 8): with tir.block([]): tir.reads([]) tir.writes(A[i, j]) sub_A = tir.match_buffer(A[i, j], ()) sub_A.data[0] = 1
def transformed_buffer_load_store(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (16, 16, 16)) C = tir.match_buffer(c, (16, 16)) for i, j, k in tir.grid(4, 16, 8): with tir.block([]): tir.reads(C[i * 4:i * 4 + 4, k * 2:k * 2 + 2]) tir.writes(A[i * 4:i * 4 + 4, j, k * 2:k * 2 + 2]) for ii, kk in tir.grid(4, 2): A[i * 4 + ii, j, k * 2 + kk] += C[i * 4 + ii, k * 2 + kk]
def compacted_unit_loop_func(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (32), "float32") C = tir.match_buffer(c, (32), "float32") for x, y, z in tir.grid(4, 1, 8): with tir.block([]) as []: tir.reads(A[x * 8 + y * 8 + z]) tir.writes(C[x * 8 + y * 8 + z]) C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0
def compacted_predicate_func(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (32), "float32") C = tir.match_buffer(c, (32), "float32") for i, j in tir.grid(5, 7): with tir.block([]) as []: tir.reads(A[i * 7 + j]) tir.writes(C[i * 7 + j]) tir.where(i * 7 + j < 32) C[i * 7 + j] = A[i * 7 + j] + 1.0
def opaque_access_load(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: tir.reads(B[0:128, 0:128]) tir.writes(C[0:128, 0:128]) C[vi, vj] = tir.load("float32", B.data, vi * 128 + vj) + 1.0
def recursive_match(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (64, 64, 64)) B = tir.match_buffer(b, (64, 64, 64)) for i, j, k in tir.grid(64, 4, 4): with tir.block([]): tir.reads([]) tir.writes([ A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], B[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], ]) As_0 = tir.var("int32") As_1 = tir.var("int32") sub_A = tir.match_buffer( A[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], (16, 16), strides=[As_0, As_1], offset_factor=1, ) sub_B = tir.match_buffer( B[i, j * 16:j * 16 + 16, k * 16:k * 16 + 16], (16, 16), offset_factor=1, ) for jj, kk in tir.grid(4, 4): with tir.block([]): tir.reads([]) tir.writes([ sub_A[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4], sub_B[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4], ]) Ass_0 = tir.var("int32") Ass_1 = tir.var("int32") sub_sub_A = tir.match_buffer( sub_A[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4], (4, 4), strides=[Ass_0, Ass_1], offset_factor=1, ) sub_sub_B = tir.match_buffer( sub_B[jj * 4:jj * 4 + 4, kk * 4:kk * 4 + 4], (4, 4), offset_factor=1, ) tir.evaluate( tir.intrin_test( sub_sub_A.data, sub_sub_A.elem_offset, sub_sub_A.strides[0], sub_sub_A.strides[1], sub_sub_A.shape[0], sub_sub_A.shape[1], dtype="handle", )) for jjj, kkk in tir.grid(4, 4): sub_sub_B[jjj, kkk] = 1
def elementwise_subblock_uncovered(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128), "float32") C = tir.match_buffer(c, (128, 128), "float32") B = tir.alloc_buffer((128, 128), "float32") with tir.block([32, 32], "B") as [vi, vj]: tir.reads([A[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) tir.writes([B[vi * 4 : vi * 4 + 2, vj * 4 : vj * 4 + 2]]) with tir.block([2, 2], "B_sub") as [vi_i, vj_i]: B[vi * 4 + vi_i, vj * 4 + vj_i] = A[vi * 4 + vi_i, vj * 4 + vj_i] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def elementwise_with_wrong_block_var_type(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j, k in tir.grid(128, 128, 128): with tir.block([128, 128, tir.scan_axis(0, 128)], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def opaque_access_func() -> None: A = tir.alloc_buffer([1024]) B = tir.alloc_buffer([1024]) for i in tir.serial(0, 8): with tir.block([8]) as [v]: tir.bind(v, i) tir.reads([A[v * 128 : v * 128 + 128]]) tir.writes([B[v * 128 : v * 128 + 128]]) tir.evaluate( tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") )
def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [128, 128, 128]) B = tir.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, ((i1 * 64) + i3)) tir.bind(vj, ((j1 * 32) + j2)) tir.bind(vk, ((k1 * 8) + k2)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_fused(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for fused in tir.serial(0, 2097152): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, tir.floordiv(fused, 16384)) tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128)) tir.bind(vk, tir.floormod(fused, 128)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def buffer_shape_mismatch(a: ty.handle) -> None: A = tir.match_buffer(a, (8, 8)) for i, j in tir.grid(8, 2): with tir.block([]): tir.reads([]) tir.writes([A[i, j * 4:j * 4 + 4]]) sub_A = tir.match_buffer( A[i, j * 4:j * 4 + 4], (5)) # error: shape mismatched between 4 and 5 for jj in range(0, 4): sub_A[i, j * 4 + jj] = 1
def param_buffer_access_func(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (20, 20), "float32") B = tir.match_buffer(c, (20, 20), "float32") for i in range(0, 16): with tir.block([]): tir.reads(A[i, 0:16]) tir.writes(B[i, 0:16]) for j in range(0, 16): with tir.block([]) as []: tir.reads(A[i, j]) tir.writes(B[i, j]) B[i, j] = A[i, j] + 1.0
def unschedulable_func(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (16, 16), "float32") C = tir.match_buffer(c, (16, 16), "float32") for i in range(0, 16): with tir.block([]): tir.reads(A[i, 0:16]) tir.writes(C[i, 0:16]) B = tir.alloc_buffer((16, 16), "float32") for j in range(0, 16): tir.store(B.data, i * 16 + j, A[i, j] + 1.0) for j in range(0, 16): C[i, j] = B[i, j] * 2.0
def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.serial(0, 128, annotations={"useless_annotation": True}): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_symbolic_fused(a: ty.handle, b: ty.handle, n: ty.int32) -> None: A = tir.match_buffer(a, (128, 128, n)) B = tir.match_buffer(b, (128, 128, n)) for i_j_k_fused in tir.serial(0, (n * 16384)): with tir.block([128, 128, n], "B") as [vi, vj, vk]: tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128))) tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128)) tir.bind(vk, tir.floormod(i_j_k_fused, n)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_split_case1(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [128, 128, 128]) B = tir.match_buffer(b, [128, 128, 128]) for i1, i2, i3, j1, j2, j3, k1, k2, k3 in tir.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i1 * 64 + i3) tir.bind(vj, j1 * 64 + j3) tir.bind(vk, k1 * 64 + k3) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.serial(10, 128): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.thread_binding(0, 128, thread="threadIdx.x"): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def compacted_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, (32), "float32") D = tir.match_buffer(d, (32), "float32") for i in range(0, 32): with tir.block([]) as []: tir.reads(A[i]) tir.writes(D[i]) B = tir.alloc_buffer((32, )) C = tir.alloc_buffer((32, )) B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0
def elementwise_symbolic_split(a: ty.handle, b: ty.handle, n: ty.int32) -> None: A = tir.match_buffer(a, (128, 128, n)) B = tir.match_buffer(b, (128, 128, n)) for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)): with tir.block([128, 128, n], "B") as [vi, vj, vk]: tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n)) tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0