def main(placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(30522, 768), "float32"], placeholder_2: T.Buffer[(1, 384, 768), "float32"], T_add: T.Buffer[(1, 384, 768), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_add_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( placeholder[ax0, ax1], placeholder_1[ T.min(T.max(T.int64(0), placeholder[ ax0, ax1]), T.int64(30521)):T.min( T.max(T.int64(0), placeholder[ax0, ax1] + T.int64(30522)), T.int64(30521)) + T.int64(1), ax2], placeholder_2[ax0, ax1, ax2]) T.writes(T_add[ax0, ax1, ax2]) T_add[ax0, ax1, ax2] = placeholder_1[T.min( T.max( T.int64(0), T.Select( T.cast(placeholder[ax0, ax1] < T.int64(0), "int32" ) != 0, placeholder[ax0, ax1] + T.int64(30522), placeholder[ax0, ax1]) ), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]
def tir_argmax_val_idx( var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) m = T.var("int32") n = T.var("int32") val = T.match_buffer(var_val, [m, n], dtype="float32") idx = T.match_buffer(var_idx, [m, n], dtype="int32") argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32") argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32") for i0, i1 in T.grid(m, n): with T.block("argmax"): i, k = T.axis.remap("SR", [i0, i1]) T.reads(val[i, k], idx[i, k]) T.writes(argmax_v0[i], argmax_v1[i]) with T.init(): argmax_v0[i] = T.min_value("float32") argmax_v1[i] = T.int32(-1) v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k]) v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k]) argmax_v0[i] = v_argmax_v0 argmax_v1[i] = v_argmax_v1
def main( T_reshape: T.Buffer[(1, 12, 384, 384), "float32"], placeholder_1: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "bool"], T_where: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384), "float32"] ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): for i0_i1_i2_i3_fused_2 in T.thread_binding( T.int64(1024), thread="threadIdx.x"): for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)): with T.block("T_where"): ax0 = T.axis.spatial(T.int64(1), T.int64(0)) ax1 = T.axis.spatial( T.int64(12), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(1769472) // T.int64(147456)) ax2 = T.axis.spatial( T.int64(384), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(147456) // T.int64(384)) ax3 = T.axis.spatial( 384, T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(384), "int32")) T.where((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2 < T.int64(1769472)) T.reads(placeholder_1[ax0, ax1, ax2, ax3], T_reshape[ax0, ax1, ax2, ax3]) T.writes(T_where[ax0, ax1, ax2, ax3]) T_where[ax0, ax1, ax2, ax3] = T.Select( T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") != 0, T.float32(-1000000000), T_reshape[ax0, ax1, ax2, ax3])
def main(placeholder: T.Buffer[(1, 16, 7, 7, 32), "float32"], placeholder_1: T.Buffer[(25088,), "float32"], T_layout_trans: T.Buffer[(1, 1, 7, 7, 512), "float32"]) -> None: # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "main"}) # body # with T.block("root") for i0_i1_i2_i3_i4_fused in T.parallel(25088, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): with T.block("T_layout_trans_1"): ax0 = T.axis.spatial(1, 0) ax1 = T.axis.spatial(1, 0) ax2 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused // 3584) ax3 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused % 3584 // 512) ax4 = T.axis.spatial(512, i0_i1_i2_i3_i4_fused % 512) T.reads(placeholder[0, (ax4 * 49 + ax2 * 7 + ax3) % 25088 // 1568, (ax2 * 7 + ax3) % 49 // 7, ax3 % 7, (ax4 * 49 + ax2 * 7 + ax3) % 1568 // 49], placeholder_1[(ax4 * 49 + ax2 * 7 + ax3) % 25088]) T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4]) T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32")
def conv2d_winograd_cuda( # type: ignore placeholder: T.Buffer[(1, 14, 14, 128), "float32"], # type: ignore placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"], # type: ignore conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], # type: ignore ) -> None: # type: ignore data_pad = T.alloc_buffer([1, 16, 16, 128]) input_tile = T.alloc_buffer([6, 6, 9, 128]) B = T.alloc_buffer([6, 6]) data_pack = T.alloc_buffer([6, 6, 9, 128]) bgemm = T.alloc_buffer([6, 6, 9, 128]) A = T.alloc_buffer([6, 4]) inverse = T.alloc_buffer([4, 4, 9, 128]) for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): with T.block("data_pad"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.block_attr({"schedule_rule": "None"}) T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, # type: ignore placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32", ) for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): with T.block("input_tile"): eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) T.block_attr({"schedule_rule": "None"}) T.reads( [ data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] ] ) T.writes([input_tile[eps, nu, p, ci]]) input_tile[eps, nu, p, ci] = data_pad[ T.floordiv(p, 9), # type: ignore ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] for i0_3, i1_3 in T.grid(6, 6): with T.block("B"): i, j = T.axis.remap("SS", [i0_3, i1_3]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([B[i, j]]) # fmt: off B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) # type: ignore # fmt: on for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): with T.block("data_pack"): eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}) T.reads( [ data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[ T.min(r_a, r_b) : ( # type: ignore T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) # type: ignore ), T.min(eps_1, nu_1) : ( # type: ignore T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) # type: ignore ), ], ] ) T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) with T.init(): data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + ( (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1] ) for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): with T.block("bgemm"): eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) T.block_attr({"meta_schedule.write_cache_level": [3]}) T.reads( [ bgemm[eps_2, nu_2, p_2, co], data_pack[eps_2, nu_2, p_2, ci_2], placeholder_1[eps_2, nu_2, co, ci_2], ] ) T.writes([bgemm[eps_2, nu_2, p_2, co]]) with T.init(): bgemm[eps_2, nu_2, p_2, co] = T.float32(0) bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + ( data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2] ) for i0_6, i1_6 in T.grid(6, 4): with T.block("A"): i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([A[i_1, j_1]]) # fmt: off A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) # type: ignore # fmt: on for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): with T.block("inverse"): vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] ) T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"}) T.reads( [ inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[ T.min(r_a_1, r_b_1) : ( # type: ignore T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) # type: ignore ), T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), # type: ignore ], ] ) T.writes([inverse[vh, vw, p_3, co_1]]) with T.init(): inverse[vh, vw, p_3, co_1] = T.float32(0) inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + ( (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw] ) for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): with T.block("conv2d_winograd"): n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) T.reads( [ inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ] ] ) T.writes([conv2d_winograd[n, h, w, co_2]]) conv2d_winograd[n, h, w, co_2] = inverse[ T.floormod(h, 4), # type: ignore T.floormod(w, 4), # type: ignore (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ]
def main( placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(30522, 768), "float32"], placeholder_2: T.Buffer[(1, 384, 768), "float32"], T_add: T.Buffer[(1, 384, 768), "float32"], ) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) compile_engine_const = T.alloc_buffer([], dtype="int64") T_less = T.alloc_buffer([1, 384], dtype="bool") compile_engine_const_1 = T.alloc_buffer([], dtype="int64") T_add_1 = T.alloc_buffer([1, 384], dtype="int64") T_where = T.alloc_buffer([1, 384], dtype="int64") T_take = T.alloc_buffer([1, 384, 768], dtype="float32") with T.block("compile_engine_const"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const[()]) compile_engine_const[()] = T.int64(0) for i0, i1 in T.grid(1, 384): with T.block("T_less"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(placeholder[ax0, ax1], compile_engine_const[()]) T.writes(T_less[ax0, ax1]) T_less[ax0, ax1] = placeholder[ax0, ax1] < compile_engine_const[()] with T.block("compile_engine_const_1"): vi = T.axis.spatial(1, 0) T.reads() T.writes(compile_engine_const_1[()]) compile_engine_const_1[()] = T.int64(30522) for i0, i1 in T.grid(1, 384): with T.block("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(placeholder[ax0, ax1], compile_engine_const_1[()]) T.writes(T_add_1[ax0, ax1]) T_add_1[ax0, ax1] = placeholder[ax0, ax1] + compile_engine_const_1[()] for i0, i1 in T.grid(1, 384): with T.block("T_where"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) T.reads(T_less[ax0, ax1], T_add_1[ax0, ax1], placeholder[ax0, ax1]) T.writes(T_where[ax0, ax1]) T_where[ax0, ax1] = T.Select( T.cast(T_less[ax0, ax1], "int32") != 0, T_add_1[ax0, ax1], placeholder[ax0, ax1]) for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_take"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads( placeholder_1[T.min(T.max(T.int64(0), T_where[ ax0, ax1]), T.int64(30521)), ax2], T_where[ax0, ax1], ) T.writes(T_take[ax0, ax1, ax2]) T_take[ax0, ax1, ax2] = placeholder_1[ T.min(T.max(T.int64(0), T_where[ax0, ax1]), T.int64(30521)), ax2] for i0, i1, i2 in T.grid(1, 384, 768): with T.block("T_add_1"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) T.reads(T_take[ax0, ax1, ax2], placeholder_2[ax0, ax1, ax2]) T.writes(T_add[ax0, ax1, ax2]) T_add[ax0, ax1, ax2] = T_take[ax0, ax1, ax2] + placeholder_2[ax0, ax1, ax2]