Example #1
0
 def main(placeholder: T.Buffer[(1, 384), "int64"],
          placeholder_1: T.Buffer[(30522, 768), "float32"],
          placeholder_2: T.Buffer[(1, 384, 768), "float32"],
          T_add: T.Buffer[(1, 384, 768), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0, i1, i2 in T.grid(1, 384, 768):
         with T.block("T_add_1"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(
                 placeholder[ax0, ax1], placeholder_1[
                     T.min(T.max(T.int64(0), placeholder[
                         ax0, ax1]), T.int64(30521)):T.min(
                             T.max(T.int64(0), placeholder[ax0, ax1] +
                                   T.int64(30522)), T.int64(30521)) +
                     T.int64(1), ax2], placeholder_2[ax0, ax1, ax2])
             T.writes(T_add[ax0, ax1, ax2])
             T_add[ax0, ax1, ax2] = placeholder_1[T.min(
                 T.max(
                     T.int64(0),
                     T.Select(
                         T.cast(placeholder[ax0, ax1] < T.int64(0), "int32"
                                ) != 0, placeholder[ax0, ax1] +
                         T.int64(30522), placeholder[ax0, ax1])
                 ), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]
def compacted_sparse_read_cache(
    A_data: T.Buffer[(819,), "float32"],
    B: T.Buffer[(128,), "float32"],
    A_indptr: T.Buffer[(129,), "int32"],
    A_indices: T.Buffer[(819,), "int32"],
) -> None:
    for i in T.serial(128):
        with T.block("rowsum_outer"):
            T.reads(
                A_indptr[i : i + 1],
                A_data[A_indptr[i] + 0 : A_indptr[i] + 0 + (A_indptr[i + 1] - A_indptr[i])],
            )
            T.writes(B[i])
            with T.block("rowsum_init"):
                T.reads()
                T.writes(B[i])
                B[i] = T.float32(0)
            for k in T.serial(A_indptr[i + 1] - A_indptr[i]):
                with T.block():
                    T.reads(A_indptr[i], A_data[A_indptr[i] + k], B[i])
                    T.writes(B[i])
                    A_data_local = T.alloc_buffer([1], dtype="float32", scope="local")
                    with T.block("A_data_cache_read"):
                        T.reads(A_indptr[i], A_data[A_indptr[i] + k])
                        T.writes(A_data_local[T.min(A_indptr[i] + k, 0)])
                        A_data_local[T.min(A_indptr[i] + k, 0)] = A_data[A_indptr[i] + k]
                    with T.block("rowsum_inner"):
                        T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
                        T.writes(B[i])
                        B[i] = B[i] + A_data_local[T.min(A_indptr[i] + k, 0)]
def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [127, 128])
    B = T.match_buffer(b, [127, 128])
    for i in T.grid(4):
        for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128):
            with T.block("B"):
                vi = T.axis.S(
                    127,
                    i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1),
                )
                vj = T.axis.S(128, T.floormod(j_k_fused, 128))
                T.reads([A[vi, vj]])
                T.writes([B[vi, vj]])
                B[vi, vj] = A[vi, vj]
Example #4
0
def expected_recursive_bufferslice_indices(data: T.handle,
                                           index: T.handle) -> None:
    index_buf = T.match_buffer(index, [1],
                               dtype="int32",
                               elem_offset=0,
                               align=128,
                               offset_factor=1)
    data_buf = T.match_buffer(data, [16, 16],
                              elem_offset=0,
                              align=128,
                              offset_factor=1)
    with T.block("root"):
        T.reads([])
        T.writes([])
        out_buf = T.alloc_buffer([16, 16],
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
        for i0, i1 in T.grid(16, 16):
            with T.block():
                vi, vj = T.axis.remap("SS", [i0, i1])
                T.reads([
                    data_buf[index_buf[index_buf[0]], index_buf[0]],
                    index_buf[T.min(index_buf[0], 0):T.max(index_buf[0], 0) +
                              1],
                ])
                T.writes([out_buf[vi, vj]])
                out_buf[vi, vj] = data_buf[index_buf[index_buf[0]],
                                           index_buf[0]]
Example #5
0
 def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(
         placeholder: T.handle, placeholder_1: T.handle,
         T_cast: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast",
         "tir.noalias": True
     })
     placeholder_2 = T.match_buffer(placeholder, [360000], dtype="uint8")
     placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
     T_cast_1 = T.match_buffer(T_cast, [360000], dtype="int16")
     # body
     for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
         T_cast_1[
             ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 +
             ax3_inner] = T.cast(
                 T.cast(
                     T.max(
                         T.min(
                             T.q_multiply_shift(T.cast(
                                 placeholder_2[ax0_ax1_fused * 4800 +
                                               ax2 * 64 + ax3_outer * 16 +
                                               ax3_inner], "int32") - 94,
                                                1843157232,
                                                31,
                                                1,
                                                dtype="int32") +
                             placeholder_3[ax3_outer * 16 + ax3_inner],
                             255), 0), "uint8"), "int16")
Example #6
0
def elementwise_not_affine(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (127, 128))
    B = T.match_buffer(b, (127, 128))
    for i in T.serial(0, 4):
        for j, k in T.grid(T.min(31, 126 - i * 32) + 1, 128):
            with T.block("B"):
                vi = T.axis.S(127, i * 32 + j)
                vj = T.axis.S(128, k)
                B[vi, vj] = A[vi, vj]
Example #7
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(
         placeholder_4: T.handle, placeholder_5: T.handle,
         placeholder_6: T.handle, T_cast_2: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast",
         "tir.noalias": True
     })
     placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64],
                                    dtype="int16")
     placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64],
                                    dtype="int16")
     placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64],
                                    dtype="int32")
     T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
     # body
     PaddedInput = T.allocate([360000], "int16", "global")
     for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
         T.store(
             PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3,
             T.load("int16", placeholder_7.data,
                    i0_i1_fused * 4800 + i2 * 64 + i3), True)
     for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
         Conv2dOutput = T.allocate([64], "int32", "global")
         for ff in T.serial(0, 64):
             T.store(Conv2dOutput, ff, 0, True)
             for rc in T.serial(0, 64):
                 T.store(
                     Conv2dOutput, ff,
                     T.load("int32", Conv2dOutput, ff) + T.cast(
                         T.load("int16", PaddedInput,
                                ax0_ax1_fused_ax2_fused * 64 + rc), "int32")
                     * T.cast(
                         T.load("int16", placeholder_8.data, rc * 64 + ff),
                         "int32"), True)
         for ax3_inner_1 in T.serial(0, 64):
             T.store(
                 T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1,
                 T.cast(
                     T.cast(
                         T.max(
                             T.min(
                                 T.q_multiply_shift(
                                     T.load("int32", Conv2dOutput,
                                            ax3_inner_1) +
                                     T.load("int32", placeholder_9.data,
                                            ax3_inner_1),
                                     1843106743,
                                     31,
                                     -6,
                                     dtype="int32"), 255), 0), "uint8"),
                     "int16"), True)
