Beispiel #1
0
 def main(A: T.handle, tensor: T.handle) -> None:
     # function attr dict
     T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
     # buffer definition
     tensor_2 = T.buffer_decl([1, 10, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1)
     A_1 = T.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1)
     tensor_1 = T.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1)
     # body
     T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "")
     T.realize(tensor_2[0:1, 0:6, 0:12, 0:16], "")
     for ax1_outer in T.serial(0, 2):
         for ax1 in T.serial(0, 6):
             for ax2 in T.serial(0, 12):
                 for ax3 in T.serial(0, 16):
                     if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype='bool') :
                         tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.int8(0)
                     for dh in T.serial(0, 3):
                         for dw in T.serial(0, 3):
                             if T.likely(((ax1_outer < 1) or (ax1 >= 2)), dtype='bool'):
                                 tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3] = T.max(tensor_2[0, T.floormod((ax1 + (ax1_outer*4)), 6), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3])
         for ax1_inner in T.serial(0, 4):
             for ax2_inner in T.serial(0, 8):
                 for ax3_inner in T.serial(0, 16):
                     tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.int8(0)
                     for dh_1 in T.serial(0, 3):
                         for dw_1 in T.serial(0, 5):
                             tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, T.floormod(((ax1_inner + (ax1_outer*4)) + dh_1), 6), (ax2_inner + dw_1), ax3_inner])
Beispiel #2
0
def elementwise_fused(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128))
    B = T.match_buffer(b, (128, 128, 128))
    for fused in T.serial(0, 2097152):
        with T.block("B"):
            vi = T.axis.S(128, T.floordiv(fused, 16384))
            vj = T.axis.S(128, T.floordiv(T.floormod(fused, 16384), 128))
            vk = T.axis.S(128, T.floormod(fused, 128))
            T.reads([A[vi, vj, vk]])
            T.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Beispiel #3
0
 def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle,
                                             T_cast_6: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast",
         "tir.noalias": True
     })
     placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64],
                                     dtype="uint8",
                                     elem_offset=0,
                                     align=128,
                                     offset_factor=1)
     T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64],
                               dtype="int16",
                               elem_offset=0,
                               align=128,
                               offset_factor=1)
     # body
     tensor_2 = T.allocate([200704], "uint8", "global")
     for ax0_ax1_fused_4 in T.serial(0, 56):
         for ax2_4 in T.serial(0, 56):
             for ax3_init in T.serial(0, 64):
                 T.store(tensor_2, (((ax0_ax1_fused_4 * 3584) +
                                     (ax2_4 * 64)) + ax3_init), T.uint8(0),
                         True)
             for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
                 T.store(
                     tensor_2,
                     (((ax0_ax1_fused_4 * 3584) + (ax2_4 * 64)) + ax3_2),
                     T.max(
                         T.load("uint8", tensor_2,
                                (((ax0_ax1_fused_4 * 3584) +
                                  (ax2_4 * 64)) + ax3_2)),
                         T.if_then_else(
                             ((((ax0_ax1_fused_4 * 2) +
                                T.floordiv(rv0_rv1_fused_1, 3)) < 112) and
                              (((ax2_4 * 2) +
                                T.floormod(rv0_rv1_fused_1, 3)) < 112)),
                             T.load("uint8", placeholder_29.data, (((
                                 ((ax0_ax1_fused_4 * 14336) +
                                  (T.floordiv(rv0_rv1_fused_1, 3) * 7168)) +
                                 (ax2_4 * 128)) + (T.floormod(
                                     rv0_rv1_fused_1, 3) * 64)) + ax3_2)),
                             T.uint8(0),
                             dtype="uint8")), True)
     for ax0_ax1_fused_5 in T.serial(0, 56):
         for ax2_5, ax3_3 in T.grid(56, 64):
             T.store(
                 T_cast_7.data,
                 (((ax0_ax1_fused_5 * 3584) + (ax2_5 * 64)) + ax3_3),
                 T.cast(
                     T.load("uint8", tensor_2, (((ax0_ax1_fused_5 * 3584) +
                                                 (ax2_5 * 64)) + ax3_3)),
                     "int16"), True)
