def elementwise_predicate(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in tir.grid(128, 128, 128, 128): with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_predicate_inlined(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) C = tir.match_buffer(c, (128, 128)) for i, j in tir.grid(128, 128): with tir.block([128, 128], "C") as [vi, vj]: tir.where(A[i, j] * 2.0 < 10.0) C[vi, vj] = A[vi, vj] * 2.0 + 1.0
def compacted_predicate_func(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (32), "float32") C = tir.match_buffer(c, (32), "float32") for i, j in tir.grid(5, 7): with tir.block([]) as []: tir.reads(A[i * 7 + j]) tir.writes(C[i * 7 + j]) tir.where(i * 7 + j < 32) C[i * 7 + j] = A[i * 7 + j] + 1.0
def elementwise_predicate(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.alloc_buffer((128, 128)) C = tir.match_buffer(c, (128, 128)) with tir.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 for i, j in tir.grid(128, 128): with tir.block([128, 128], "C") as [vi, vj]: tir.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0
def elementwise_reordered_with_predicate(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128, 128)) B = tir.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in tir.grid(128, 128, 128, 128): with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100) tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, k) tir.bind(vl, l) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_symbolic_split(a: ty.handle, b: ty.handle, n: ty.int32) -> None: A = tir.match_buffer(a, (128, 128, n)) B = tir.match_buffer(b, (128, 128, n)) for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)): with tir.block([128, 128, n], "B") as [vi, vj, vk]: tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n)) tir.bind(vi, i) tir.bind(vj, j) tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_split_with_predicate(a: ty.handle, b: ty.handle) -> None: B = tir.match_buffer(b, [128, 128, 128]) A = tir.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in tir.grid(1000, 2, 3, 1, 129, 3, 43): with tir.block([128, 128, 128], "B") as [vi, vj, vk]: tir.where((((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) and (((k0 * 43) + k1) < 128))) tir.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) tir.bind(vj, j1) tir.bind(vk, ((k0 * 43) + k1)) tir.reads([A[vi, vj, vk]]) tir.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def duplicate_predicate() -> None: with tir.block([16, 16]) as [vi, vj]: tir.where(1) tir.where(0) # error