Example #8
0
def access_of_padding_pattern() -> None:
    X = T.alloc_buffer([28, 28])
    X_pad = T.alloc_buffer([32, 32])
    Y = T.alloc_buffer([28, 28])
    for i, j in T.grid(32, 32):
        with T.block("padding"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads([
                X[T.max(vi - 2, 0):T.min(vi - 2, 27) + 1,
                  T.max(vj - 2, 0):T.min(vj - 2, 27) + 1, ]
            ])
            T.writes([X_pad[vi, vj]])
            X_pad[vi, vj] = T.if_then_else(2 <= vi and vi < 30 and 2 <= vj
                                           and vj < 30,
                                           X[vi - 2, vj - 2],
                                           0.0,
                                           dtype="float32")
        with T.block("padding_reverse"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads([
                X_pad[T.max(vi, 2):T.min(vi, 29) + 1,
                      T.max(vj, 2):T.min(vj, 29) + 1]
            ])
            T.writes([
                Y[T.max(vi - 2, 0):T.min(vi - 2, 27) + 1,
                  T.max(vj - 2, 0):T.min(vj - 2, 27) + 1, ]
            ])
            if 2 <= vi and vi < 30 and 2 <= vj and vj < 30:
                Y[vi - 2, vj - 2] = X_pad[vi, vj]
Example #9
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(
         placeholder_10: T.handle, placeholder_11: T.handle,
         placeholder_12: T.handle, T_cast_4: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1",
         "tir.noalias": True
     })
     placeholder_13 = T.match_buffer(placeholder_10, [360000],
                                     dtype="int16")
     placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16")
     placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32")
     T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16")
     # body
     PaddedInput_1 = T.allocate([379456], "int16", "global")
     for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
         PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 +
                       i3_1] = T.if_then_else(
                           1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76
                           and 1 <= i2_1 and i2_1 < 76,
                           placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 +
                                          i3_1 - 4864],
                           T.int16(0),
                           dtype="int16")
     for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
         Conv2dOutput_1 = T.allocate([64], "int32", "global")
         for ff_1 in T.serial(0, 64):
             Conv2dOutput_1[ff_1] = 0
             for ry, rx, rc_1 in T.grid(3, 3, 64):
                 Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast(
                     PaddedInput_1[
                         T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 +
                         ry * 4928 + rx * 64 +
                         T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 +
                         rc_1], "int32") * T.cast(
                             placeholder_14[ry * 12288 + rx * 4096 +
                                            rc_1 * 64 + ff_1], "int32")
         for ax3_inner_2 in T.serial(0, 64):
             T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 +
                      ax3_inner_2] = T.cast(
                          T.cast(
                              T.max(
                                  T.min(
                                      T.q_multiply_shift(
                                          Conv2dOutput_1[ax3_inner_2] +
                                          placeholder_15[ax3_inner_2],
                                          1608879842,
                                          31,
                                          -7,
                                          dtype="int32"), 255), 0),
                              "uint8"), "int16")