Beispiel #4
0
def elementwise_symbolic_fused(a: T.handle, b: T.handle, n: T.int32) -> None:
    A = T.match_buffer(a, (128, 128, n))
    B = T.match_buffer(b, (128, 128, n))
    for i_j_k_fused in T.serial(0, (n * 16384)):
        with T.block("B"):
            vi = T.axis.S(128, T.floordiv(i_j_k_fused, n * 128))
            vj = T.axis.S(128, T.floordiv(T.floormod(i_j_k_fused, n * 128), n))
            vk = T.axis.S(n, T.floormod(i_j_k_fused, n))
            T.reads([A[vi, vj, vk]])
            T.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_fuse_with_opaque_block(a: T.handle, b: T.handle) -> None:
    B = T.match_buffer(b, [128, 128, 128])
    A = T.match_buffer(a, [128, 128, 128])
    for i_j_k_fused in T.serial(0, 2097152):
        with T.block("opaque"):
            T.reads([
                A[T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128
                             ),
                  T.floormod(T.floordiv(i_j_k_fused, 128), 128),
                  T.floormod(i_j_k_fused, 128), ]
            ])
            T.writes([
                B[T.floormod(T.floordiv(T.floordiv(i_j_k_fused, 128), 128), 128
                             ),
                  T.floormod(T.floordiv(i_j_k_fused, 128), 128),
                  T.floormod(i_j_k_fused, 128), ]
            ])
            with T.block("B"):
                vi = T.axis.S(128, T.floordiv(i_j_k_fused, 16384))
                vj = T.axis.S(128, T.floormod(T.floordiv(i_j_k_fused, 128),
                                              128))
                vk = T.axis.S(128, T.floormod(i_j_k_fused, 128))
                T.reads([A[vi, vj, vk]])
                T.writes([B[vi, vj, vk]])
                B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [127, 128])
    B = T.match_buffer(b, [127, 128])
    for i in T.grid(4):
        for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128):
            with T.block("B"):
                vi = T.axis.S(
                    127,
                    i * 32 + T.floormod(T.floordiv(j_k_fused, 128), T.min(31, 126 - i * 32) + 1),
                )
                vj = T.axis.S(128, T.floormod(j_k_fused, 128))
                T.reads([A[vi, vj]])
                T.writes([B[vi, vj]])
                B[vi, vj] = A[vi, vj]
Beispiel #7
0
def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    D = T.match_buffer(d, [16])
    C = T.alloc_buffer([16])
    C_rf = T.alloc_buffer([1, 16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1):
        with T.block("C_rf"):
            vi1_i2_fused_inner, b = T.axis.remap("SS", [i1_i2_fused_inner, i0])
            i = T.axis.R(256, T.floordiv(i1_i2_fused_outer, 256))
            j = T.axis.R(256, T.floormod(i1_i2_fused_outer, 256))
            with T.init():
                C_rf[vi1_i2_fused_inner, b] = 0.0
            C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j])

    for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1):
        with T.block("C"):
            vi1_i2_fused_inner_1, b_1 = T.axis.remap("RS", [i1_i2_fused_inner_1, i0_1])
            with T.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1]

    for i0_2 in T.serial(0, 16):
        with T.block("D"):
            b_2 = T.axis.S(16, i0_2)
            D[b_2] = T.sqrt(C[b_2], dtype="float32")
Beispiel #8
0
def opaque_access_fused(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [16, 16])
    B = T.match_buffer(b, [16, 16])
    for i_j_fused in T.serial(0, 256):
        with T.block("A"):
            vi = T.axis.S(16, T.floordiv(i_j_fused, 16))
            vj = T.axis.S(16, T.floormod(i_j_fused, 16))
            T.reads([])
            T.writes([A[0:16, 0:16]])
            A[vi, vj] = 1
    for i_j_fused in T.serial(0, 256):
        with T.block("B"):
            vi = T.axis.S(16, T.floordiv(i_j_fused, 16))
            vj = T.axis.S(16, T.floormod(i_j_fused, 16))
            T.reads([])
            T.writes([B[0:16, 0:16]])
            T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle"))
def rowsum_transformed(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128, ))

    for io, ii_ko_fused, ki in T.grid(32, 128, 4):
        with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
            T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32))
            T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki)
            with T.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Beispiel #10
0
def rowsum_transformed(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128,))

    for io, ii_ko_fused, ki in T.grid(32, 128, 4):
        with T.block("B"):
            vi = T.axis.S(128, io * 4 + T.floordiv(ii_ko_fused, 32))
            vk = T.axis.R(128, T.floormod(ii_ko_fused, 32) * 4 + ki)
            with T.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Beispiel #11
