def threadpool_nested_parallel_loop( A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.parallel(4): for j in T.parallel(4): B[i, j] = A[i, j] * 2.0
def threadpool_nested_parallel_loop( A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) for i in T.parallel(4): for j in T.parallel(4): T.store(B.data, i * 4 + j, T.load("float32", A.data, i * 4 + j) * 2.0)
def scatter_compute_parallelize( A: T.Buffer[(16,), "float32"], B: T.Buffer[(16,), "float32"] ) -> None: # body # with T.block("root") for i in T.parallel(8): with T.block("first_half"): vi = T.axis.spatial(16, 8 + i) T.reads(A[vi - 8]) T.writes(B[vi]) B[vi] = A[vi - 8] for i in T.parallel(8): with T.block("last_half"): vi = T.axis.spatial(16, i) T.reads(A[vi + 8]) T.writes(B[vi]) B[vi] = A[vi + 8]
def element_wise_parallelized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i0 in T.parallel(0, 128): for i1 in T.serial(0, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i0, i1]) B[vi, vj] = A[vi, vj] * 2.0
def loop_syntax_sugar(a: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) for i in T.serial(128): for j in T.parallel(128): for k in T.vectorized(128): for x in T.unroll(128): for y in T.thread_binding(128, "threadIdx.x"): for z in T.thread_binding(128, thread="threadIdx.x"): A[i, j, k, x] = A[i, j, k, x] * 2.0
def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32): T.func_attr({"global_symbol": "elemwise_sum_parallel", "tir.noalias": True}) A = T.match_buffer(a, (n,), dtype="float32") B = T.match_buffer(b, (n,), dtype="float32") C = T.match_buffer(c, (n,), dtype="float32") for i in T.parallel(n): with T.block("C"): vi = T.axis.spatial(n, i) C[vi] = A[vi] + B[vi]
def element_wise_parallelized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) for i0 in T.parallel(0, 128): for i1 in T.serial(0, 128): with T.block([128, 128], "B") as [vi, vj]: T.bind(vi, i0) T.bind(vj, i1) B[vi, vj] = A[vi, vj] * 2.0
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True}) placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_34 = T.match_buffer(placeholder_31, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_35 = T.match_buffer(placeholder_32, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_9 = T.match_buffer(T_cast_8, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_3 = T.allocate([150528], "int16", "global") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)] for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): Conv2dOutput_3 = T.allocate([1], "int32", "global") Conv2dOutput_3[0] = 0 for rc_3 in T.serial(0, 192): Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32"))) T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
def element_wise_split_predicate_parallelized(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) for i in T.serial(0, 128): for j_0 in T.parallel(0, 13): for j_1 in T.serial(0, 10): with T.block("B"): T.where(j_0 * 10 + j_1 < 128) vi = T.axis.S(128, i) vj = T.axis.S(128, j_0 * 10 + j_1) B[vi, vj] = A[vi, vj] * 2.0
def rowsum_not_serial(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128,)) for i in T.serial(0, 128): for k in T.parallel(0, 128): with T.block("B"): vi, vk = T.axis.remap("SR", [i, k]) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def rowsum_not_serial(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, )) for i in T.serial(0, 128): for k in T.parallel(0, 128): with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: T.bind(vi, i) T.bind(vk, k) with T.init(): B[vi] = 0.0 B[vi] = B[vi] + A[vi, vk]
def main(placeholder: T.Buffer[(1, 16, 7, 7, 32), "float32"], placeholder_1: T.Buffer[(25088,), "float32"], T_layout_trans: T.Buffer[(1, 1, 7, 7, 512), "float32"]) -> None: # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body # with T.block("root") for i0_i1_i2_i3_i4_fused in T.parallel(25088, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): with T.block("T_layout_trans_1"): ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(1, 0) ax2 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused // 3584) ax3 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused % 3584 // 512) ax4 = T.axis.spatial(512, i0_i1_i2_i3_i4_fused % 512) T.reads(placeholder[0, (ax4 * 49 + ax2 * 7 + ax3) % 25088 // 1568, (ax2 * 7 + ax3) % 49 // 7, ax3 % 7, (ax4 * 49 + ax2 * 7 + ax3) % 1568 // 49], placeholder_1[(ax4 * 49 + ax2 * 7 + ax3) % 25088]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32")
def Move_PUV0(a: T.handle, b: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32") B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32") # body with T.block("root"): for i0_j0_fused in T.parallel(0, 8192): for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8): for k1_fused in T.vectorized(0, 32): with T.block("move"): vi = T.axis.spatial( 1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2) vj = T.axis.spatial( 1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2) vk = T.axis.spatial(1024, k0 * 32 + k1_fused) T.where(i0_j0_fused // 64 * 16 + i1 * 4 + i2 < 1024 and i0_j0_fused % 64 * 32 + j1 * 8 + j2 < 1024 and k0 * 32 + k1_fused < 1024) T.reads([A[vi, vj, vk]]) T.writes([B[vi, vj, vk]]) B[vi, vj, vk] = A[vi, vj, vk]
def loops() -> None: for i in T.parallel(0, 2): for j in T.serial(0, 1): for z in T.vectorized(3, 4): T.evaluate(0)
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2( placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None: # function attr dict T.func_attr({ "global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True }) placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1) T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1) # body PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global") for i0_i1_fused_3 in T.parallel(0, 28): for i2_3, i3_3 in T.grid(28, 192): T.store( PaddedInput_3, (((i0_i1_fused_3 * 5376) + (i2_3 * 192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3 * 5376) + (i2_3 * 192)) + i3_3)), True) for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784): for ax3_2 in T.serial(0, 16): Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global") T.store(Conv2dOutput_3, 0, 0, True) for rc_3 in T.serial(0, 192): T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast( T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3 * 192) + rc_3)), "int32") * T.cast( T.load("int16", placeholder_34.data, ((rc_3 * 16) + ax3_2)), "int32"))), True) T.store( T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3 * 16) + ax3_2), T.cast( T.cast( T.max( T.min( T.q_multiply_shift( (T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)