def main(
     placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
     placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
     conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4):
         with T.block("conv2d_NCHWc_int8"):
             (
                 n,
                 oc_chunk,
                 oh,
                 ow,
                 oc_block,
                 kh,
                 kw,
                 ic_outer,
                 ic_f_inner,
                 ic_s_inner,
             ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9])
             T.reads(
                 placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner],
                 placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
             )
             T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
             with T.init():
                 conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
             conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
                 n, oc_chunk, oh, ow, oc_block
             ] + T.cast(
                 placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32"
             ) * T.cast(
                 placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner],
                 "int32",
             )
示例#2
0
def boolean_handling_after(a: T.Buffer[10, "int8"],
                           b: T.Buffer[10, "int8"]) -> None:
    T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
    T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
    # body
    for i0 in T.serial(10):
        b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
示例#3
0
 def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(
         placeholder: T.handle, placeholder_1: T.handle,
         T_cast: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast",
         "tir.noalias": True
     })
     placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8")
     placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
     T_cast_1 = T.match_buffer(T_cast, [360000], dtype="int16")
     # body
     for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
         T_cast_1[
             ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 +
             ax3_inner] = T.cast(
                 T.cast(
                     T.max(
                         T.min(
                             T.q_multiply_shift(T.cast(
                                 placeholder_2[ax0_ax1_fused * 4800 +
                                               ax2 * 64 + ax3_outer * 16 +
                                               ax3_inner], "int32") - 94,
                                                1843157232,
                                                31,
                                                1,
                                                dtype="int32") +
                             placeholder_3[ax3_outer * 16 + ax3_inner],
                             255), 0), "uint8"), "int16")
 def before(A: T.Buffer[16, "float32"]):
     for i in T.serial(16):
         x = T.cast(
             T.ceil(T.log2(T.cast(i + 1024 + 1, "float64"),
                           dtype="float64"),
                    dtype="float64"),
             dtype="int32",
         )
         if x == 11:
             A[i] = 0.0
示例#5
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(
         placeholder_4: T.handle, placeholder_5: T.handle,
         placeholder_6: T.handle, T_cast_2: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast",
         "tir.noalias": True
     })
     placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64],
                                    dtype="int16")
     placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64],
                                    dtype="int16")
     placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64],
                                    dtype="int32")
     T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
     # body
     PaddedInput = T.allocate([360000], "int16", "global")
     for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
         T.store(
             PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3,
             T.load("int16", placeholder_7.data,
                    i0_i1_fused * 4800 + i2 * 64 + i3), True)
     for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
         Conv2dOutput = T.allocate([64], "int32", "global")
         for ff in T.serial(0, 64):
             T.store(Conv2dOutput, ff, 0, True)
             for rc in T.serial(0, 64):
                 T.store(
                     Conv2dOutput, ff,
                     T.load("int32", Conv2dOutput, ff) + T.cast(
                         T.load("int16", PaddedInput,
                                ax0_ax1_fused_ax2_fused * 64 + rc), "int32")
                     * T.cast(
                         T.load("int16", placeholder_8.data, rc * 64 + ff),
                         "int32"), True)
         for ax3_inner_1 in T.serial(0, 64):
             T.store(
                 T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1,
                 T.cast(
                     T.cast(
                         T.max(
                             T.min(
                                 T.q_multiply_shift(
                                     T.load("int32", Conv2dOutput,
                                            ax3_inner_1) +
                                     T.load("int32", placeholder_9.data,
                                            ax3_inner_1),
                                     1843106743,
                                     31,
                                     -6,
                                     dtype="int32"), 255), 0), "uint8"),
                     "int16"), True)