0
 def main(placeholder1: T.Buffer[(100, ), "int8"],
          placeholder2: T.Buffer[(100, ), "int8"]) -> None:
     T.attr("i0", "pragma_layout", "NHCWB16")
     for i0 in T.serial(0, 1):
         for i1 in T.serial(0, 1):
             for i2 in T.serial(0, 1):
                 for i3 in T.serial(0, 6):
                     for i4 in T.serial(0, 16):
                         placeholder1[((i3 * 16) +
                                       i4)] = placeholder2[((T.floormod(
                                           (i3 + 4), 6) * 16) + i4)]
Beispiel #12
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(
         placeholder_10: T.handle, placeholder_11: T.handle,
         placeholder_12: T.handle, T_cast_4: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1",
         "tir.noalias": True
     })
     placeholder_13 = T.match_buffer(placeholder_10, [360000],
                                     dtype="int16")
     placeholder_14 = T.match_buffer(placeholder_11, [36864], dtype="int16")
     placeholder_15 = T.match_buffer(placeholder_12, [64], dtype="int32")
     T_cast_5 = T.match_buffer(T_cast_4, [360000], dtype="int16")
     # body
     PaddedInput_1 = T.allocate([379456], "int16", "global")
     for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
         PaddedInput_1[i0_i1_fused_1 * 4928 + i2_1 * 64 +
                       i3_1] = T.if_then_else(
                           1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76
                           and 1 <= i2_1 and i2_1 < 76,
                           placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 +
                                          i3_1 - 4864],
                           T.int16(0),
                           dtype="int16")
     for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
         Conv2dOutput_1 = T.allocate([64], "int32", "global")
         for ff_1 in T.serial(0, 64):
             Conv2dOutput_1[ff_1] = 0
             for ry, rx, rc_1 in T.grid(3, 3, 64):
                 Conv2dOutput_1[ff_1] = Conv2dOutput_1[ff_1] + T.cast(
                     PaddedInput_1[
                         T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 +
                         ry * 4928 + rx * 64 +
                         T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 +
                         rc_1], "int32") * T.cast(
                             placeholder_14[ry * 12288 + rx * 4096 +
                                            rc_1 * 64 + ff_1], "int32")
         for ax3_inner_2 in T.serial(0, 64):
             T_cast_5[ax0_ax1_fused_ax2_fused_1 * 64 +
                      ax3_inner_2] = T.cast(
                          T.cast(
                              T.max(
                                  T.min(
                                      T.q_multiply_shift(
                                          Conv2dOutput_1[ax3_inner_2] +
                                          placeholder_15[ax3_inner_2],
                                          1608879842,
                                          31,
                                          -7,
                                          dtype="int32"), 255), 0),
                              "uint8"), "int16")
Beispiel #13
0
 def main(placeholder1: T.Buffer[(100, ), "int8"],
          placeholder2: T.Buffer[(100, ), "int8"]) -> None:
     T.attr("i0", "pragma_layout", "NHWC")
     for i0 in T.serial(0, 1):
         for i1 in T.serial(0, 5):
             for i2 in T.serial(0, 6):
                 for i3 in T.serial(0, 4):
                     placeholder1[(
                         ((i1 * 24) + (i2 * 4)) +
                         i3)] = placeholder2[(((((T.floordiv(
                             (i1 - 1), 2) * 48) + (T.floormod(
                                 (i1 + 1), 2) * 24)) + (i2 * 4)) + i3) +
                                              96)]
Beispiel #14
0
def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    D = T.match_buffer(d, [16])
    C = T.alloc_buffer([16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536):
        with T.block("C"):
            b = T.axis.S(16, i0)
            i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256))
            j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256))
            with T.init():
                C[b] = 0.0
            C[b] = C[b] + (A[b, i, j] * A[b, i, j])
    for i0_1 in T.serial(0, 16):
        with T.block("D"):
            b_1 = T.axis.S(16, i0_1)
            D[b_1] = T.sqrt(C[b_1], dtype="float32")