Example #10
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(
         placeholder_16: T.handle, placeholder_17: T.handle,
         placeholder_18: T.handle, T_add: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_",
         "tir.noalias": True
     })
     placeholder_19 = T.match_buffer(placeholder_16, [360000],
                                     dtype="int16")
     placeholder_20 = T.match_buffer(placeholder_17, [16384], dtype="int16")
     placeholder_21 = T.match_buffer(placeholder_18, [256], dtype="int32")
     T_add_1 = T.match_buffer(T_add, [1440000], dtype="int32")
     # body
     PaddedInput_2 = T.allocate([360000], "int16", "global")
     for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64):
         PaddedInput_2[i0_i1_fused_2 * 4800 + i2_2 * 64 +
                       i3_2] = placeholder_19[i0_i1_fused_2 * 4800 +
                                              i2_2 * 64 + i3_2]
     for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625):
         Conv2dOutput_2 = T.allocate([64], "int32", "global")
         for ax3_outer_1 in T.serial(0, 4):
             for ff_2 in T.serial(0, 64):
                 Conv2dOutput_2[ff_2] = 0
                 for rc_2 in T.serial(0, 64):
                     Conv2dOutput_2[ff_2] = Conv2dOutput_2[ff_2] + T.cast(
                         PaddedInput_2[ax0_ax1_fused_ax2_fused_2 * 64 +
                                       rc_2],
                         "int32") * T.cast(
                             placeholder_20[rc_2 * 256 + ax3_outer_1 * 64 +
                                            ff_2], "int32")
             for ax3_inner_3 in T.serial(0, 64):
                 T_add_1[
                     ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 +
                     ax3_inner_3] = T.q_multiply_shift(T.cast(
                         T.cast(
                             T.max(
                                 T.min(
                                     T.q_multiply_shift(
                                         Conv2dOutput_2[ax3_inner_3] +
                                         placeholder_21[ax3_outer_1 * 64 +
                                                        ax3_inner_3],
                                         1711626602,
                                         31,
                                         -8,
                                         dtype="int32") + 132, 255), 0),
                             "uint8"), "int32") - 132,
                                                       2094289803,
                                                       31,
                                                       -2,
                                                       dtype="int32") + 136