示例#6
0
def dp4a_desc(
    A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
    B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
    C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
    with T.block("root"):
        T.reads(C[0], A[0:4], B[0:4])
        T.writes(C[0])
        for i in range(0, 4):
            with T.block("update"):
                vi = T.axis.remap("R", [i])
                C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
示例#7
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(
         placeholder_16: T.handle, placeholder_17: T.handle,
         placeholder_18: T.handle, T_add: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_",
         "tir.noalias": True
     })
     placeholder_19 = T.match_buffer(placeholder_16, [360000],
                                     dtype="int16")
     placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16")
     placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32")
     T_add_1 = T.match_buffer(T_add, [1440000], dtype="int32")
     # body
     PaddedInput_2 = T.allocate([360000], "int16", "global")
     for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64):
         PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 +
                       i3_2] = placeholder_19[i0_i1_fused_2 * 4800 +
                                              i2_2 * 64 + i3_2]
     for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625):
         Conv2dOutput_2 = T.allocate([64], "int32", "global")
         for ax3_outer_1 in T.serial(0, 4):
             for ff_2 in T.serial(0, 64):
                 Conv2dOutput_2[ff_2] = 0
                 for rc_2 in T.serial(0, 64):
                     Conv2dOutput_2[ff_2] = Conv2dOutput_2[ff_2] + T.cast(
                         PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 +
                                       rc_2],
                         "int32") * T.cast(
                             placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 +
                                            ff_2], "int32")
             for ax3_inner_3 in T.serial(0, 64):
                 T_add_1[
                     ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 +
                     ax3_inner_3] = T.q_multiply_shift(T.cast(
                         T.cast(
                             T.max(
                                 T.min(
                                     T.q_multiply_shift(
                                         Conv2dOutput_2[ax3_inner_3] +
                                         placeholder_21[ax3_outer_1 * 64 +
                                                        ax3_inner_3],
                                         1711626602,
                                         31,
                                         -8,
                                         dtype="int32") + 132, 255), 0),
                             "uint8"), "int32") - 132,
                                                       2094289803,
                                                       31,
                                                       -2,
                                                       dtype="int32") + 136
示例#8
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(
         placeholder_10: T.handle, placeholder_11: T.handle,
         placeholder_12: T.handle, T_cast_4: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1",
         "tir.noalias": True
     })
     placeholder_13 = T.match_buffer(placeholder_10, [360000],
                                     dtype="int16")
     placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16")
     placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32")
     T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16")
     # body
     PaddedInput_1 = T.allocate([379456], "int16", "global")
     for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
         PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 +
                       i3_1] = T.if_then_else(
                           1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76
                           and 1 <= i2_1 and i2_1 < 76,
                           placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 +
                                          i3_1 - 4864],
                           T.int16(0),
                           dtype="int16")
     for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
         Conv2dOutput_1 = T.allocate([64], "int32", "global")
         for ff_1 in T.serial(0, 64):
             Conv2dOutput_1[ff_1] = 0
             for ry, rx, rc_1 in T.grid(3, 3, 64):
                 Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast(
                     PaddedInput_1[
                         T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 +
                         ry * 4928 + rx * 64 +
                         T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 +
                         rc_1], "int32") * T.cast(
                             placeholder_14[ry * 12288 + rx * 4096 +
                                            rc_1 * 64 + ff_1], "int32")
         for ax3_inner_2 in T.serial(0, 64):
             T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 +
                      ax3_inner_2] = T.cast(
                          T.cast(
                              T.max(
                                  T.min(
                                      T.q_multiply_shift(
                                          Conv2dOutput_1[ax3_inner_2] +
                                          placeholder_15[ax3_inner_2],
                                          1608879842,
                                          31,
                                          -7,
                                          dtype="int32"), 255), 0),
                              "uint8"), "int16")