def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    D = T.match_buffer(d, [16])
    C = T.alloc_buffer([16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1):
        with T.block([16, T.reduce_axis(0, 256),
                      T.reduce_axis(0, 256)], "C") as [b, i, j]:
            T.bind(b, i0)
            T.bind(i, T.floordiv(i1_i2_fused_outer, 256))
            T.bind(j, T.floormod(i1_i2_fused_outer, 256))
            T.reads([C[b], A[b, i, j]])
            T.writes([C[b]])
            with T.init():
                C[b] = 0.0
            C[b] = C[b] + (A[b, i, j] * A[b, i, j])
    for i0_1 in T.serial(0, 16):
        with T.block([16], "D") as [b_1]:
            T.bind(b_1, i0_1)
            T.reads([C[b_1]])
            T.writes([D[b_1]])
            D[b_1] = T.sqrt(C[b_1], dtype="float32")
def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    D = T.match_buffer(d, [16])
    C = T.alloc_buffer([16])
    C_rf = T.alloc_buffer([1, 16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1):
        with T.block(
            [1, 16, T.reduce_axis(0, 256),
             T.reduce_axis(0, 256)], "C_rf") as [
                 vi1_i2_fused_inner,
                 b,
                 i,
                 j,
             ]:
            T.bind(vi1_i2_fused_inner, i1_i2_fused_inner)
            T.bind(b, i0)
            T.bind(i, T.floordiv(i1_i2_fused_outer, 256))
            T.bind(j, T.floormod(i1_i2_fused_outer, 256))
            with T.init():
                C_rf[vi1_i2_fused_inner, b] = 0.0
            C_rf[vi1_i2_fused_inner,
                 b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j])

    for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1):
        with T.block([T.reduce_axis(0, 1), 16],
                     "C") as [vi1_i2_fused_inner_1, b_1]:
            T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1)
            T.bind(b_1, i0_1)
            with T.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1]

    for i0_2 in T.serial(0, 16):
        with T.block([16], "D") as [b_2]:
            T.bind(b_2, i0_2)
            D[b_2] = T.sqrt(C[b_2], dtype="float32")
Beispiel #17
0
 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(
         placeholder_62: T.handle, placeholder_63: T.handle,
         placeholder_64: T.handle, T_cast_20: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol":
         "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast",
         "tir.noalias": True
     })
     placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3],
                                     dtype="int16",
                                     elem_offset=0,
                                     align=128,
                                     offset_factor=1)
     placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64],
                                     dtype="int16",
                                     elem_offset=0,
                                     align=128,
                                     offset_factor=1)
     placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64],
                                     dtype="int32",
                                     elem_offset=0,
                                     align=128,
                                     offset_factor=1)
     T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64],
                                dtype="uint8",
                                elem_offset=0,
                                align=128,
                                offset_factor=1)
     # body
     PaddedInput_7 = T.allocate([157323], "int16", "global")
     for i0_i1_fused_7 in T.serial(0, 229):
         for i2_7, i3_7 in T.grid(229, 3):
             T.store(
                 PaddedInput_7,
                 (((i0_i1_fused_7 * 687) + (i2_7 * 3)) + i3_7),
                 T.if_then_else(
                     ((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and
                       (2 <= i2_7)) and (i2_7 < 226)),
                     T.load("int16", placeholder_65.data,
                            ((((i0_i1_fused_7 * 672) +
                               (i2_7 * 3)) + i3_7) - 1350)),
                     T.int16(0),
                     dtype="int16"), True)
     for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544):
         Conv2dOutput_7 = T.allocate([64], "int32", "global")
         for ff_3 in T.serial(0, 64):
             T.store(Conv2dOutput_7, ff_3, 0, True)
             for ry_2, rx_2, rc_7 in T.grid(7, 7, 3):
                 T.store(
                     Conv2dOutput_7, ff_3,
                     (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(
                         T.load("int16", PaddedInput_7, ((
                             (((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112) *
                                1374) + (ry_2 * 687)) +
                              (T.floormod(ax0_ax1_fused_ax2_fused_7, 112) *
                               6)) +
                             (rx_2 * 3)) + rc_7)), "int32") * T.cast(
                                 T.load("int16", placeholder_66.data,
                                        ((((ry_2 * 1344) + (rx_2 * 192)) +
                                          (rc_7 * 64)) + ff_3)), "int32"))),
                     True)
         for ax3_inner_7 in T.serial(0, 64):
             T.store(
                 T_cast_21.data,
                 ((ax0_ax1_fused_ax2_fused_7 * 64) + ax3_inner_7),
                 T.cast(
                     T.max(
                         T.min(
                             T.q_multiply_shift(
                                 (T.load("int32", Conv2dOutput_7,
                                         ax3_inner_7) +
                                  T.load("int32", placeholder_67.data,
                                         ax3_inner_7)),
                                 1939887962,
                                 31,
                                 -9,
                                 dtype="int32"), 255), 0), "uint8"), True)
