def cache_write_under_scope(b: ty.handle, c: ty.handle) -> None: A = tir.alloc_buffer((128, 128)) B = tir.match_buffer(b, (128, 128)) C = tir.match_buffer(c, (128, 128)) A_global = tir.alloc_buffer((128, 128)) with tir.block([8, 8], "scope") as [i, j]: A_local = tir.alloc_buffer((128, 128), scope="local") B_global = tir.alloc_buffer((128, 128)) for x, y in tir.grid(16, 16): with tir.block([128, 128], "A_local") as [vi, vj]: tir.bind(vi, i * 16 + x) tir.bind(vj, j * 16 + y) A_local[vi, vj] = 1.0 for x, y in tir.grid(16, 16): with tir.block([128, 128], "A") as [vi, vj]: tir.bind(vi, i * 16 + x) tir.bind(vj, j * 16 + y) A_global[vi, vj] = A_local[vi, vj] for x, y in tir.grid(16, 16): with tir.block([128, 128], "B_global") as [vi, vj]: tir.bind(vi, i * 16 + x) tir.bind(vj, j * 16 + y) B_global[vi, vj] = A_global[vi, vj] + 1.0 for x, y in tir.grid(16, 16): with tir.block([128, 128], "B_global") as [vi, vj]: tir.bind(vi, i * 16 + x) tir.bind(vj, j * 16 + y) B[vi, vj] = B_global[vi, vj] with tir.block([128, 128], "A_global") as [vi, vj]: A[vi, vj] = A_global[vi, vj] with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] * 2.0
def transformed_func() -> None: A = tir.alloc_buffer([128, 128]) with tir.block([128, 128], "") as [i, j]: A[i, j] = tir.float32(0) with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]: B = tir.alloc_buffer([128, 128]) if k == 0: for ii, jj in tir.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in tir.grid(4, 4): with tir.block([], ""): tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) C = tir.alloc_buffer([128, 128]) for kk in tir.serial(0, 4): B[((i * 4) + ii), ((j * 4) + jj)] = (B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)]) for kk in tir.serial(0, 4): with tir.block([], ""): tir.reads([ B[((i * 4) + ii), ((j * 4) + jj)], C[((i * 4) + ii), ((k * 4) + kk)], ]) tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) D = tir.alloc_buffer([128, 128]) B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + (D[((j * 4) + jj), ( (k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)])
def compacted_complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: A = tir.match_buffer(a, (8, 8), "float32") C = tir.match_buffer(c, (8, 8), "float32") for i in range(0, 8): with tir.block([]): tir.reads(A[0, 8]) tir.writes(C[0, 8]) B = tir.alloc_buffer((1, 8), "float32") for j in range(0, 4): with tir.block([]) as []: D = tir.alloc_buffer((6, 1), "float32") tir.reads(A[i, j]) tir.writes(B[0, j]) for k in range(4, 8): D[k - 2, 0] = 1.0 for k in range(2, 4): tir.store(B.data, j, A[i, j] + D[k - 2, 0]) for j in range(3, 5): with tir.block([]) as []: tir.reads(B[0, j]) tir.writes(C[i, j]) C[i, j] = B[0, j] for j in range(6, 8): with tir.block([]) as []: tir.reads(B[0, j]) tir.writes(C[i, j]) C[i, j] = B[0, j]
def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None: A = tir.match_buffer(a, [16, 16, 16]) C = tir.alloc_buffer([16, 16]) D = tir.alloc_buffer([16, 16]) E = tir.alloc_buffer([16, 16]) F = tir.match_buffer(f, [16, 16]) C_rf = tir.alloc_buffer([16, 16, 4]) for i, j1, k1o, k1i in tir.grid(16, 16, 4, 4): with tir.block([4, 16, 16, tir.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: tir.bind(vk1o, k1o) tir.bind(ci, i) tir.bind(cj, j1) tir.bind(vk1i, k1i) with tir.init(): C_rf[ci, cj, vk1o] = 0.0 C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] for i_1 in tir.serial(0, 16): for j1_1 in tir.serial(0, 16): for k1o_1 in tir.serial(0, 4): with tir.block([tir.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: tir.bind(vk1o_1, k1o_1) tir.bind(ci_1, i_1) tir.bind(cj_1, j1_1) with tir.init(): C[ci_1, cj_1] = 0.0 C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] for k2o, k2i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: tir.bind(di, i_1) tir.bind(dj, j1_1) tir.bind(dk, (k2o * 4) + k2i) with tir.init(): D[di, dj] = 0.0 D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] for j2 in tir.serial(0, 16): for k3o, k3i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: tir.bind(ei, i_1) tir.bind(ej, j2) tir.bind(ek, (k3o * 4) + k3i) with tir.init(): E[ei, ej] = 0.0 E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] for k4o, k4i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: tir.bind(fi, i_1) tir.bind(fj, j2) tir.bind(fk, (k4o * 4) + k4i) with tir.init(): F[fi, fj] = 0.0 F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
def func_multi_producer() -> None: A = tir.alloc_buffer((128)) B = tir.alloc_buffer((128)) with tir.block([128], "A0") as [vi]: A[vi] = 1.0 with tir.block([128], "A1") as [vi]: A[vi] = 2.0 with tir.block([128], "B") as [vi]: B[vi] = A[vi]
def fail_multi_reader_writer(a: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.alloc_buffer((128, 128)) D = tir.match_buffer(d, (128, 128)) with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 C[vi, vj] = A[vi, vj] + 2.0 with tir.block([128, 128], "C") as [vi, vj]: D[vi, vj] = B[vi, vj] + C[vi, vj]
def opaque_access_func() -> None: A = tir.alloc_buffer([1024]) B = tir.alloc_buffer([1024]) for i in tir.serial(0, 8): with tir.block([8]) as [v]: tir.bind(v, i) tir.reads([A[v * 128 : v * 128 + 128]]) tir.writes([B[v * 128 : v * 128 + 128]]) tir.evaluate( tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32") )
def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None: # pylint: disable=undefined-loop-variable A = tir.match_buffer(a, [2048, 2048], "float32") B = tir.match_buffer(b, [2048, 2048], "float32") C = tir.match_buffer(c, [2048, 2048], "float32") A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared") A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local") with tir.block([2048, 2048], "A_shared") as [v0, v1]: A_shared[v0, v1] = A[v0, v1] with tir.block([2048, 2048], "B_shared") as [v0, v1]: B_shared[v0, v1] = B[v0, v1] with tir.block([2048, 2048], "A_shared_local") as [v0, v1]: A_shared_local[v0, v1] = A_shared[v0, v1] with tir.block([2048, 2048], "B_shared_local") as [v0, v1]: B_shared_local[v0, v1] = B_shared[v0, v1] for by in tir.thread_binding(0, 32, thread="blockIdx.y"): for bx in tir.thread_binding(0, 32, thread="blockIdx.x"): for vy in tir.thread_binding(0, 2, thread="vthread.y"): for vx in tir.thread_binding(0, 2, thread="vthread.x"): for ty in tir.thread_binding(0, 8, thread="threadIdx.y"): for tx in tir.thread_binding(0, 8, thread="threadIdx.x"): for k_0 in tir.serial(0, 256): for k_1 in tir.unroll(0, 8): for _, i, j in tir.grid(1, 4, 4): with tir.block([ 2048, 2048, tir.reduce_axis(0, 2048) ], "C") as [vi, vj, vk]: tir.bind( vi, by * 64 + vy * 32 + ty * 4 + i) tir.bind( vj, bx * 64 + vx * 32 + tx * 4 + j) tir.bind(vk, k_0 * 8 + k_1) with tir.init(): C_local[vi, vj] = 0.0 C_local[vi, vj] = C_local[ vi, vj] + A_shared_local[ vk, vi] * B_shared_local[ vk, vj] for i, j in tir.grid(4, 4): with tir.block([2048, 2048], "C_local") as [vi, vj]: tir.bind(vi, by * 64 + vy * 32 + ty * 4 + i) tir.bind(vj, bx * 64 + vx * 32 + tx * 4 + j) C[vi, vj] = C_local[vi, vj]
def compacted_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, (32), "float32") D = tir.match_buffer(d, (32), "float32") for i in range(0, 32): with tir.block([]) as []: tir.reads(A[i]) tir.writes(D[i]) B = tir.alloc_buffer((32, )) C = tir.alloc_buffer((32, )) B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0
def continuous_cache_write(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.match_buffer(c, (128, 128)) B_shared = tir.alloc_buffer((128, 128), scope="shared") B_local = tir.alloc_buffer((128, 128), scope="local") with tir.block([128, 128], "B") as [vi, vj]: B_local[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "B") as [vi, vj]: B_shared[vi, vj] = B_local[vi, vj] with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = B_shared[vi, vj] with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def cache_write_elementwise(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) C = tir.match_buffer(c, (128, 128)) B = tir.alloc_buffer((128, 128)) B_global = tir.alloc_buffer((128, 128), scope="local") C_local = tir.alloc_buffer((128, 128)) with tir.block([128, 128], "B_global") as [vi, vj]: B_global[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = B_global[vi, vj] with tir.block([128, 128], "C_local") as [vi, vj]: C_local[vi, vj] = B[vi, vj] + 1.0 with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = C_local[vi, vj]
def fail_all_producers_under_loop(a: ty.handle, d: ty.handle) -> None: A = tir.match_buffer(a, (128, 128), "float32") B = tir.alloc_buffer((128, 128), "float32") C = tir.alloc_buffer((128, 128), "float32") D = tir.match_buffer(d, (128, 128), "float32") for i, j in tir.grid(128, 128): with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 for i, j in tir.grid(128, 128): with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = A[vi, vj] + 1.0 for i, j in tir.grid(128, 128): with tir.block([128, 128], "D") as [vi, vj]: D[vi, vj] = B[vi, vj] + C[vi, vj]
def match_buffer_func() -> None: C = tir.alloc_buffer((128, 128)) with tir.block([128]) as [vi]: C0 = tir.match_buffer(C[vi, 0:128], (128)) with tir.block([128]) as [jj]: C1 = tir.match_buffer(C0[jj], ()) C1[()] = 0
def opaque_block_func() -> None: with tir.block([], "root"): A = tir.alloc_buffer((16, 16), "float32") B = tir.alloc_buffer((16, 16), "float32") tir.reads([]) tir.writes([]) # Need add read/write region manually to avoid triggering block access region detector for i in range(0, 16): with tir.block([]): tir.reads(A[i, 0:16]) tir.writes([B[i, 0:16]]) for j in range(0, 16): with tir.block([]): tir.reads(A[i, j]) tir.writes(B[i, j]) B[i, j] = A[i, j] + 1.0
def func_with_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: data_buf = tir.match_buffer(data, (16, 16), "float32") index_buf = tir.match_buffer(index, (1, ), "int32") out_buf = tir.alloc_buffer((16, 16), "float32") with tir.block([16, 16]) as [vi, vj]: out_buf[vi, vj] = data_buf[vi, index_buf[0]]
def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None: C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with tir.block([], "root"): tir.reads([]) tir.writes([]) B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) for i0 in tir.serial(0, 128): for ax1 in tir.serial(0, 128): with tir.block([128, 128], "B") as [vi, vj]: tir.block_attr({"buffer_dim_align": [0]}) tir.bind(vi, i0) tir.bind(vj, ax1) tir.reads([A[vi, vj]]) tir.writes([B[vi, vj]]) B[vi, vj] = (A[vi, vj] * tir.float32(2)) for i1 in tir.serial(0, 128): with tir.block([128, 128], "C") as [vi_1, vj_1]: tir.bind(vi_1, i0) tir.bind(vj_1, i1) tir.reads([B[vi_1, vj_1]]) tir.writes([C[vi_1, vj_1]]) C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
def expected_recursive_bufferslice_indices(data: ty.handle, index: ty.handle) -> None: index_buf = tir.match_buffer(index, [1], dtype="int32", elem_offset=0, align=128, offset_factor=1) data_buf = tir.match_buffer(data, [16, 16], elem_offset=0, align=128, offset_factor=1) with tir.block([], "root"): tir.reads([]) tir.writes([]) out_buf = tir.alloc_buffer([16, 16], elem_offset=0, align=128, offset_factor=1) for i0, i1 in tir.grid(16, 16): with tir.block([16, 16], "") as [vi, vj]: tir.bind(vi, i0) tir.bind(vj, i1) tir.reads([data_buf[0:16, 0:16], index_buf[0]]) tir.writes([out_buf[vi, vj]]) out_buf[vi, vj] = data_buf[index_buf[index_buf[0]], index_buf[0]]
def buffer_load_store_func(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128), "float32") B = tir.match_buffer(b, (128, 128), "float32") C = tir.alloc_buffer((128, 128), "float32") D = tir.alloc_buffer((128, 128), "float32") with tir.block([128, 128]) as [i, j]: A[i, j] = tir.float32(0) with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]: with tir.init(): for ii, jj in tir.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in tir.grid(4, 4): for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]
def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, [1024]) B = tir.match_buffer(b, [1024]) for i in tir.serial(0, 8): with tir.block([8]) as [vi]: tir.reads(A[vi * 128:vi * 128 + 128]) tir.writes(B[vi * 128:vi * 128 + 128]) A_cache = tir.alloc_buffer([1024]) with tir.block([8]) as [v]: tir.bind(v, vi) tir.reads([A[v * 128:v * 128 + 128]]) tir.writes([A_cache[v * 128:v * 128 + 128]]) tir.evaluate( tir.call_extern("test", A_cache.data, v * 128, 128, A.data, v * 128, 128, dtype="float32")) for j in tir.serial(0, 128): with tir.block([1024]) as [v]: tir.bind(v, ((vi * 128) + j)) tir.reads([A_cache[v]]) tir.writes([B[v]]) B[v] = A_cache[v]
def func_multi_consumer() -> None: A = tir.alloc_buffer((128)) B = tir.alloc_buffer((128)) C = tir.alloc_buffer((128)) for i in tir.grid(8): for j in tir.grid(16): with tir.block([128], "A") as [vi]: tir.bind(vi, i * 16 + j) A[vi] = 1.0 for j in tir.grid(16): with tir.block([128], "B") as [vi]: tir.bind(vi, i * 16 + j) B[vi] = A[vi] + 1.0 for i in tir.grid(128): with tir.block([128], "C") as [vi]: C[vi] = A[vi]
def original_func() -> None: A = tir.alloc_buffer((128, 128), "float32") with tir.block([128, 128]) as [i, j]: A[i, j] = tir.float32(0) with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]: B = tir.alloc_buffer((128, 128), "float32") C = tir.alloc_buffer((128, 128), "float32") D = tir.alloc_buffer((128, 128), "float32") if k == 0: for ii, jj in tir.grid(4, 4): B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] for ii, jj in tir.grid(4, 4): for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] for kk in range(0, 4): B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]
def func() -> None: A = tir.alloc_buffer((128, 128), "float32") B = tir.alloc_buffer((128, 128), "float32") C = tir.alloc_buffer((128, 128), "float32") D = tir.alloc_buffer((128, 128), "float32") with tir.block([]): # Need add read/write region manually to avoid triggering block access region detector tir.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) tir.writes([A[0:12, 0:12]]) for i, j in tir.grid(8, 8): A[i, j] = B[0, 0] + C[0, 0] with tir.block([2, 2]) as [vi, vj]: tir.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) tir.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) for i, j in tir.grid(4, 4): A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] tir.evaluate(D.data)
def multiple_reduction_blocks(a: ty.handle, f: ty.handle) -> None: A = tir.match_buffer(a, (16, 16, 16)) C = tir.alloc_buffer((16, 16)) D = tir.alloc_buffer((16, 16)) E = tir.alloc_buffer((16, 16)) F = tir.match_buffer(f, (16, 16)) for i in tir.serial(0, 16): for j1 in tir.serial(0, 16): for k1o, k1i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "C") as [ci, cj, ck]: tir.bind(ci, i) tir.bind(cj, j1) tir.bind(ck, k1o * 4 + k1i) with tir.init(): C[ci, cj] = 0.0 C[ci, cj] = C[ci, cj] + A[ci, cj, ck] for k2o, k2i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: tir.bind(di, i) tir.bind(dj, j1) tir.bind(dk, k2o * 4 + k2i) with tir.init(): D[di, dj] = 0.0 D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] for j2 in tir.serial(0, 16): for k3o, k3i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: tir.bind(ei, i) tir.bind(ej, j2) tir.bind(ek, k3o * 4 + k3i) with tir.init(): E[ei, ej] = 0.0 E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] for k4o, k4i in tir.grid(4, 4): with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: tir.bind(fi, i) tir.bind(fj, j2) tir.bind(fk, k4o * 4 + k4i) with tir.init(): F[fi, fj] = 0.0 F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
def elementwise(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = B[vi, vj] + 1.0
def elementwise_multi_reverse_loads(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: C[vi, vj] = (B[vi, vj] + 1.0) * (B[vi, vj] * 2.0) + 3.0
def transformed_match_buffer_func() -> None: for i in range(0, 128): with tir.block([128]) as [vi]: tir.bind(vi, i) C = tir.alloc_buffer((128, 128)) C0 = tir.match_buffer(C[vi, 0:128], (128)) with tir.block([128]) as [jj]: C1 = tir.match_buffer(C0[jj], ()) C1[()] = 0
def tir_element_wise(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) C = tir.match_buffer(c, (128, 128)) B = tir.alloc_buffer((128, 128)) with tir.block([128, 128]) as [i, j]: B[i, j] = A[i, j] * 2.0 with tir.block([128, 128]) as [i, j]: C[i, j] = B[i, j] + 1.0
def buffer_matched(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: Bb = tir.match_buffer(B[vi:vi + 1, vj], (1, 1)) C[vi, vj] = Bb[0, 0] + 1.0
def elementwise_predicate(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 for i, j in tir.grid(128, 128): with tir.block([128, 128], "C") as [vi, vj]: tir.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0
def opaque_access_load(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 with tir.block([128, 128], "C") as [vi, vj]: tir.reads(B[0:128, 0:128]) tir.writes(C[0:128, 0:128]) C[vi, vj] = tir.load("float32", B.data, vi * 128 + vj) + 1.0