示例#9
0
文件: x86.py 项目: chenghanpeng/tvm
def dot_product_16x4_u8i8i32_desc(
    A: T.Buffer((4,), "uint8", offset_factor=1),
    B: T.Buffer((16, 4), "int8", offset_factor=1),
    C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
        T.writes(C[0:16])
        for i in T.serial(0, 16):
            for k in T.serial(0, 4):
                with T.block("update"):
                    vi, vk = T.axis.remap("SR", [i, k])
                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
示例#10
0
 def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle,
                                        placeholder_3: T.handle,
                                        T_subtract: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol": "tvmgen_default_fused_cast_subtract",
         "tir.noalias": True
     })
     placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3],
                                    dTpe="uint8",
                                    elem_offset=0,
                                    align=128,
                                    offset_factor=1)
     placeholder_5 = T.match_buffer(placeholder_3, [],
                                    dtype="int16",
                                    elem_offset=0,
                                    align=128,
                                    offset_factor=1)
     T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3],
                                   dtype="int16",
                                   elem_offset=0,
                                   align=128,
                                   offset_factor=1)
     # body
     for ax0_ax1_fused_1 in T.serial(0, 224):
         for ax2_1, ax3_inner_1 in T.grid(224, 3):
             T.store(T_subtract_1.data, (((ax0_ax1_fused_1 * 672) +
                                          (ax2_1 * 3)) + ax3_inner_1),
                     (T.cast(
                         T.load("uint8", placeholder_4.data,
                                (((ax0_ax1_fused_1 * 672) +
                                  (ax2_1 * 3)) + ax3_inner_1)), "int16") -
                      T.load("int16", placeholder_5.data, 0)), True)
示例#11
0
 def main(placeholder: T.Buffer[(1, 384), "int64"],
          placeholder_1: T.Buffer[(30522, 768), "float32"],
          placeholder_2: T.Buffer[(1, 384, 768), "float32"],
          T_add: T.Buffer[(1, 384, 768), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0, i1, i2 in T.grid(1, 384, 768):
         with T.block("T_add_1"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(
                 placeholder[ax0, ax1], placeholder_1[
                     T.min(T.max(T.int64(0), placeholder[
                         ax0, ax1]), T.int64(30521)):T.min(
                             T.max(T.int64(0), placeholder[ax0, ax1] +
                                   T.int64(30522)), T.int64(30521)) +
                     T.int64(1), ax2], placeholder_2[ax0, ax1, ax2])
             T.writes(T_add[ax0, ax1, ax2])
             T_add[ax0, ax1, ax2] = placeholder_1[T.min(
                 T.max(
                     T.int64(0),
                     T.Select(
                         T.cast(placeholder[ax0, ax1] < T.int64(0), "int32"
                                ) != 0, placeholder[ax0, ax1] +
                         T.int64(30522), placeholder[ax0, ax1])
                 ), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]
 def main(
     T_reshape: T.Buffer[(1, 12, 384, 384), "float32"],
     placeholder_1: T.Buffer[(T.int64(1), T.int64(12), T.int64(384),
                              384), "bool"],
     T_where: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384),
                       "float32"]
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256),
                                                 thread="blockIdx.x"):
         for i0_i1_i2_i3_fused_2 in T.thread_binding(
                 T.int64(1024), thread="threadIdx.x"):
             for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)):
                 with T.block("T_where"):
                     ax0 = T.axis.spatial(T.int64(1), T.int64(0))
                     ax1 = T.axis.spatial(
                         T.int64(12),
                         ((i0_i1_i2_i3_fused_0 * T.int64(256) +
                           i0_i1_i2_i3_fused_1) * T.int64(1024) +
                          i0_i1_i2_i3_fused_2) % T.int64(1769472) //
                         T.int64(147456))
                     ax2 = T.axis.spatial(
                         T.int64(384),
                         ((i0_i1_i2_i3_fused_0 * T.int64(256) +
                           i0_i1_i2_i3_fused_1) * T.int64(1024) +
                          i0_i1_i2_i3_fused_2) % T.int64(147456) //
                         T.int64(384))
                     ax3 = T.axis.spatial(
                         384,
                         T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) +
                                  i0_i1_i2_i3_fused_1) * T.int64(1024) +
                                 i0_i1_i2_i3_fused_2) % T.int64(384),
                                "int32"))
                     T.where((i0_i1_i2_i3_fused_0 * T.int64(256) +
                              i0_i1_i2_i3_fused_1) * T.int64(1024) +
                             i0_i1_i2_i3_fused_2 < T.int64(1769472))
                     T.reads(placeholder_1[ax0, ax1, ax2, ax3],
                             T_reshape[ax0, ax1, ax2, ax3])
                     T.writes(T_where[ax0, ax1, ax2, ax3])
                     T_where[ax0, ax1, ax2, ax3] = T.Select(
                         T.cast(placeholder_1[ax0, ax1, ax2,
                                              ax3], "int32") != 0,
                         T.float32(-1000000000), T_reshape[ax0, ax1,
                                                           ax2, ax3])
