Example #1
0
 def main(placeholder: T.Buffer[(1, 384), "int64"],
          placeholder_1: T.Buffer[(30522, 768), "float32"],
          placeholder_2: T.Buffer[(1, 384, 768), "float32"],
          T_add: T.Buffer[(1, 384, 768), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0, i1, i2 in T.grid(1, 384, 768):
         with T.block("T_add_1"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(
                 placeholder[ax0, ax1], placeholder_1[
                     T.min(T.max(T.int64(0), placeholder[
                         ax0, ax1]), T.int64(30521)):T.min(
                             T.max(T.int64(0), placeholder[ax0, ax1] +
                                   T.int64(30522)), T.int64(30521)) +
                     T.int64(1), ax2], placeholder_2[ax0, ax1, ax2])
             T.writes(T_add[ax0, ax1, ax2])
             T_add[ax0, ax1, ax2] = placeholder_1[T.min(
                 T.max(
                     T.int64(0),
                     T.Select(
                         T.cast(placeholder[ax0, ax1] < T.int64(0), "int32"
                                ) != 0, placeholder[ax0, ax1] +
                         T.int64(30522), placeholder[ax0, ax1])
                 ), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]
Example #2
0
def tir_argmax_val_idx(
    var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle
) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    m = T.var("int32")
    n = T.var("int32")
    val = T.match_buffer(var_val, [m, n], dtype="float32")
    idx = T.match_buffer(var_idx, [m, n], dtype="int32")
    argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32")
    argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32")
    for i0, i1 in T.grid(m, n):
        with T.block("argmax"):
            i, k = T.axis.remap("SR", [i0, i1])
            T.reads(val[i, k], idx[i, k])
            T.writes(argmax_v0[i], argmax_v1[i])
            with T.init():
                argmax_v0[i] = T.min_value("float32")
                argmax_v1[i] = T.int32(-1)
            v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k])
            v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k])
            argmax_v0[i] = v_argmax_v0
            argmax_v1[i] = v_argmax_v1
 def main(
     T_reshape: T.Buffer[(1, 12, 384, 384), "float32"],
     placeholder_1: T.Buffer[(T.int64(1), T.int64(12), T.int64(384),
                              384), "bool"],
     T_where: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384),
                       "float32"]
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256),
                                                 thread="blockIdx.x"):
         for i0_i1_i2_i3_fused_2 in T.thread_binding(
                 T.int64(1024), thread="threadIdx.x"):
             for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)):
                 with T.block("T_where"):
                     ax0 = T.axis.spatial(T.int64(1), T.int64(0))
                     ax1 = T.axis.spatial(
                         T.int64(12),
                         ((i0_i1_i2_i3_fused_0 * T.int64(256) +
                           i0_i1_i2_i3_fused_1) * T.int64(1024) +
                          i0_i1_i2_i3_fused_2) % T.int64(1769472) //
                         T.int64(147456))
                     ax2 = T.axis.spatial(
                         T.int64(384),
                         ((i0_i1_i2_i3_fused_0 * T.int64(256) +
                           i0_i1_i2_i3_fused_1) * T.int64(1024) +
                          i0_i1_i2_i3_fused_2) % T.int64(147456) //
                         T.int64(384))
                     ax3 = T.axis.spatial(
                         384,
                         T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) +
                                  i0_i1_i2_i3_fused_1) * T.int64(1024) +
                                 i0_i1_i2_i3_fused_2) % T.int64(384),
                                "int32"))
                     T.where((i0_i1_i2_i3_fused_0 * T.int64(256) +
                              i0_i1_i2_i3_fused_1) * T.int64(1024) +
                             i0_i1_i2_i3_fused_2 < T.int64(1769472))
                     T.reads(placeholder_1[ax0, ax1, ax2, ax3],
                             T_reshape[ax0, ax1, ax2, ax3])
                     T.writes(T_where[ax0, ax1, ax2, ax3])
                     T_where[ax0, ax1, ax2, ax3] = T.Select(
                         T.cast(placeholder_1[ax0, ax1, ax2,
                                              ax3], "int32") != 0,
                         T.float32(-1000000000), T_reshape[ax0, ax1,
                                                           ax2, ax3])
 def main(placeholder: T.Buffer[(1, 16, 7, 7, 32), "float32"], placeholder_1: T.Buffer[(25088,), "float32"], T_layout_trans: T.Buffer[(1, 1, 7, 7, 512), "float32"]) -> None:
     # function attr dict
     T.func_attr({"tir.noalias": True, "global_symbol": "main"})
     # body
     # with T.block("root")
     for i0_i1_i2_i3_i4_fused in T.parallel(25088, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}):
         with T.block("T_layout_trans_1"):
             ax0 = T.axis.spatial(1, 0)
             ax1 = T.axis.spatial(1, 0)
             ax2 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused // 3584)
             ax3 = T.axis.spatial(7, i0_i1_i2_i3_i4_fused % 3584 // 512)
             ax4 = T.axis.spatial(512, i0_i1_i2_i3_i4_fused % 512)
             T.reads(placeholder[0, (ax4 * 49 + ax2 * 7 + ax3) % 25088 // 1568, (ax2 * 7 + ax3) % 49 // 7, ax3 % 7, (ax4 * 49 + ax2 * 7 + ax3) % 1568 // 49], placeholder_1[(ax4 * 49 + ax2 * 7 + ax3) % 25088])
             T.writes(T_layout_trans[ax0, ax1, ax2, ax3, ax4])
             T_layout_trans[ax0, ax1, ax2, ax3, ax4] = T.if_then_else(ax0 < 1 and ax1 * 512 + ax4 < 512 and ax2 < 7 and ax3 < 7, T.Select(T.float32(0) < T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32"), T.if_then_else(0 < 1 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 < 512 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7 < 7 and ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7 < 7, placeholder[0, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 // 32, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 49 // 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 7, ((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088 % 25088 // 49 % 32], T.float32(0), dtype="float32") * placeholder_1[((ax1 * 512 + ax4) * 49 + ax2 * 7 + ax3) % 25088]), T.float32(0), dtype="float32")
Example #5
0
def conv2d_winograd_cuda(  # type: ignore
    placeholder: T.Buffer[(1, 14, 14, 128), "float32"],  # type: ignore
    placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"],  # type: ignore
    conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"],  # type: ignore
) -> None:
    # type: ignore
    data_pad = T.alloc_buffer([1, 16, 16, 128])
    input_tile = T.alloc_buffer([6, 6, 9, 128])
    B = T.alloc_buffer([6, 6])
    data_pack = T.alloc_buffer([6, 6, 9, 128])
    bgemm = T.alloc_buffer([6, 6, 9, 128])
    A = T.alloc_buffer([6, 4])
    inverse = T.alloc_buffer([4, 4, 9, 128])
    for i0, i1, i2, i3 in T.grid(1, 16, 16, 128):
        with T.block("data_pad"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.block_attr({"schedule_rule": "None"})
            T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]])
            T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]])
            data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
                0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14,  # type: ignore
                placeholder[i0_1, i1_1, i2_1, i3_1],
                T.float32(0),
                dtype="float32",
            )
    for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128):
        with T.block("input_tile"):
            eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2])
            T.block_attr({"schedule_rule": "None"})
            T.reads(
                [
                    data_pad[
                        T.floordiv(p, 9),  # type: ignore
                        ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                        ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                        ci,
                    ]
                ]
            )
            T.writes([input_tile[eps, nu, p, ci]])
            input_tile[eps, nu, p, ci] = data_pad[
                T.floordiv(p, 9),  # type: ignore
                ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                ci,
            ]
    for i0_3, i1_3 in T.grid(6, 6):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0_3, i1_3])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([B[i, j]])
            # fmt: off
            B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6):
        with T.block("data_pack"):
            eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap(
                "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"})
            T.reads(
                [
                    data_pack[eps_1, nu_1, p_1, ci_1],
                    input_tile[r_a, r_b, p_1, ci_1],
                    B[
                        T.min(r_a, r_b) : (  # type: ignore
                            T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))  # type: ignore
                        ),
                        T.min(eps_1, nu_1) : (  # type: ignore
                            T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1))  # type: ignore
                        ),
                    ],
                ]
            )
            T.writes([data_pack[eps_1, nu_1, p_1, ci_1]])
            with T.init():
                data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0)
            data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + (
                (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1]
            )
    for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128):
        with T.block("bgemm"):
            eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1])
            T.block_attr({"meta_schedule.write_cache_level": [3]})
            T.reads(
                [
                    bgemm[eps_2, nu_2, p_2, co],
                    data_pack[eps_2, nu_2, p_2, ci_2],
                    placeholder_1[eps_2, nu_2, co, ci_2],
                ]
            )
            T.writes([bgemm[eps_2, nu_2, p_2, co]])
            with T.init():
                bgemm[eps_2, nu_2, p_2, co] = T.float32(0)
            bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + (
                data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2]
            )
    for i0_6, i1_6 in T.grid(6, 4):
        with T.block("A"):
            i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([A[i_1, j_1]])
            # fmt: off
            A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6):
        with T.block("inverse"):
            vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap(
                "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"})
            T.reads(
                [
                    inverse[vh, vw, p_3, co_1],
                    bgemm[r_a_1, r_b_1, p_3, co_1],
                    A[
                        T.min(r_a_1, r_b_1) : (  # type: ignore
                            T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))  # type: ignore
                        ),
                        T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))),  # type: ignore
                    ],
                ]
            )
            T.writes([inverse[vh, vw, p_3, co_1]])
            with T.init():
                inverse[vh, vw, p_3, co_1] = T.float32(0)
            inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + (
                (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw]
            )
    for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128):
        with T.block("conv2d_winograd"):
            n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6])
            T.reads(
                [
                    inverse[
                        T.floormod(h, 4),  # type: ignore
                        T.floormod(w, 4),  # type: ignore
                        (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                        co_2,
                    ]
                ]
            )
            T.writes([conv2d_winograd[n, h, w, co_2]])
            conv2d_winograd[n, h, w, co_2] = inverse[
                T.floormod(h, 4),  # type: ignore
                T.floormod(w, 4),  # type: ignore
                (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                co_2,
            ]
Example #6
0
 def main(
     placeholder: T.Buffer[(1, 384), "int64"],
     placeholder_1: T.Buffer[(30522, 768), "float32"],
     placeholder_2: T.Buffer[(1, 384, 768), "float32"],
     T_add: T.Buffer[(1, 384, 768), "float32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     compile_engine_const = T.alloc_buffer([], dtype="int64")
     T_less = T.alloc_buffer([1, 384], dtype="bool")
     compile_engine_const_1 = T.alloc_buffer([], dtype="int64")
     T_add_1 = T.alloc_buffer([1, 384], dtype="int64")
     T_where = T.alloc_buffer([1, 384], dtype="int64")
     T_take = T.alloc_buffer([1, 384, 768], dtype="float32")
     with T.block("compile_engine_const"):
         vi = T.axis.spatial(1, 0)
         T.reads()
         T.writes(compile_engine_const[()])
         compile_engine_const[()] = T.int64(0)
     for i0, i1 in T.grid(1, 384):
         with T.block("T_less"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(placeholder[ax0, ax1], compile_engine_const[()])
             T.writes(T_less[ax0, ax1])
             T_less[ax0,
                    ax1] = placeholder[ax0, ax1] < compile_engine_const[()]
     with T.block("compile_engine_const_1"):
         vi = T.axis.spatial(1, 0)
         T.reads()
         T.writes(compile_engine_const_1[()])
         compile_engine_const_1[()] = T.int64(30522)
     for i0, i1 in T.grid(1, 384):
         with T.block("T_add"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(placeholder[ax0, ax1], compile_engine_const_1[()])
             T.writes(T_add_1[ax0, ax1])
             T_add_1[ax0,
                     ax1] = placeholder[ax0,
                                        ax1] + compile_engine_const_1[()]
     for i0, i1 in T.grid(1, 384):
         with T.block("T_where"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_less[ax0, ax1], T_add_1[ax0, ax1], placeholder[ax0,
                                                                      ax1])
             T.writes(T_where[ax0, ax1])
             T_where[ax0, ax1] = T.Select(
                 T.cast(T_less[ax0, ax1], "int32") != 0, T_add_1[ax0, ax1],
                 placeholder[ax0, ax1])
     for i0, i1, i2 in T.grid(1, 384, 768):
         with T.block("T_take"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(
                 placeholder_1[T.min(T.max(T.int64(0), T_where[
                     ax0, ax1]), T.int64(30521)), ax2],
                 T_where[ax0, ax1],
             )
             T.writes(T_take[ax0, ax1, ax2])
             T_take[ax0, ax1, ax2] = placeholder_1[
                 T.min(T.max(T.int64(0), T_where[ax0,
                                                 ax1]), T.int64(30521)),
                 ax2]
     for i0, i1, i2 in T.grid(1, 384, 768):
         with T.block("T_add_1"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(T_take[ax0, ax1, ax2], placeholder_2[ax0, ax1, ax2])
             T.writes(T_add[ax0, ax1, ax2])
             T_add[ax0, ax1,
                   ax2] = T_take[ax0, ax1, ax2] + placeholder_2[ax0, ax1,
                                                                ax2]