Example #11
0
def reduction_loop_only(
    A: T.Buffer[2, "float32"],
    B: T.Buffer[2, "float32"],
    C: T.Buffer[(), "float32"],
) -> None:
    for i0 in T.serial(2):
        with T.block("C"):
            k0 = T.axis.reduce(2, i0)
            T.reads(A[k0], B[k0])
            T.writes(C[()])
            with T.init():
                C[()] = T.float32(1.0)
            C[()] = T.min(C[()], A[k0] / B[k0])
Example #12
0
def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [16], "float32")
    B = T.alloc_buffer([16], "float32")
    C = T.match_buffer(c, [16], "float32")
    for j in T.serial(0, 16):
        for i in T.serial(0, T.min(1, 15 - j) + 1):
            with T.block("B"):
                v = T.axis.S(16, j + i)
                B[v] = A[v]
        with T.block("C"):
            v = T.axis.S(16, j)
            T.reads([B[v : v + 2]])
            C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32")
Example #13
0
def compacted_symbolic_func(a: T.handle, c: T.handle, n: T.int32) -> None:
    A = T.match_buffer(a, (n * 8, ), "float32")
    C = T.match_buffer(c, (n * 8, ), "float32")
    for i in range(0, n):
        with T.block():
            T.reads(A[i * 8:i * 8 + 8])
            T.writes(C[i * 8:i * 8 + 8])
            B = T.alloc_buffer((T.min(n, 1) * 8, ), "float32")
            for j in range(0, 8):
                with T.block() as []:
                    T.reads(A[i * 8 + j])
                    T.writes(B[j])
                    B[j] = A[i * 8 + j] + 1.0
            for j in range(0, 8):
                with T.block() as []:
                    T.reads(B[j])
                    T.writes(C[i * 8 + j])
                    C[i * 8 + j] = B[j] * 2.0
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None:
    # function attr dict
    T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True})
    placeholder_33 = T.match_buffer(placeholder_30, [150528], dtype="int16", elem_offset=0, align=128, offset_factor=1)
    placeholder_34 = T.match_buffer(placeholder_31, [3072], dtype="int16", elem_offset=0, align=128, offset_factor=1)
    placeholder_35 = T.match_buffer(placeholder_32, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1)
    T_cast_9 = T.match_buffer(T_cast_8, [12544], dtype="int16", elem_offset=0, align=128, offset_factor=1)
    # body
    PaddedInput_3 = T.allocate([150528], "int16", "global")
    for i0_i1_fused_3 in T.parallel(0, 28):
        for i2_3, i3_3 in T.grid(28, 192):
            PaddedInput_3[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3) ] = placeholder_33[(((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)]
    for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784):
        for ax3_2 in T.serial(0, 16):
            Conv2dOutput_3 = T.allocate([1], "int32", "global")
            Conv2dOutput_3[0] = 0
            for rc_3 in T.serial(0, 192):
                Conv2dOutput_3[0] = (Conv2dOutput_3[0] + (T.cast(PaddedInput_3[((ax0_ax1_fused_ax2_fused_3*192) + rc_3)], "int32")*T.cast(placeholder_34[((rc_3*16) + ax3_2)], "int32")))
            T_cast_9[((ax0_ax1_fused_ax2_fused_3*16) + ax3_2)] = T.cast(T.cast(T.max(T.min(T.q_multiply_shift((Conv2dOutput_3[0] + placeholder_35[ax3_2]), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16")
Example #15
0
def different_access_indices(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128, 128], dtype="float32")
    B = T.match_buffer(b, [128, 128], dtype="float32")
    for i, j in T.grid(128, 128):
        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads([B[vi, vj], A[vi, vj, vk]])
                T.writes([
                    B[T.min(vj, vi):T.min(vj, vi) +
                      (T.max(vj, vi) + 1 - T.min(vj, vi)),
                      T.min(vi, vj):T.min(vi, vj) +
                      (T.max(vi, vj) + 1 - T.min(vi, vj)), ]
                ])
                with T.init():
                    B[vj, vi] = T.float32(0)
                B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
Example #16
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(
         placeholder_4: T.handle, placeholder_5: T.handle,
         placeholder_6: T.handle, T_cast_2: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast",
         "tir.noalias": True
     })
     placeholder_7 = T.match_buffer(placeholder_4, [360000], dtype="int16")
     placeholder_8 = T.match_buffer(placeholder_5, [4096], dtype="int16")
     placeholder_9 = T.match_buffer(placeholder_6, [64], dtype="int32")
     T_cast_3 = T.match_buffer(T_cast_2, [360000], dtype="int16")
     # body
     PaddedInput = T.allocate([360000], "int16", "global")
     for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
         PaddedInput[i0_i1_fused * 4800 + i2 * 64 +
                     i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3]
     for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
         Conv2dOutput = T.allocate([64], "int32", "global")
         for ff in T.serial(0, 64):
             Conv2dOutput[ff] = 0
             for rc in T.serial(0, 64):
                 Conv2dOutput[ff] = Conv2dOutput[ff] + T.cast(
                     PaddedInput[ax0_ax1_fused_ax2_fused * 64 + rc],
                     "int32") * T.cast(placeholder_8[rc * 64 + ff], "int32")
         for ax3_inner_1 in T.serial(0, 64):
             T_cast_3[ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1] = T.cast(
                 T.cast(
                     T.max(
                         T.min(
                             T.q_multiply_shift(Conv2dOutput[ax3_inner_1] +
                                                placeholder_9[ax3_inner_1],
                                                1843106743,
                                                31,
                                                -6,
                                                dtype="int32"), 255), 0),
                     "uint8"), "int16")
Example #17
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(
         placeholder_62: T.handle, placeholder_63: T.handle,
         placeholder_64: T.handle, T_cast_20: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast",
         "tir.noalias": True
     })
     placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3],
                                     dtype="int16",
                                     elem_offset=0,
                                     align=128,
                                     offset_factor=1)
     placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64],
                                     dtype="int16",
                                     elem_offset=0,
                                     align=128,
                                     offset_factor=1)
     placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64],
                                     dtype="int32",
                                     elem_offset=0,
                                     align=128,
                                     offset_factor=1)
     T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64],
                                dtype="uint8",
                                elem_offset=0,
                                align=128,
                                offset_factor=1)
     # body
     PaddedInput_7 = T.allocate([157323], "int16", "global")
     for i0_i1_fused_7 in T.serial(0, 229):
         for i2_7, i3_7 in T.grid(229, 3):
             T.store(
                 PaddedInput_7,
                 (((i0_i1_fused_7 * 687) + (i2_7 * 3)) + i3_7),
                 T.if_then_else(
                     ((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and
                       (2 <= i2_7)) and (i2_7 < 226)),
                     T.load("int16", placeholder_65.data,
                            ((((i0_i1_fused_7 * 672) +
                               (i2_7 * 3)) + i3_7) - 1350)),
                     T.int16(0),
                     dtype="int16"), True)
     for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544):
         Conv2dOutput_7 = T.allocate([64], "int32", "global")
         for ff_3 in T.serial(0, 64):
             T.store(Conv2dOutput_7, ff_3, 0, True)
             for ry_2, rx_2, rc_7 in T.grid(7, 7, 3):
                 T.store(
                     Conv2dOutput_7, ff_3,
                     (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(
                         T.load("int16", PaddedInput_7, ((
                             (((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112) *
                                1374) + (ry_2 * 687)) +
                              (T.floormod(ax0_ax1_fused_ax2_fused_7, 112) *
                               6)) +
                             (rx_2 * 3)) + rc_7)), "int32") * T.cast(
                                 T.load("int16", placeholder_66.data,
                                        ((((ry_2 * 1344) + (rx_2 * 192)) +
                                          (rc_7 * 64)) + ff_3)), "int32"))),
                     True)
         for ax3_inner_7 in T.serial(0, 64):
             T.store(
                 T_cast_21.data,
                 ((ax0_ax1_fused_ax2_fused_7 * 64) + ax3_inner_7),
                 T.cast(
                     T.max(
                         T.min(
                             T.q_multiply_shift(
                                 (T.load("int32", Conv2dOutput_7,
                                         ax3_inner_7) +
                                  T.load("int32", placeholder_67.data,
                                         ax3_inner_7)),
                                 1939887962,
                                 31,
                                 -9,
                                 dtype="int32"), 255), 0), "uint8"), True)