def unified_element_wise_thread_x_different_dtype(
    A: T.Buffer[(128, 128), "float32"],
    B: T.Buffer[(128, 128), "float32"],
    C: T.Buffer[(128, 128), "float32"],
) -> None:
    for blockIdx_x in T.thread_binding(128, "blockIdx.x"):
        for threadIdx_x in T.thread_binding(4, "threadIdx.x"):
            for j0_1 in T.serial(0, 32):
                with T.block(""):
                    B[blockIdx_x, threadIdx_x * 32 +
                      j0_1] = (A[blockIdx_x, threadIdx_x * 32 + j0_1] * 2.0)
            for j1_1 in T.serial(T.int64(32)):
                with T.block(""):
                    C[blockIdx_x,
                      T.cast(threadIdx_x, "int64") * T.int64(32) +
                      j1_1] = (B[blockIdx_x,
                                 T.cast(threadIdx_x, "int64") * T.int64(32) +
                                 j1_1] + 1.0)
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None:
    # function attr dict
    T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True})
    placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1)
    placeholder_34 = T.match_buffer(placeholder_31, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1)
    placeholder_35 = T.match_buffer(placeholder_32, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1)
    T_cast_9 = T.match_buffer(T_cast_8, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1)
    # body
    PaddedInput_3 = T.allocate([150528], "int16", "global")
    for i0_i1_fused_3 in T.parallel(0, 28):
        for i2_3, i3_3 in T.grid(28, 192):
            PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)]
    for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784):
        for ax3_2 in T.serial(0, 16):
            Conv2dOutput_3 = T.allocate([1], "int32", "global")
            Conv2dOutput_3[0] = 0
            for rc_3 in T.serial(0, 192):
                Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32")))
            T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
示例#15
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True})
     placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16")
     placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16")
     placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32")
     T_cast_3 = T.match_buffer(T_cast_2, [360000], dtype="int16")
     # body
     PaddedInput = T.allocate([360000], "int16", "global")
     for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
         PaddedInput[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3]
     for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
         Conv2dOutput = T.allocate([64], "int32", "global")
         for ff in T.serial(0, 64):
             Conv2dOutput[ff] = 0
             for rc in T.serial(0, 64):
                 Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast(PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc], "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32")
         for ax3_inner_1 in T.serial(0, 64):
             T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput[ax3_inner_1] + placeholder_9[ax3_inner_1], 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16")
 def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True})
     placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
     placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1)
     T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
     # body
     for ax0_ax1_fused_1 in T.serial(0, 224):
         for ax2_1, ax3_inner_1 in T.grid(224, 3):
             T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0])
