def elementwise_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 128): with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): with T.block([128, 128], "C") as [vi, vj]: T.where(A[i, j] * 2.0 < 10.0) C[vi, vj] = A[vi, vj] * 2.0 + 1.0
def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block("root"): T.reads([]) T.writes([]) for i0_0 in T.serial(0, 16): for i0_1_init, i1_init in T.grid(8, 128): with T.block("update_init"): vi_init = T.axis.S(128, i0_0 * 8 + i0_1_init) vj_init = T.axis.S(128, i1_init) C[vi_init, vj_init] = T.float32(0) for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7): with T.block("update_update"): T.where((((i2_0 * 7) + i2_1) < 128)) vi = T.axis.S(128, i0_0 * 8 + i0_1) vj = T.axis.S(128, i1) vk = T.axis.R(128, i2_0 * 7 + i2_1) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None: C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) # body with T.block([], "root"): T.reads([]) T.writes([]) for i0_0 in T.serial(0, 16): for i0_1_init, i1_init in T.grid(8, 128): with T.block([128, 128], "update_init") as [vi_init, vj_init]: T.bind(vi_init, ((i0_0 * 8) + i0_1_init)) T.bind(vj_init, i1_init) C[vi_init, vj_init] = T.float32(0) for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7): with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [ vi, vj, vk, ]: T.where((((i2_0 * 7) + i2_1) < 128)) T.bind(vi, ((i0_0 * 8) + i0_1)) T.bind(vj, i1) T.bind(vk, ((i2_0 * 7) + i2_1)) C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None: X = T.match_buffer(a, [224, 224], dtype="float32") Y = T.match_buffer(b, [224, 224], dtype="float32") cache = T.alloc_buffer([224, 224], dtype="float32") for hh_0, ww_0 in T.grid(28, 28): for ax0 in T.serial(0, 10): for ax1 in T.serial(0, 10): with T.block("cache"): h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0) w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1) T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225) cache[h, w] = X[h, w] for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3): with T.block("compute"): h = T.axis.spatial(224, hh_0 * 8 + hh_1) w = T.axis.spatial(224, ww_0 * 8 + ww_1) kh, kw = T.axis.remap("RR", [khh, kww]) with T.init(): Y[h, w] = 0.0 Y[h, w] = T.max( Y[h, w], T.if_then_else( T.likely(1 <= h + kh, dtype="bool") and T.likely(h + kh < 225, dtype="bool") and T.likely(1 <= w + kw, dtype="bool") and T.likely(w + kw < 225, dtype="bool"), cache[h + kh - 1, w + kw - 1], 0.0, dtype="float32", ), )
def elementwise_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for i, j, k, l in T.grid(128, 128, 128, 128): with T.block("B"): T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_predicate_inlined(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) T.where(A[i, j] * 2.0 < 10.0) C[vi, vj] = A[vi, vj] * 2.0 + 1.0
def narrow_shape(A: T.Buffer[(10,), "float32"], B: T.Buffer[(10,), "float32"]) -> None: B_cache = T.alloc_buffer(10, "float32") for j in T.serial(3): for k in T.serial(4): with T.block("B_cache"): T.where(j * 4 + k < 10) B_cache[j * 4 + k] = B[j] for i in T.serial(10): A[i] = B_cache[i] + T.float32(1)
def element_wise_split_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i, j_0, j_1 in T.grid(128, 13, 10): with T.block([128, 128], "B") as [vi, vj]: T.where(j_0 * 10 + j_1 < 128) T.bind(vi, i) T.bind(vj, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0
def element_wise_split_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i, j_0, j_1 in T.grid(128, 13, 10): with T.block("B"): T.where(j_0 * 10 + j_1 < 128) vi = T.axis.S(128, i) vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0
def elementwise_predicate(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) with T.block([128, 128], "B") as [vi, vj]: B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): with T.block([128, 128], "C") as [vi, vj]: T.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0
def compacted_predicate_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (32), "float32") C = T.match_buffer(c, (32), "float32") for i, j in T.grid(5, 7): with T.block() as []: T.reads(A[i * 7 + j]) T.writes(C[i * 7 + j]) T.where(i * 7 + j < 32) C[i * 7 + j] = A[i * 7 + j] + 1.0
def elementwise_reordered_with_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) B = T.match_buffer(b, (128, 128, 128, 128)) for l, j, k, i in T.grid(128, 128, 128, 128): with T.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]: T.where(i * 2097152 + j * 16384 + k * 128 + l < 100) T.bind(vi, i) T.bind(vj, j) T.bind(vk, k) T.bind(vl, l) B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def rowsum_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.where(k_0 * 10 + k_1 < 128) T.bind(vi, i) T.bind(vk, k_0 * 10 + k_1) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def rowsum_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): with T.block("B"): T.where(k_0 * 10 + k_1 < 128) vi = T.axis.S(128, i) vk = T.axis.R(128, k_0 * 10 + k_1) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i in T.serial(0, 128): for j_0 in T.parallel(0, 13): for j_1 in T.serial(0, 10): with T.block("B"): T.where(j_0 * 10 + j_1 < 128) vi = T.axis.S(128, i) vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0
def elementwise_symbolic_split(a: T.handle, b: T.handle, n: T.int32) -> None: A = T.match_buffer(a, (128, 128, n)) B = T.match_buffer(b, (128, 128, n)) for i, j, k0, k1 in T.grid(128, 128, 10, T.floordiv((n + 9), 10)): with T.block("B"): T.where((((k0 * T.floordiv((n + 9), 10)) + k1) < n)) vi, vj = T.axis.remap("SS", [i, j]) vk = T.axis.S(n, k0 * T.floordiv(n + 9, 10) + k1) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def transformed_three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]) -> None: for tx in T.thread_binding(16, thread="threadIdx.x"): with T.block(): T.reads(A[tx, 0:16]) T.writes(D[tx, 0:16]) B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads(A[tx, 0:2], B[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0]) for i in T.unroll(2): with T.block(): T.reads(A[tx, i]) T.writes(B[0:2, tx, 0]) B[i, tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.where(1 <= i) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14]) for i in T.serial(14): with T.block(): T.reads(A[tx, i + 2]) T.writes(B[0:2, tx, 0]) B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2) with T.block(): T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i]) D[tx, i] = C[i % 2, tx, 0] + T.float32(1) with T.block(): T.reads(B[0:2, tx, 0], C[0:2, tx, 0]) T.writes(C[0:2, tx, 0], D[tx, 14:16]) for i in T.unroll(2): with T.block(): T.where(i < 1) T.reads(B[0:2, tx, 0]) T.writes(C[0:2, tx, 0]) C[(i + 1) % 2, tx, 0] = B[(i + 1) % 2, tx, 0] + T.float32(2) with T.block(): T.reads(C[0:2, tx, 0]) T.writes(D[tx, i + 14]) D[tx, i + 14] = C[i, tx, 0] + T.float32(1)
def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None: B = T.match_buffer(b, [128, 128, 128]) A = T.match_buffer(a, [128, 128, 128]) for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43): with T.block("B"): vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2) vj = T.axis.S(128, j0 * 129 + j1) vk = T.axis.S(128, k0 * 43 + k1) T.where((i0 * 2 + i1) * 3 + i2 < 128 and j0 * 129 + j1 < 128 and k0 * 43 + k1 < 128) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def compacted_narrow_shape(A: T.Buffer[(10,), "float32"], B: T.Buffer[(10,), "float32"]) -> None: # body # with T.block("root") B_cache = T.alloc_buffer([10], dtype="float32") for j, k in T.grid(3, 4): with T.block("B_cache"): T.where(j * 4 + k < 10) T.reads(B[j]) T.writes(B_cache[j * 4 + k]) B_cache[j * 4 + k] = B[j] for i in T.serial(10): A[i] = B_cache[i] + T.float32(1)
def elementwise_predicate(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.alloc_buffer((128, 128)) C = T.match_buffer(c, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j]) T.where(B[i, j] < 10.0) C[vi, vj] = B[vi, vj] + 1.0
def func_with_block_predicate() -> None: A = T.alloc_buffer((120)) B = T.alloc_buffer((120)) for i, j in T.grid(16, 8): with T.block("producer"): T.where(i * 8 + j < 120) ax = T.axis.S(120, i * 8 + j) A[ax] = 0.0 for i, j in T.grid(16, 8): with T.block("consumer"): T.where(i * 8 + j < 120) ax = T.axis.S(120, i * 8 + j) B[ax] = A[ax] + 1.0
def with_block_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 120], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") for i, ko in T.grid(128, 4): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("B"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) T.reads([B[vi], A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + A[vi, vk]
def compacted_spatial_tiled_pad_and_pooling( X: T.Buffer[(64, 112, 112), "int32"], Y: T.Buffer[(64, 56, 56), "int32"] ) -> None: for h_o, w_o in T.grid(14, 14): with T.block(): T.reads(X[0:64, h_o * 8 - 1 : h_o * 8 + 8, w_o * 8 - 1 : w_o * 8 + 8]) T.writes(Y[h_o * 4 : h_o * 4 + 4, w_o * 4 : w_o * 4 + 4, 0:64]) X_cache = T.alloc_buffer([9, 9, 64], dtype="int32") for ax0, ax1, ax2 in T.grid(64, 9, 9): with T.block("cache"): T.where(1 <= h_o * 8 + ax1 and 1 <= w_o * 8 + ax2) T.reads(X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1]) T.writes( X_cache[ h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1, ax0, ] ) X_cache[ h_o * 8 + ax1 - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + ax2 - T.max(0, w_o * 8 - 1) - 1, ax0, ] = X[ax0, h_o * 8 + ax1 - 1, w_o * 8 + ax2 - 1] for h_i, w_i, kh, kw, c in T.grid(4, 4, 3, 3, 64): with T.block("compute"): T.reads( X_cache[ h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1, c, ] ) T.writes(Y[h_o * 4 + h_i, w_o * 4 + w_i, c]) if kh == 0 and kw == 0: Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = 0 Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = T.max( Y[h_o * 4 + h_i, w_o * 4 + w_i, c], T.if_then_else( T.likely(1 <= h_o * 8 + h_i * 2 + kh, dtype="bool") and T.likely(1 <= w_o * 8 + w_i * 2 + kw, dtype="bool"), X_cache[ h_o * 8 + h_i * 2 + kh - T.max(0, h_o * 8 - 1) - 1, w_o * 8 + w_i * 2 + kw - T.max(0, w_o * 8 - 1) - 1, c, ], 0, dtype="int32", ), )
def main( T_reshape: T.Buffer[(1, 12, 384, 384), "float32"], placeholder_1: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "bool"], T_where: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "float32"] ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): for i0_i1_i2_i3_fused_2 in T.thread_binding( T.int64(1024), thread="threadIdx.x"): for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)): with T.block("T_where"): ax0 = T.axis.spatial(T.int64(1), T.int64(0)) ax1 = T.axis.spatial( T.int64(12), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(1769472) // T.int64(147456)) ax2 = T.axis.spatial( T.int64(384), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(147456) // T.int64(384)) ax3 = T.axis.spatial( 384, T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(384), "int32")) T.where((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2 < T.int64(1769472)) T.reads(placeholder_1[ax0, ax1, ax2, ax3], T_reshape[ax0, ax1, ax2, ax3]) T.writes(T_where[ax0, ax1, ax2, ax3]) T_where[ax0, ax1, ax2, ax3] = T.Select( T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") != 0, T.float32(-1000000000), T_reshape[ax0, ax1, ax2, ax3])
def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 120], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") normal_reduce_temp0 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") for i in T.serial(0, 128): for ki in T.thread_binding(0, 32, thread="threadIdx.x"): with T.block("B_in_thread_init"): T.reads([]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.float32(0) for ko in T.serial(0, 4): with T.block("B_normal_reduction"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) T.reads([A[vi, vk], normal_reduce_temp0[0]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] with T.block("B_cross_thread_reduction"): T.reads([normal_reduce_temp0[0]]) T.writes([reduce_temp0[0]]) T.attr( T.comm_reducer(lambda x, y: x + y, [T.float32(0)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ) T.evaluate( T.tvm_thread_allreduce( T.uint32(1), normal_reduce_temp0[0], True, reduce_temp0.data, ki, dtype="handle", )) with T.block("B_write_back"): vi = T.axis.spatial(128, i) T.reads([reduce_temp0[0]]) T.writes([B[vi]]) B[vi] = reduce_temp0[0]
def rowsum_predicate_rfactor(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") B = T.match_buffer(b, [128], dtype="float32") B_rf = T.alloc_buffer([128, 13], dtype="float32") for i, k_0, k_1 in T.grid(128, 13, 10): with T.block("B_rf"): vk_0, vi, vk_1 = T.axis.remap("SSR", [k_0, i, k_1]) T.where(k_0 * 10 + k_1 < 128) with T.init(): B_rf[vi, vk_0] = T.float32(0) B_rf[vi, vk_0] = B_rf[vi, vk_0] + A[vi, vk_0 * 10 + vk_1] for i, k_0 in T.grid(128, 13): with T.block("B"): vk_0, vi = T.axis.remap("RS", [k_0, i]) with T.init(): B[vi] = T.float32(0) B[vi] = B[vi] + B_rf[vi, vk_0]
def main(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") # body with T.block("root"): T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32}) for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32): with T.block("move"): vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2) vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2) vk = T.axis.spatial(1024, k0 * 32 + k1) T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk]
def block_predicate_cache_write_output_buf() -> None: A = T.alloc_buffer([120], dtype="float32") B = T.alloc_buffer([120], dtype="float32") B_shared = T.alloc_buffer([120], dtype="float32", scope="shared") for i, j in T.grid(16, 8): with T.block("producer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) A[ax] = T.float32(0) for i, j in T.grid(16, 8): with T.block("consumer"): ax = T.axis.spatial(120, i * 8 + j) T.where(i * 8 + j < 120) B_shared[ax] = A[ax] + T.float32(1) for ax0 in T.serial(120): with T.block("B_shared"): v0 = T.axis.spatial(120, ax0) B[v0] = B_shared[v0]
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1, ), "float32"]) -> None: C = T.alloc_buffer([1], dtype="float32") for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): for i1, i2 in T.grid(256, 256): with T.block("C"): b = T.axis.S(1, 0) i, j = T.axis.remap("RR", [i1, i2]) T.where(i0_fused_1 < 1) with T.init(): C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): with T.block("D"): b = T.axis.S(1, 0) T.where(i0_fused_1 < 1) D[b] = T.sqrt(C[b], dtype="float32")