Example #18
0
def conv2d_winograd_cuda(  # type: ignore
    placeholder: T.Buffer[(1, 14, 14, 128), "float32"],  # type: ignore
    placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"],  # type: ignore
    conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"],  # type: ignore
) -> None:
    # type: ignore
    data_pad = T.alloc_buffer([1, 16, 16, 128])
    input_tile = T.alloc_buffer([6, 6, 9, 128])
    B = T.alloc_buffer([6, 6])
    data_pack = T.alloc_buffer([6, 6, 9, 128])
    bgemm = T.alloc_buffer([6, 6, 9, 128])
    A = T.alloc_buffer([6, 4])
    inverse = T.alloc_buffer([4, 4, 9, 128])
    for i0, i1, i2, i3 in T.grid(1, 16, 16, 128):
        with T.block("data_pad"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.block_attr({"schedule_rule": "None"})
            T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]])
            T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]])
            data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
                0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14,  # type: ignore
                placeholder[i0_1, i1_1, i2_1, i3_1],
                T.float32(0),
                dtype="float32",
            )
    for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128):
        with T.block("input_tile"):
            eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2])
            T.block_attr({"schedule_rule": "None"})
            T.reads(
                [
                    data_pad[
                        T.floordiv(p, 9),  # type: ignore
                        ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                        ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                        ci,
                    ]
                ]
            )
            T.writes([input_tile[eps, nu, p, ci]])
            input_tile[eps, nu, p, ci] = data_pad[
                T.floordiv(p, 9),  # type: ignore
                ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                ci,
            ]
    for i0_3, i1_3 in T.grid(6, 6):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0_3, i1_3])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([B[i, j]])
            # fmt: off
            B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6):
        with T.block("data_pack"):
            eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap(
                "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"})
            T.reads(
                [
                    data_pack[eps_1, nu_1, p_1, ci_1],
                    input_tile[r_a, r_b, p_1, ci_1],
                    B[
                        T.min(r_a, r_b) : (  # type: ignore
                            T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))  # type: ignore
                        ),
                        T.min(eps_1, nu_1) : (  # type: ignore
                            T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1))  # type: ignore
                        ),
                    ],
                ]
            )
            T.writes([data_pack[eps_1, nu_1, p_1, ci_1]])
            with T.init():
                data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0)
            data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + (
                (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1]
            )
    for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128):
        with T.block("bgemm"):
            eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1])
            T.block_attr({"meta_schedule.write_cache_level": [3]})
            T.reads(
                [
                    bgemm[eps_2, nu_2, p_2, co],
                    data_pack[eps_2, nu_2, p_2, ci_2],
                    placeholder_1[eps_2, nu_2, co, ci_2],
                ]
            )
            T.writes([bgemm[eps_2, nu_2, p_2, co]])
            with T.init():
                bgemm[eps_2, nu_2, p_2, co] = T.float32(0)
            bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + (
                data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2]
            )
    for i0_6, i1_6 in T.grid(6, 4):
        with T.block("A"):
            i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([A[i_1, j_1]])
            # fmt: off
            A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6):
        with T.block("inverse"):
            vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap(
                "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"})
            T.reads(
                [
                    inverse[vh, vw, p_3, co_1],
                    bgemm[r_a_1, r_b_1, p_3, co_1],
                    A[
                        T.min(r_a_1, r_b_1) : (  # type: ignore
                            T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))  # type: ignore
                        ),
                        T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))),  # type: ignore
                    ],
                ]
            )
            T.writes([inverse[vh, vw, p_3, co_1]])
            with T.init():
                inverse[vh, vw, p_3, co_1] = T.float32(0)
            inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + (
                (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw]
            )
    for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128):
        with T.block("conv2d_winograd"):
            n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6])
            T.reads(
                [
                    inverse[
                        T.floormod(h, 4),  # type: ignore
                        T.floormod(w, 4),  # type: ignore
                        (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                        co_2,
                    ]
                ]
            )
            T.writes([conv2d_winograd[n, h, w, co_2]])
            conv2d_winograd[n, h, w, co_2] = inverse[
                T.floormod(h, 4),  # type: ignore
                T.floormod(w, 4),  # type: ignore
                (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                co_2,
            ]
Example #19
0
 def main(
     placeholder: T.Buffer[(1, 384), "int64"],
     placeholder_1: T.Buffer[(30522, 768), "float32"],
     placeholder_2: T.Buffer[(1, 384, 768), "float32"],
     T_add: T.Buffer[(1, 384, 768), "float32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     compile_engine_const = T.alloc_buffer([], dtype="int64")
     T_less = T.alloc_buffer([1, 384], dtype="bool")
     compile_engine_const_1 = T.alloc_buffer([], dtype="int64")
     T_add_1 = T.alloc_buffer([1, 384], dtype="int64")
     T_where = T.alloc_buffer([1, 384], dtype="int64")
     T_take = T.alloc_buffer([1, 384, 768], dtype="float32")
     with T.block("compile_engine_const"):
         vi = T.axis.spatial(1, 0)
         T.reads()
         T.writes(compile_engine_const[()])
         compile_engine_const[()] = T.int64(0)
     for i0, i1 in T.grid(1, 384):
         with T.block("T_less"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(placeholder[ax0, ax1], compile_engine_const[()])
             T.writes(T_less[ax0, ax1])
             T_less[ax0,
                    ax1] = placeholder[ax0, ax1] < compile_engine_const[()]
     with T.block("compile_engine_const_1"):
         vi = T.axis.spatial(1, 0)
         T.reads()
         T.writes(compile_engine_const_1[()])
         compile_engine_const_1[()] = T.int64(30522)
     for i0, i1 in T.grid(1, 384):
         with T.block("T_add"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(placeholder[ax0, ax1], compile_engine_const_1[()])
             T.writes(T_add_1[ax0, ax1])
             T_add_1[ax0,
                     ax1] = placeholder[ax0,
                                        ax1] + compile_engine_const_1[()]
     for i0, i1 in T.grid(1, 384):
         with T.block("T_where"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_less[ax0, ax1], T_add_1[ax0, ax1], placeholder[ax0,
                                                                      ax1])
             T.writes(T_where[ax0, ax1])
             T_where[ax0, ax1] = T.Select(
                 T.cast(T_less[ax0, ax1], "int32") != 0, T_add_1[ax0, ax1],
                 placeholder[ax0, ax1])
     for i0, i1, i2 in T.grid(1, 384, 768):
         with T.block("T_take"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(
                 placeholder_1[T.min(T.max(T.int64(0), T_where[
                     ax0, ax1]), T.int64(30521)), ax2],
                 T_where[ax0, ax1],
             )
             T.writes(T_take[ax0, ax1, ax2])
             T_take[ax0, ax1, ax2] = placeholder_1[
                 T.min(T.max(T.int64(0), T_where[ax0,
                                                 ax1]), T.int64(30521)),
                 ax2]
     for i0, i1, i2 in T.grid(1, 384, 768):
         with T.block("T_add_1"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(T_take[ax0, ax1, ax2], placeholder_2[ax0, ax1, ax2])
             T.writes(T_add[ax0, ax1, ax2])
             T_add[ax0, ax1,
                   ax2] = T_take[ax0, ax1, ax2] + placeholder_2[ax0, ax1,
                                                                ax2]
Example #20
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(
         placeholder_22: T.handle, placeholder_23: T.handle,
         placeholder_24: T.handle, placeholder_25: T.handle,
         T_cast_6: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_",
         "tir.noalias": True
     })
     placeholder_29 = T.match_buffer(placeholder_22, [360000],
                                     dtype="int16")
     placeholder_27 = T.match_buffer(placeholder_23, [16384], dtype="int16")
     placeholder_26 = T.match_buffer(placeholder_24, [256], dtype="int32")
     placeholder_28 = T.match_buffer(placeholder_25, [1440000],
                                     dtype="int32")
     T_cast_7 = T.match_buffer(T_cast_6, [1440000], dtype="uint8")
     # body
     PaddedInput_3 = T.allocate([360000], "int16", "global")
     for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64):
         PaddedInput_3[i0_i1_fused_3 * 4800 + i2_3 * 64 +
                       i3_3] = placeholder_29[i0_i1_fused_3 * 4800 +
                                              i2_3 * 64 + i3_3]
     for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625):
         Conv2dOutput_3 = T.allocate([64], "int32", "global")
         for ax3_outer_2 in T.serial(0, 4):
             for ff_3 in T.serial(0, 64):
                 Conv2dOutput_3[ff_3] = 0
                 for rc_3 in T.serial(0, 64):
                     Conv2dOutput_3[ff_3] = Conv2dOutput_3[ff_3] + T.cast(
                         PaddedInput_3[ax0_ax1_fused_ax2_fused_3 * 64 +
                                       rc_3],
                         "int32") * T.cast(
                             placeholder_27[rc_3 * 256 + ax3_outer_2 * 64 +
                                            ff_3], "int32")
             for ax3_inner_4 in T.serial(0, 64):
                 T_cast_7[
                     ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 +
                     ax3_inner_4] = T.cast(
                         T.max(
                             T.min(
                                 T.q_multiply_shift(T.cast(
                                     T.cast(
                                         T.max(
                                             T.min(
                                                 T.q_multiply_shift(
                                                     Conv2dOutput_3[
                                                         ax3_inner_4] +
                                                     placeholder_26[
                                                         ax3_outer_2 * 64 +
                                                         ax3_inner_4],
                                                     1343014664,
                                                     31,
                                                     -8,
                                                     dtype="int32") + 136,
                                                 255), 0), "uint8"),
                                     "int32") - 136,
                                                    1073903788,
                                                    31,
                                                    1,
                                                    dtype="int32") +
                                 placeholder_28[ax0_ax1_fused_ax2_fused_3 *
                                                256 + ax3_outer_2 * 64 +
                                                ax3_inner_4], 255), 0),
                         "uint8")
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(
        placeholder_30: T.handle, placeholder_31: T.handle,
        placeholder_32: T.handle, T_cast_8: T.handle) -> None:
    # function attr dict
    T.func_attr({
        "global_symbol":
        "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2",
        "tir.noalias": True
    })
    placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192],
                                    dtype="int16",
                                    elem_offset=0,
                                    align=128,
                                    offset_factor=1)
    placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16],
                                    dtype="int16",
                                    elem_offset=0,
                                    align=128,
                                    offset_factor=1)
    placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16],
                                    dtype="int32",
                                    elem_offset=0,
                                    align=128,
                                    offset_factor=1)
    T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16],
                              dtype="int16",
                              elem_offset=0,
                              align=128,
                              offset_factor=1)
    # body
    PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global")
    for i0_i1_fused_3 in T.parallel(0, 28):
        for i2_3, i3_3 in T.grid(28, 192):
            T.store(
                PaddedInput_3,
                (((i0_i1_fused_3 * 5376) + (i2_3 * 192)) + i3_3),
                T.load("int16", placeholder_33.data,
                       (((i0_i1_fused_3 * 5376) + (i2_3 * 192)) + i3_3)), True)
    for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784):
        for ax3_2 in T.serial(0, 16):
            Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global")
            T.store(Conv2dOutput_3, 0, 0, True)
            for rc_3 in T.serial(0, 192):
                T.store(Conv2dOutput_3, 0,
                        (T.load("int32", Conv2dOutput_3, 0) + (T.cast(
                            T.load("int16", PaddedInput_3,
                                   ((ax0_ax1_fused_ax2_fused_3 * 192) + rc_3)),
                            "int32") * T.cast(
                                T.load("int16", placeholder_34.data,
                                       ((rc_3 * 16) + ax3_2)), "int32"))),
                        True)
            T.store(
                T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3 * 16) + ax3_2),
                T.cast(
                    T.cast(
                        T.max(
                            T.min(
                                T.q_multiply_shift(
                                    (T.load("int32", Conv2dOutput_3, 0) +
                                     T.load("int32", placeholder_35.data,
                                            ax3_2)),
                                    1764006585,
                                    31,
                                    -7,
                                    dtype="int32"), 255), 0), "uint8"),
                    "int16"), True)