示例#17
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True})
     placeholder_65 = T.match_buffer(placeholder_62, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1)
     placeholder_66 = T.match_buffer(placeholder_63, [9408], dtype="int16", elem_offset=0, align=128, offset_factor=1)
     placeholder_67 = T.match_buffer(placeholder_64, [64], dtype="int32", elem_offset=0, align=128, offset_factor=1)
     T_cast_21 = T.match_buffer(T_cast_20, [802816], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
     # body
     PaddedInput_7 = T.allocate([157323], "int16", "global")
     for i0_i1_fused_7 in T.serial(0, 229):
         for i2_7, i3_7 in T.grid(229, 3):
             PaddedInput_7[(((i0_i1_fused_7*687) + (i2_7*3)) + i3_7)] = T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), placeholder_65[((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)], T.int16(0), dtype="int16")
     for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544):
         Conv2dOutput_7 = T.allocate([64], "int32", "global")
         for ff_3 in T.serial(0, 64):
             Conv2dOutput_7[ff_3] = 0
             for ry_2, rx_2, rc_7 in T.grid(7, 7, 3):
                 Conv2dOutput_7[ff_3] = (Conv2dOutput_7[ff_3] + (T.cast(PaddedInput_7[(((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)], "int32")*T.cast(placeholder_66[((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)], "int32")))
         for ax3_inner_7 in T.serial(0, 64):
             T_cast_21[((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7)] = T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_7[ax3_inner_7] + placeholder_67[ax3_inner_7]), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8")
示例#18
0
 def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle,
                                             T_cast_6: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast",
         "tir.noalias": True
     })
     placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64],
                                     dtype="uint8",
                                     elem_offset=0,
                                     align=128,
                                     offset_factor=1)
     T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64],
                               dtype="int16",
                               elem_offset=0,
                               align=128,
                               offset_factor=1)
     # body
     tensor_2 = T.allocate([200704], "uint8", "global")
     for ax0_ax1_fused_4 in T.serial(0, 56):
         for ax2_4 in T.serial(0, 56):
             for ax3_init in T.serial(0, 64):
                 T.store(tensor_2, (((ax0_ax1_fused_4 * 3584) +
                                     (ax2_4 * 64)) + ax3_init), T.uint8(0),
                         True)
             for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
                 T.store(
                     tensor_2,
                     (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2),
                     T.max(
                         T.load("uint8", tensor_2,
                                (((ax0_ax1_fused_4 * 3584) +
                                  (ax2_4 * 64)) + ax3_2)),
                         T.if_then_else(
                             ((((ax0_ax1_fused_4 * 2) +
                                T.floordiv(rv0_rv1_fused_1, 3)) < 112) and
                              (((ax2_4 * 2) +
                                T.floormod(rv0_rv1_fused_1, 3)) < 112)),
                             T.load("uint8", placeholder_29.data, (((
                                 ((ax0_ax1_fused_4 * 14336) +
                                  (T.floordiv(rv0_rv1_fused_1, 3) * 7168)) +
                                 (ax2_4 * 128)) + (T.floormod(
                                     rv0_rv1_fused_1, 3) * 64)) + ax3_2)),
                             T.uint8(0),
                             dtype="uint8")), True)
     for ax0_ax1_fused_5 in T.serial(0, 56):
         for ax2_5, ax3_3 in T.grid(56, 64):
             T.store(
                 T_cast_7.data,
                 (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3),
                 T.cast(
                     T.load("uint8", tensor_2, (((ax0_ax1_fused_5 * 3584) +
                                                 (ax2_5 * 64)) + ax3_3)),
                     "int16"), True)