Beispiel #18
0
def conv2d_winograd_cuda(  # type: ignore
    placeholder: T.Buffer[(1, 14, 14, 128), "float32"],  # type: ignore
    placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"],  # type: ignore
    conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"],  # type: ignore
) -> None:
    # type: ignore
    data_pad = T.alloc_buffer([1, 16, 16, 128])
    input_tile = T.alloc_buffer([6, 6, 9, 128])
    B = T.alloc_buffer([6, 6])
    data_pack = T.alloc_buffer([6, 6, 9, 128])
    bgemm = T.alloc_buffer([6, 6, 9, 128])
    A = T.alloc_buffer([6, 4])
    inverse = T.alloc_buffer([4, 4, 9, 128])
    for i0, i1, i2, i3 in T.grid(1, 16, 16, 128):
        with T.block("data_pad"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.block_attr({"schedule_rule": "None"})
            T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]])
            T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]])
            data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
                0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14,  # type: ignore
                placeholder[i0_1, i1_1, i2_1, i3_1],
                T.float32(0),
                dtype="float32",
            )
    for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128):
        with T.block("input_tile"):
            eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2])
            T.block_attr({"schedule_rule": "None"})
            T.reads(
                [
                    data_pad[
                        T.floordiv(p, 9),  # type: ignore
                        ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                        ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                        ci,
                    ]
                ]
            )
            T.writes([input_tile[eps, nu, p, ci]])
            input_tile[eps, nu, p, ci] = data_pad[
                T.floordiv(p, 9),  # type: ignore
                ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                ci,
            ]
    for i0_3, i1_3 in T.grid(6, 6):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0_3, i1_3])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([B[i, j]])
            # fmt: off
            B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6):
        with T.block("data_pack"):
            eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap(
                "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"})
            T.reads(
                [
                    data_pack[eps_1, nu_1, p_1, ci_1],
                    input_tile[r_a, r_b, p_1, ci_1],
                    B[
                        T.min(r_a, r_b) : (  # type: ignore
                            T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))  # type: ignore
                        ),
                        T.min(eps_1, nu_1) : (  # type: ignore
                            T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1))  # type: ignore
                        ),
                    ],
                ]
            )
            T.writes([data_pack[eps_1, nu_1, p_1, ci_1]])
            with T.init():
                data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0)
            data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + (
                (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1]
            )
    for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128):
        with T.block("bgemm"):
            eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1])
            T.block_attr({"meta_schedule.write_cache_level": [3]})
            T.reads(
                [
                    bgemm[eps_2, nu_2, p_2, co],
                    data_pack[eps_2, nu_2, p_2, ci_2],
                    placeholder_1[eps_2, nu_2, co, ci_2],
                ]
            )
            T.writes([bgemm[eps_2, nu_2, p_2, co]])
            with T.init():
                bgemm[eps_2, nu_2, p_2, co] = T.float32(0)
            bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + (
                data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2]
            )
    for i0_6, i1_6 in T.grid(6, 4):
        with T.block("A"):
            i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([A[i_1, j_1]])
            # fmt: off
            A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6):
        with T.block("inverse"):
            vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap(
                "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"})
            T.reads(
                [
                    inverse[vh, vw, p_3, co_1],
                    bgemm[r_a_1, r_b_1, p_3, co_1],
                    A[
                        T.min(r_a_1, r_b_1) : (  # type: ignore
                            T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))  # type: ignore
                        ),
                        T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))),  # type: ignore
                    ],
                ]
            )
            T.writes([inverse[vh, vw, p_3, co_1]])
            with T.init():
                inverse[vh, vw, p_3, co_1] = T.float32(0)
            inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + (
                (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw]
            )
    for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128):
        with T.block("conv2d_winograd"):
            n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6])
            T.reads(
                [
                    inverse[
                        T.floormod(h, 4),  # type: ignore
                        T.floormod(w, 4),  # type: ignore
                        (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                        co_2,
                    ]
                ]
            )
            T.writes([conv2d_winograd[n, h, w, co_2]])
            conv2d_winograd[n, h, w, co_2] = inverse[
                T.floormod(h, 4),  # type: ignore
                T.floormod(w, 4),  # type: ignore
                (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                co_2,
            ]