def transformed_func() -> None: A = tir.alloc_buffer([128, 128]) with tir.block([128, 128], "") as [i, j]: A[i, j] = tir.float32(0) with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]: B = tir.alloc_buffer([128, 128]) if k == 0: for ii, jj in tir.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in tir.grid(4, 4): with tir.block([], ""): tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) C = tir.alloc_buffer([128, 128]) for kk in tir.serial(0, 4): B[((i * 4) + ii), ((j * 4) + jj)] = (B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)]) for kk in tir.serial(0, 4): with tir.block([], ""): tir.reads([ B[((i * 4) + ii), ((j * 4) + jj)], C[((i * 4) + ii), ((k * 4) + kk)], ]) tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) D = tir.alloc_buffer([128, 128]) B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + (D[((j * 4) + jj), ( (k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)])
def get_valid_counts( data: ty.handle, valid_count: ty.handle, out: ty.handle, out_indices: ty.handle, score_threshold: ty.float32, id_index: ty.int32, score_index: ty.int32, ) -> None: data_buf = tir.match_buffer(data, (1, 2500, 6), "float32") valid_count_buf = tir.match_buffer(valid_count, (1, ), "int32") out_buf = tir.match_buffer(out, (1, 2500, 6), "float32") out_indices_buf = tir.match_buffer(out_indices, (1, 2500), "int32") with tir.block([1], "init") as [vi]: valid_count_buf[vi] = tir.int32(0) with tir.block([2500], "update") as [vj]: tir.reads([data_buf[vi, vj, 6]]) tir.writes([ valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6] ]) if (data_buf[vi, vj, score_index] > score_threshold) and ( (id_index < 0) or (data_buf[vi, vj, id_index] >= tir.float32(0))): for k in tir.serial(0, 6): out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k] out_indices_buf[vi, valid_count_buf[vi]] = vj valid_count_buf[vi] = valid_count_buf[vi] + 1 if vj >= valid_count_buf[vi]: for k in tir.serial(0, 6): out_buf[vi, vj, k] = tir.float32(-1) out_indices_buf[vi, vj] = tir.int32(-1)
def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None: C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with tir.block([], "root"): tir.reads([]) tir.writes([]) B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in tir.serial(0, 128): for ax1 in tir.serial(0, 128): with tir.block([128, 128], "B") as [vi, vj]: tir.block_attr({"buffer_dim_align": [0]}) tir.bind(vi, i0) tir.bind(vj, ax1) tir.reads([A[vi, vj]]) tir.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj] * tir.float32(2)) for i1 in tir.serial(0, 128): with tir.block([128, 128], "C") as [vi_1, vj_1]: tir.bind(vi_1, i0) tir.bind(vj_1, i1) tir.reads([B[vi_1, vj_1]]) tir.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [1024]) B = tir.match_buffer(b, [1024]) for i in tir.serial(0, 8): with tir.block([8]) as [vi]: tir.reads(A[vi * 128:vi * 128 + 128]) tir.writes(B[vi * 128:vi * 128 + 128]) A_cache = tir.alloc_buffer([1024]) with tir.block([8]) as [v]: tir.bind(v, vi) tir.reads([A[v * 128:v * 128 + 128]]) tir.writes([A_cache[v * 128:v * 128 + 128]]) tir.evaluate( tir.call_extern("test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32")) for j in tir.serial(0, 128): with tir.block([1024]) as [v]: tir.bind(v, ((vi * 128) + j)) tir.reads([A_cache[v]]) tir.writes([B[v]]) B[v] = A_cache[v]
def range_missing_args(a: ty.handle) -> None: A = tir.match_buffer(a, (16, 16), "float32") tir.attr(A, "realize_scope", "") tir.realize(A[0:16, 0:16]) for i in tir.serial(16): for j in tir.serial(0, 16): A[i, j] = 0.0
def undefined_buffer(a: ty.handle) -> None: A = tir.match_buffer(a, (16, 16), "float32") tir.attr(A, "realize_scope", "") tir.realize(C[0:16, 0:16]) for i in tir.serial(16): for j in tir.serial(0, 16): A[i, j] = 0.0
def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None: A = tir.match_buffer(a, [16, 16, 16]) C = tir.alloc_buffer([16, 16]) D = tir.alloc_buffer([16, 16]) E = tir.alloc_buffer([16, 16]) F = tir.match_buffer(f, [16, 16]) C_rf = tir.alloc_buffer([16, 16, 4]) for i, j1, k1o, k1i in tir.grid(16, 16, 4, 4): with tir.block([4, 16, 16, tir.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: tir.bind(vk1o, k1o) tir.bind(ci, i) tir.bind(cj, j1) tir.bind(vk1i, k1i) with tir.init(): C_rf[ci, cj, vk1o] = 0.0 C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] for i_1 in tir.serial(0, 16): for j1_1 in tir.serial(0, 16): for k1o_1 in tir.serial(0, 4): with tir.block([tir.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: tir.bind(vk1o_1, k1o_1) tir.bind(ci_1, i_1) tir.bind(cj_1, j1_1) with tir.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] for k2o, k2i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: tir.bind(di, i_1) tir.bind(dj, j1_1) tir.bind(dk, (k2o * 4) + k2i) with tir.init(): D[di, dj] = 0.0 D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] for j2 in tir.serial(0, 16): for k3o, k3i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: tir.bind(ei, i_1) tir.bind(ej, j2) tir.bind(ek, (k3o * 4) + k3i) with tir.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] for k4o, k4i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: tir.bind(fi, i_1) tir.bind(fj, j2) tir.bind(fk, (k4o * 4) + k4i) with tir.init(): F[fi, fj] = 0.0 F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
def flattened_elementwise_func(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (16, 16), "float32") C = tir.match_buffer(c, (16, 16), "float32") for i in tir.serial(0, 16): B_new = tir.allocate([16], "float32", "global") for j in tir.serial(0, 16): B_new[j] = tir.load("float32", A.data, ((i * 16) + j)) + 1.0 for j in tir.serial(0, 16): C.data[((i * 16) + j)] = tir.load("float32", B_new, j) * 2.0
def bound_to_thread(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) C = tir.match_buffer(c, [128, 128]) B = tir.alloc_buffer([128, 128], scope="shared") for i in tir.thread_binding(0, 128, thread="threadIdx.x"): for j in tir.serial(0, 128): with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 for j in tir.serial(0, 128): with tir.block([128, 128], "C") as [vi, vj]: C[vj, vi] = B[vj, vi] + 1.0
def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) C = tir.alloc_buffer((128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.serial(0, 128): with tir.block([128, 128, 128], "C") as [vi, vj, vk]: C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in tir.serial(0, 128): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: B[vi, vj, vk] = C[vi, vj, vk] * 2.0
def warp_memory(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) C = tir.match_buffer(c, [128, 128]) B = tir.alloc_buffer([128, 4, 32], scope="warp") for i_o in tir.thread_binding(0, 4, thread="threadIdx.y"): for i_i in tir.thread_binding(0, 32, thread="threadIdx.x"): for j in tir.serial(0, 128): with tir.block([4, 32, 128], "B") as [warp_id, lane_id, vj]: B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0 for j in tir.serial(0, 128): with tir.block([4, 32, 128], "C") as [warp_id, lane_id, vj]: C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0
def read_out_of_bound(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [16], "float32") B = tir.alloc_buffer([16], "float32") C = tir.match_buffer(c, [16], "float32") for i in tir.serial(0, 16): with tir.block([16], "B") as [v]: B[v] = A[v] for j in tir.serial(0, 16): with tir.block([16], "C") as [v]: tir.reads(B[v:v + 2]) C[v] = tir.if_then_else(v < 15, tir.max(B[v], B[v + 1]), B[v], dtype="float32")
def tiled_after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128], "float32") B = tir.alloc_buffer([128, 128], "float32") C = tir.match_buffer(c, [128, 128], "float32") for i_0, j_0, i_1 in tir.grid(8, 8, 16): for j_1 in tir.serial(0, 16): with tir.block([128, 128], "B") as [vi, vj]: tir.bind(vi, i_0 * 16 + i_1) tir.bind(vj, j_0 * 16 + j_1) B[vi, vj] = A[vi, vj] * 2.0 for j_1 in tir.serial(0, 16): with tir.block([128, 128], "C") as [vi, vj]: tir.bind(vi, i_0 * 16 + i_1) tir.bind(vj, j_0 * 16 + j_1) C[vi, vj] = B[vi, vj] + 1.0
def elementwise_under_loop(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) C = tir.match_buffer(c, (128, 128)) B = tir.alloc_buffer((128, 128)) for i in tir.serial(0, 128): for j in tir.serial(0, 128): with tir.block([128, 128], "B") as [vi, vj]: tir.bind(vi, i) tir.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 for j in tir.serial(0, 128): with tir.block([128, 128], "C") as [vi, vj]: tir.bind(vi, i) tir.bind(vj, j) C[vi, vj] = B[vi, vj] + 1.0
def elementwise_dependent_loop(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128, 128)) for i in tir.serial(0, 128): for j, k, l in tir.grid(128, i, 128): with tir.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]: B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def read_out_of_bound_after_compute_at(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [16], "float32") B = tir.alloc_buffer([16], "float32") C = tir.match_buffer(c, [16], "float32") for j in tir.serial(0, 16): for i in tir.serial(0, tir.min(1, 15 - j) + 1): with tir.block([16], "B") as [v]: tir.bind(v, j + i) B[v] = A[v] with tir.block([16], "C") as [v]: tir.bind(v, j) tir.reads([B[v:v + 2]]) C[v] = tir.if_then_else(v < 15, tir.max(B[v], B[v + 1]), B[v], dtype="float32")
def equal_ranked_threads(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) C = tir.match_buffer(c, [128, 128]) B = tir.alloc_buffer([128, 128], scope="shared") for i_o in tir.thread_binding(0, 16, thread="threadIdx.x"): for i_i in tir.thread_binding(0, 8, thread="threadIdx.y"): for j in tir.serial(0, 128): with tir.block([128, 128], "B") as [vi, vj]: tir.bind(vi, i_o * 8 + i_i) tir.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 for j in tir.serial(0, 128): with tir.block([128, 128], "C") as [vi, vj]: tir.bind(vi, i_o * 8 + i_i) tir.bind(vj, j) C[vj, vi] = B[vj, vi] + 1.0
def multiple_reduction_blocks(a: ty.handle, f: ty.handle) -> None: A = tir.match_buffer(a, (16, 16, 16)) C = tir.alloc_buffer((16, 16)) D = tir.alloc_buffer((16, 16)) E = tir.alloc_buffer((16, 16)) F = tir.match_buffer(f, (16, 16)) for i in tir.serial(0, 16): for j1 in tir.serial(0, 16): for k1o, k1i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "C") as [ci, cj, ck]: tir.bind(ci, i) tir.bind(cj, j1) tir.bind(ck, k1o * 4 + k1i) with tir.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] for k2o, k2i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: tir.bind(di, i) tir.bind(dj, j1) tir.bind(dk, k2o * 4 + k2i) with tir.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] for j2 in tir.serial(0, 16): for k3o, k3i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: tir.bind(ei, i) tir.bind(ej, j2) tir.bind(ek, k3o * 4 + k3i) with tir.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: tir.bind(fi, i) tir.bind(fj, j2) tir.bind(fk, k4o * 4 + k4i) with tir.init(): F[fi, fj] = 0.0 F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
def unsupported_function_call(a: ty.handle) -> None: A = tir.match_buffer(a, (16, 16), "float32") tir.attr(A, "realize_scope", "") tir.realize(A[0:16, 0:16]) for i in tir.const_range(16): for j in tir.serial(0, 16): A[i, j] = 0.0
def elementwise_non_single_branch(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) C = tir.alloc_buffer((128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.serial(0, 128): with tir.block([128, 128, 128], "C") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) C[vi, vj, vk] = A[vi, vj, vk] * 2.0 for k in tir.serial(0, 128): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) B[vi, vj, vk] = C[vi, vj, vk] * 2.0
def transformed_element_func(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [16, 16]) C = tir.match_buffer(c, [16, 16]) for i_0 in range(0, 16): with tir.block([]): tir.reads([A[i_0, 0:16]]) tir.writes([C[i_0, 0:16]]) B = tir.alloc_buffer([16, 16]) for j_0 in tir.serial(0, 16): with tir.block([16, 16], "") as [i, j]: tir.bind(i, i_0) tir.bind(j, j_0) B[i, j] = A[i, j] + 1.0 for j_0 in tir.serial(0, 16): with tir.block([16, 16], "") as [i, j]: tir.bind(i, i_0) tir.bind(j, j_0) C[i, j] = B[i, j] * 2.0
def elementwise_fused(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for fused in tir.serial(0, 2097152): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, tir.floordiv(fused, 16384)) tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128)) tir.bind(vk, tir.floormod(fused, 128)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def opaque_access_func() -> None: A = tir.alloc_buffer([1024]) B = tir.alloc_buffer([1024]) for i in tir.serial(0, 8): with tir.block([8]) as [v]: tir.bind(v, i) tir.reads([A[v * 128 : v * 128 + 128]]) tir.writes([B[v * 128 : v * 128 + 128]]) tir.evaluate( tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") )
def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable A = tir.match_buffer(a, [2048, 2048], "float32") B = tir.match_buffer(b, [2048, 2048], "float32") C = tir.match_buffer(c, [2048, 2048], "float32") A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") with tir.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] with tir.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: A_shared_local[v0, v1] = A_shared[v0, v1] with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] for by in tir.thread_binding(0, 32, thread="blockIdx.y"): for bx in tir.thread_binding(0, 32, thread="blockIdx.x"): for vy in tir.thread_binding(0, 2, thread="vthread.y"): for vx in tir.thread_binding(0, 2, thread="vthread.x"): for ty in tir.thread_binding(0, 8, thread="threadIdx.y"): for tx in tir.thread_binding(0, 8, thread="threadIdx.x"): for k_0 in tir.serial(0, 256): for k_1 in tir.unroll(0, 8): for _, i, j in tir.grid(1, 4, 4): with tir.block([ 2048, 2048, tir.reduce_axis(0, 2048) ], "C") as [vi, vj, vk]: tir.bind( vi, by * 64 + vy * 32 + ty * 4 + i) tir.bind( vj, bx * 64 + vx * 32 + tx * 4 + j) tir.bind(vk, k_0 * 8 + k_1) with tir.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[ vi, vj] + A_shared_local[ vk, vi] * B_shared_local[ vk, vj] for i, j in tir.grid(4, 4): with tir.block([2048, 2048], "C_local") as [vi, vj]: tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj]
def rowsum_not_serial(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, )) for i in tir.serial(0, 128): for k in tir.parallel(0, 128): with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: tir.bind(vi, i) tir.bind(vk, k) with tir.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.serial(0, 128, annotations={"useless_annotation": True}): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_symbolic_fused(a: ty.handle, b: ty.handle, n: ty.int32) -> None: A = tir.match_buffer(a, (128, 128, n)) B = tir.match_buffer(b, (128, 128, n)) for i_j_k_fused in tir.serial(0, (n * 16384)): with tir.block([128, 128, n], "B") as [vi, vj, vk]: tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128))) tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128)) tir.bind(vk, tir.floormod(i_j_k_fused, n)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): for k in tir.serial(10, 128): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) n = tir.var("int32") stride = tir.var("int32") stride_1 = tir.var("int32") stride_2 = tir.var("int32") stride_3 = tir.var("int32") A_1 = tir.match_buffer( A, [n], strides=[stride], elem_offset=0, align=128, offset_factor=1, type="auto", ) B_1 = tir.match_buffer( B, [n], strides=[stride_1], elem_offset=0, align=128, offset_factor=1, type="auto", ) C_1 = tir.match_buffer( C, [n], strides=[stride_2], elem_offset=0, align=128, offset_factor=1, type="auto", ) d_1 = tir.match_buffer( d, [n], strides=[stride_3], elem_offset=0, align=128, offset_factor=1, type="auto", ) # body for i in tir.serial(0, n): d_1.data[(i * stride_3)] = (tir.load("float32", A_1.data, (i * stride)) * tir.load("float32", B_1.data, (i * stride_1))) + tir.load( "float32", C_1.data, (i * stride_2))
def elementwise_with_loops_not_same_scope(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128)) for i, j in tir.grid(128, 128): with tir.block([128, 128], "A") as [vi, vj]: tir.bind(vi, i) tir.bind(vj, j) for k in tir.serial(0, 128): with tir.block([128], "B") as [vk]: tir.bind(vk, k) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0