示例#19
0
 def main(
     placeholder: T.Buffer[(1024, 1024), "uint8"],
     placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
     compute: T.Buffer[(1024, 1024), "int32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     with T.block("root"):
         T.reads()
         T.writes()
         for i0, i1, i2 in T.grid(1024, 1024, 1024):
             with T.block("compute"):
                 i, j, k = T.axis.remap("SSR", [i0, i1, i2])
                 T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4,
                                                          j % 16, k % 4])
                 T.writes(compute[i, j])
                 with T.init():
                     compute[i, j] = 0
                 compute[i, j] = compute[i, j] + T.cast(
                     placeholder[i, k], "int32") * T.cast(
                         placeholder_1[j // 16, k // 4, j % 16, k % 4],
                         "int32")
    def tir_packed_call() -> None:
        A = T.var("handle")
        B = T.var("handle")
        C = T.var("handle")
        device_context = T.var("handle")

        # body
        T.evaluate(
            T.tvm_call_cpacked(
                "tvm_test_cpacked",
                T.tvm_stack_make_array(
                    A,
                    T.tvm_stack_make_shape(1, dtype="handle"),
                    T.reinterpret(T.uint64(0), dtype="handle"),
                    T.uint32(1),
                    T.cast(0, dtype="float32"),
                    0,
                    dtype="handle",
                ),
                T.tvm_stack_make_array(
                    B,
                    T.tvm_stack_make_shape(1, dtype="handle"),
                    T.reinterpret(T.uint64(0), dtype="handle"),
                    T.uint32(1),
                    T.cast(0, dtype="float32"),
                    0,
                    dtype="handle",
                ),
                T.tvm_stack_make_array(
                    C,
                    T.tvm_stack_make_shape(1, dtype="handle"),
                    T.reinterpret(T.uint64(0), dtype="handle"),
                    T.uint32(1),
                    T.cast(0, dtype="float32"),
                    0,
                    dtype="handle",
                ),
                device_context,
                dtype="int32",
            ))
 def main(
     placeholder: T.Buffer[(1024, 1024), "uint8"],
     placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"],
     compute: T.Buffer[(1024, 1024), "int32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4):
         with T.block("compute"):
             i = T.axis.spatial(1024, i0)
             j = T.axis.spatial(1024, i1_0 * 16 + i1_1)
             k = T.axis.reduce(1024, i2_0 * 4 + i2_1)
             T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4,
                                                      j % 16, k % 4])
             T.writes(compute[i, j])
             with T.init():
                 compute[i, j] = 0
             compute[i, j] = compute[i, j] + T.cast(
                 placeholder[i, k], "int32") * T.cast(
                     placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32")
示例#22
0
 def main(
     placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
     placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
     conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid(
             1, 16, 56, 56, 1, 1, 1, 4, 4, 1, 16, 4):
         with T.block("conv2d_NCHWc_int8"):
             n = T.axis.spatial(1, 0)
             oc_chunk, oh, ow, oc_block = T.axis.remap(
                 "SSSS", [i1, i2, i3, i4_1])
             kh = T.axis.reduce(1, 0)
             kw = T.axis.reduce(1, 0)
             ic_outer, ic_f_inner, ic_s_inner = T.axis.remap(
                 "RRR", [i7, i8, i9_1])
             T.reads(
                 placeholder[n, ic_outer, oh + kh, ow + kw,
                             ic_f_inner * 4 + ic_s_inner],
                 placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                               oc_block, ic_s_inner],
             )
             T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block])
             with T.init():
                 conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0
             conv2d_NCHWc_int8[
                 n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
                     n, oc_chunk, oh, ow, oc_block] + T.cast(
                         placeholder[n, ic_outer, oh + kh, ow + kw,
                                     ic_f_inner * 4 + ic_s_inner], "int32"
                     ) * T.cast(
                         placeholder_1[oc_chunk, ic_outer, kh, kw,
                                       ic_f_inner, oc_block, ic_s_inner],
                         "int32",
                     )
示例#23
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True})
     placeholder_29 = T.match_buffer(placeholder_22, [360000], dtype="int16")
     placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16")
     placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32")
     placeholder_28 = T.match_buffer(placeholder_25, [1440000], dtype="int32")
     T_cast_7 = T.match_buffer(T_cast_6, [1440000], dtype="uint8")
     # body
     PaddedInput_3 = T.allocate([360000], "int16", "global")
     for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64):
         PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3]
     for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625):
         Conv2dOutput_3 = T.allocate([64], "int32", "global")
         for ax3_outer_2 in T.serial(0, 4):
             for ff_3 in T.serial(0, 64):
                 Conv2dOutput_3[ff_3] = 0
                 for rc_3 in T.serial(0, 64):
                     Conv2dOutput_3[ff_3] = Conv2dOutput_3[ff_3] + T.cast(PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 + rc_3], "int32") * T.cast(placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 + ff_3], "int32")
             for ax3_inner_4 in T.serial(0, 64):
                 T_cast_7[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4] = T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(Conv2dOutput_3[ax3_inner_4] + placeholder_26[ax3_outer_2 * 64 + ax3_inner_4], 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + placeholder_28[ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4], 255), 0), "uint8")
示例#24
0
 def constant_binds_wrapped():
     x = T.int32(1)
     y = T.float32(42.0)
     T.evaluate(T.cast(x, "float32") + y)
示例#25
0
 def constant_binds():
     x = 1
     y = 42.0
     T.evaluate(T.cast(x, "float32") + y)
 def before(A: T.Buffer[(4, 4), "float32"]):
     for i, j in T.grid(4, 4):
         x = T.var("float32")
         A[i, j] = T.Let(x, T.cast(i + 1, "float32"),
                         5.0 * x + T.cast(j, "float32"))
 def expected(A: T.Buffer[(4, 4), "float32"]):
     for i in T.serial(4):
         x = T.cast(i + 1, "float32")
         for j in T.serial(4):
             A[i, j] = 5.0 * x + T.cast(j, "float32")
 def main(placeholder: T.Buffer[(1, 13, 13, 3, 85), "float32"], placeholder_1: T.Buffer[(1, 26, 26, 3, 85), "float32"], placeholder_2: T.Buffer[(1, 52, 52, 3, 85), "float32"], T_expand_dims: T.Buffer[(1, 80, 10647), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     T_strided_slice_with_axes = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
     T_sigmoid = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
     T_strided_slice_with_axes_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
     T_sigmoid_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
     T_multiply = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
     T_reshape = T.alloc_buffer([8112, 80], dtype="float32")
     T_strided_slice_with_axes_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32")
     T_sigmoid_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32")
     T_strided_slice_with_axes_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32")
     T_sigmoid_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32")
     T_multiply_1 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32")
     T_reshape_1 = T.alloc_buffer([2028, 80], dtype="float32")
     T_strided_slice_with_axes_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32")
     T_sigmoid_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32")
     T_strided_slice_with_axes_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32")
     T_sigmoid_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32")
     T_multiply_2 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32")
     T_reshape_2 = T.alloc_buffer([507, 80], dtype="float32")
     T_concat = T.alloc_buffer([10647, 80], dtype="float32")
     T_transpose = T.alloc_buffer([80, 10647], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1):
         with T.block("T_strided_slice_with_axes"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)])
             T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1):
         with T.block("T_sigmoid"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80):
         with T.block("T_strided_slice_with_axes_1"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)])
             T.writes(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80):
         with T.block("T_sigmoid_1"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_1[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80):
         with T.block("T_multiply"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_sigmoid[ax0, ax1, ax2, ax3, 0], T_sigmoid_1[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4])
             T_multiply[ax0, ax1, ax2, ax3, ax4] = T_sigmoid[ax0, ax1, ax2, ax3, 0] * T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]
     for i0, i1 in T.grid(8112, 80):
         with T.block("T_reshape"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80])
             T.writes(T_reshape[ax0, ax1])
             T_reshape[ax0, ax1] = T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1):
         with T.block("T_strided_slice_with_axes_2"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)])
             T.writes(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1):
         with T.block("T_sigmoid_2"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_2[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_2[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80):
         with T.block("T_strided_slice_with_axes_3"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)])
             T.writes(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80):
         with T.block("T_sigmoid_3"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_3[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80):
         with T.block("T_multiply_1"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_sigmoid_2[ax0, ax1, ax2, ax3, 0], T_sigmoid_3[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4])
             T_multiply_1[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_2[ax0, ax1, ax2, ax3, 0] * T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]
     for i0, i1 in T.grid(2028, 80):
         with T.block("T_reshape_1"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80])
             T.writes(T_reshape_1[ax0, ax1])
             T_reshape_1[ax0, ax1] = T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1):
         with T.block("T_strided_slice_with_axes_4"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)])
             T.writes(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1):
         with T.block("T_sigmoid_4"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_4[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_4[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80):
         with T.block("T_strided_slice_with_axes_5"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)])
             T.writes(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80):
         with T.block("T_sigmoid_5"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_5[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80):
         with T.block("T_multiply_2"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_sigmoid_4[ax0, ax1, ax2, ax3, 0], T_sigmoid_5[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4])
             T_multiply_2[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_4[ax0, ax1, ax2, ax3, 0] * T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]
     for i0, i1 in T.grid(507, 80):
         with T.block("T_reshape_2"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80])
             T.writes(T_reshape_2[ax0, ax1])
             T_reshape_2[ax0, ax1] = T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]
     for i0, i1 in T.grid(10647, 80):
         with T.block("T_concat"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_reshape[ax0 - 2535, ax1], T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1])
             T.writes(T_concat[ax0, ax1])
             T_concat[ax0, ax1] = T.if_then_else(2535 <= ax0, T_reshape[ax0 - 2535, ax1], T.if_then_else(507 <= ax0, T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1], dtype="float32"), dtype="float32")
     for i0, i1 in T.grid(80, 10647):
         with T.block("T_transpose"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_concat[ax1, ax0])
             T.writes(T_transpose[ax0, ax1])
             T_transpose[ax0, ax1] = T_concat[ax1, ax0]
     for i0, i1, i2 in T.grid(1, 80, 10647):
         with T.block("T_expand_dims"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(T_transpose[ax1, ax2])
             T.writes(T_expand_dims[ax0, ax1, ax2])
             T_expand_dims[ax0, ax1, ax2] = T_transpose[ax1, ax2]
 def before(A: T.Buffer[1, "int32"]):
     A[0] = T.cast(T.ceil(T.log2(T.cast(14, "float64"), dtype="float64"),
                          dtype="float64"),
                   dtype="int32")
示例#30
0
 def main(
     X: T.Buffer[(128, 128), "int8"],
     W: T.Buffer[(128, 128), "int8"],
     compute: T.Buffer[(128, 128), "int32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     compute_local = T.alloc_buffer([128, 128],
                                    dtype="int32",
                                    scope="local")
     X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(2,
                                                     thread="threadIdx.x"):
                 for i2_0_0 in T.serial(2):
                     for ax0_ax1_fused in T.serial(1024):
                         with T.block("X_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(X[v0, v1])
                             T.writes(X_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 4})
                             X_shared[v0, v1] = X[v0, v1]
                     for ax0_ax1_fused in T.serial(4096):
                         with T.block("W_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused % 2 * 64 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(W[v0, v1])
                             T.writes(W_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 1})
                             W_shared[v0, v1] = W[v0, v1]
                     for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(
                             2, 4, 16, 8, 4, 1):
                         with T.block("compute_o"):
                             i = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 +
                                 i0_4)
                             j = T.axis.spatial(
                                 128,
                                 i0_0_i1_0_fused % 2 * 64 +
                                 i0_1_i1_1_fused * 32 +
                                 i0_2_i1_2_fused * 16 + i1_3,
                             )
                             k_o = T.axis.reduce(
                                 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2)
                             T.reads(
                                 X_shared[i, k_o * 4:k_o * 4 + 4],
                                 W_shared[j, k_o * 4:k_o * 4 + 4],
                             )
                             T.writes(compute_local[i, j])
                             T.block_attr(
                                 {"meta_schedule.auto_tensorize": "dp4a"})
                             with T.init():
                                 with T.block("compute_init"):
                                     T.reads()
                                     T.writes(compute_local[i, j])
                                     compute_local[i, j] = 0
                             for i2_1 in T.serial(4):
                                 with T.block("compute"):
                                     k = T.axis.reduce(4, i2_1)
                                     T.reads(
                                         compute_local[i, j],
                                         X_shared[i, k_o * 4 + k],
                                         W_shared[j, k_o * 4 + k],
                                     )
                                     T.writes(compute_local[i, j])
                                     T.block_attr({
                                         "meta_schedule.tiling_structure":
                                         "SSSRRSRS"
                                     })
                                     compute_local[
                                         i,
                                         j] = compute_local[i, j] + T.cast(
                                             X_shared[i, k_o * 4 + k],
                                             "int32") * T.cast(
                                                 W_shared[j, k_o * 4 + k],
                                                 "int32")
                 for ax0, ax1 in T.grid(16, 16):
                     with T.block("compute_local"):
                         v0 = T.axis.spatial(
                             128, i0_0_i1_0_fused // 2 * 16 + ax0)
                         v1 = T.axis.spatial(
                             128,
                             i0_0_i1_0_fused % 2 * 64 +
                             i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 +
                             ax1,
                         )
                         T.reads(compute_local[v0, v1])
                         T.writes(compute[v0, v1])
                         compute[v0, v1] = compute_local[v0, v1]