def main(a: T.handle, b: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "T.noalias": True})
     # var definition
     threadIdx_x = T.env_thread("threadIdx.x")
     threadIdx_y = T.env_thread("threadIdx.y")
     blockIdx_x = T.env_thread("blockIdx.x")
     blockIdx_y = T.env_thread("blockIdx.y")
     blockIdx_z = T.env_thread("blockIdx.z")
     A = T.match_buffer(a, [14 * 14 * 256 * 256], dtype="float32")
     B = T.match_buffer(b, [14 * 14 * 512 * 256], dtype="float32")
     # body
     T.launch_thread(blockIdx_z, 196)
     B_local = T.allocate([64], "float32", "local")
     Apad_shared = T.allocate([512000], "float32", "shared")
     Apad_shared_local = T.allocate([8], "float32", "local")
     T.launch_thread(blockIdx_y, 8)
     T.launch_thread(blockIdx_x, 4)
     T.launch_thread(threadIdx_y, 8)
     T.launch_thread(threadIdx_x, 8)
     for ff_c_init, nn_c_init in T.grid(8, 8):
         B_local[ff_c_init * 8 + nn_c_init] = T.float32(0)
     for rc_outer, ry, rx in T.grid(32, 3, 3):
         for ax3_inner_outer in T.serial(0, 2):
             Apad_shared[T.ramp(
                 threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4,
                 1, 4)] = T.if_then_else(
                     1 <= blockIdx_z // 14 + ry
                     and blockIdx_z // 14 + ry < 15
                     and 1 <= rx + blockIdx_z % 14
                     and rx + blockIdx_z % 14 < 15,
                     A[T.ramp(
                         ry * 917504 + blockIdx_z * 65536 + rx * 65536 +
                         rc_outer * 2048 + threadIdx_y * 256 +
                         blockIdx_x * 64 + threadIdx_x * 8 +
                         ax3_inner_outer * 4 - 983040, 1, 4)],
                     T.broadcast(T.float32(0), 4),
                     dtype="float32x4",
                 )
             # Access of the last element of Apad_shared prevents
             # buffer compacting from reducing the amount of shared
             # memory used.
             Apad_shared[512000 - 1] = 0.0
         for rc_inner in T.serial(0, 8):
             for ax3 in T.serial(0, 8):
                 Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 +
                                                      threadIdx_x * 8 + ax3]
             for ff_c, nn_c in T.grid(8, 8):
                 B_local[ff_c * 8 +
                         nn_c] = B_local[ff_c * 8 +
                                         nn_c] + Apad_shared_local[nn_c]
     for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
         B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 +
           ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 +
           nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 +
                                           nn_inner_inner_inner]  # fmt: on
 def main(a: T.handle, b: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "T.noalias": True})
     # var definition
     threadIdx_x = T.env_thread("threadIdx.x")
     threadIdx_y = T.env_thread("threadIdx.y")
     blockIdx_x = T.env_thread("blockIdx.x")
     blockIdx_y = T.env_thread("blockIdx.y")
     blockIdx_z = T.env_thread("blockIdx.z")
     A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32")
     B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32")
     # body
     T.launch_thread(blockIdx_z, 196)
     B_local = T.allocate([6400000], "float32", "local")
     Apad_shared = T.allocate([512], "float32", "shared")
     Apad_shared_local = T.allocate([8], "float32", "local")
     T.launch_thread(blockIdx_y, 8)
     T.launch_thread(blockIdx_x, 4)
     T.launch_thread(threadIdx_y, 8)
     T.launch_thread(threadIdx_x, 8)
     for ff_c_init, nn_c_init in T.grid(8, 8):
         T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True)
     for rc_outer, ry, rx in T.grid(32, 3, 3):
         for ax3_inner_outer in T.serial(0, 2):
             T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4))
         for rc_inner in T.serial(0, 8):
             for ax3 in T.serial(0, 8):
                 T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True)
             for ff_c, nn_c in T.grid(8, 8):
                 T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True)
     for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
         T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on
 def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # var definition
     threadIdx_x = T.env_thread("threadIdx.x")
     blockIdx_x = T.env_thread("blockIdx.x")
     T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data)
     T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data)
     T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data)
     # body
     T.launch_thread(blockIdx_x, 64)
     conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local")
     PadInput_shared = T.allocate([768], "float32", "shared")
     weight_shared = T.allocate([4096], "float32", "shared")
     T.launch_thread(threadIdx_x, 32)
     for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2):
         conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0)
     for i6_0 in T.serial(16):
         for ax0_ax1_ax2_ax3_fused_0 in T.serial(24):
             PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32")
         for ax0_ax1_ax2_ax3_fused_0 in T.serial(32):
             weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
         for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2):
             conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024]
     for ax1, ax2 in T.grid(2, 4):
         conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
Example #4
0
def dot_product_4x4_i8i8i32_neon(
        A: T.Buffer((4, ), "int8", offset_factor=1),
        B: T.Buffer((4, 4), "int8", offset_factor=1),
        C: T.Buffer((4, ), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:4], A[0:4], B[0:4, 0:4])
        T.writes(C[0:4])

        A_int8 = A.vload([0], "int8x4")
        re_int32 = T.reinterpret(A_int8, dtype="int32")
        vec_ai32 = T.broadcast(re_int32, 2)
        vec_a = T.reinterpret(vec_ai32, dtype="int8x8")

        vec_b = B.vload([0, 0], dtype="int8x16")

        # TODO(masahi): Remove duplication when inlined function call is supported
        vec_b_low = T.vectorlow(vec_b, dtype="int8x8")

        multiply_low = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
            T.uint32(2),
            vec_a,
            vec_b_low,
            dtype="int16x8",
        )

        pairwise_reduction_low = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
            T.uint32(1),
            multiply_low,
            dtype="int32x4",
        )

        vec_b_high = T.vectorhigh(vec_b, dtype="int8x8")

        multiply_high = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
            T.uint32(2),
            vec_a,
            vec_b_high,
            dtype="int16x8",
        )

        pairwise_reduction_high = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
            T.uint32(1),
            multiply_high,
            dtype="int32x4",
        )

        C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
            T.uint32(2),
            pairwise_reduction_low,
            pairwise_reduction_high,
            dtype="int32x4",
        )
Example #5
0
def dot_product_16x4_u8i8i32_vnni(
    A: T.Buffer((4,), "uint8", offset_factor=1),
    B: T.Buffer((16, 4), "int8", offset_factor=1),
    C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
        T.writes(C[0:16])

        A_u8x4 = A.vload([0], "uint8x4")
        A_i32 = T.reinterpret(A_u8x4, dtype="int32")

        B_i8x64 = B.vload([0, 0], dtype="int8x64")
        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")

        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this is an update +=
            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
            T.uint32(0),
            T.int32x16(0),
            T.broadcast(A_i32, 16),
            B_i32x16,
            dtype="int32x16",
        )
Example #6
0
def dot_product_4x4_i8i8i32_sdot(
        A: T.Buffer((4, ), "int8", offset_factor=1),
        B: T.Buffer((4, 4), "int8", offset_factor=1),
        C: T.Buffer((4, ), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:4], A[0:4], B[0:4, 0:4])
        T.writes(C[0:4])

        A_i8x4 = A.vload([0], "int8x4")
        A_i32 = T.reinterpret(A_i8x4, dtype="int32")
        vec_ai32 = T.broadcast(A_i32, 4)
        vec_a = T.reinterpret(vec_ai32, dtype="int8x16")

        vec_b = B.vload([0, 0], dtype="int8x16")

        C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
            T.uint32(3),
            T.int32x4(0),
            vec_a,
            vec_b,
            dtype="int32x4",
        )
 def main(
         inputs: T.Buffer[(1, 4, 4, 512),
                          "float32"], weight: T.Buffer[(4, 4, 512, 256),
                                                       "float32"],
         conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256),
                                         "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # var definition
     threadIdx_x = T.env_thread("threadIdx.x")
     blockIdx_x = T.env_thread("blockIdx.x")
     # body
     T.launch_thread(blockIdx_x, 64)
     conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local")
     PadInput_shared = T.allocate([768], "float32", "shared")
     weight_shared = T.allocate([4096], "float32", "shared")
     T.launch_thread(threadIdx_x, 32)
     for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2):
         T.store(conv2d_transpose_nhwc_local,
                 i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0),
                 True)
     for i6_0 in T.serial(16):
         for ax0_ax1_ax2_ax3_fused_0 in T.serial(24):
             T.store(
                 PadInput_shared,
                 ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x,
                 T.if_then_else(
                     128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x
                     and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640
                     and 1 <= blockIdx_x // 32 * 2 +
                     (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 //
                     32 and blockIdx_x // 32 * 2 +
                     (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 //
                     32 < 5,
                     T.load(
                         "float32", inputs.data, blockIdx_x // 32 * 1024 +
                         ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 +
                         threadIdx_x - 2560),
                     T.float32(0),
                     dtype="float32"), True)
         for ax0_ax1_ax2_ax3_fused_0 in T.serial(32):
             T.store(
                 weight_shared,
                 T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1,
                        4),
                 T.load(
                     "float32x4", weight.data,
                     T.ramp(
                         (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4)
                         // 256 * 131072 + i6_0 * 8192 +
                         (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) %
                         256 // 8 * 256 + blockIdx_x % 32 * 8 +
                         threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)),
                 T.broadcast(True, 4))
         for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(
                 4, 2, 4, 4, 8, 2, 2):
             T.store(
                 conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4,
                 T.load("float32", conv2d_transpose_nhwc_local,
                        i1_4 * 4 + i2_3 * 2 + i2_4) +
                 T.if_then_else(
                     (i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0,
                     T.load(
                         "float32", PadInput_shared,
                         threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 +
                         (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 +
                         i6_2),
                     T.float32(0),
                     dtype="float32") * T.load(
                         "float32", weight_shared, i6_1 * 64 + i6_2 * 8 +
                         threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024),
                 True)
     for ax1, ax2 in T.grid(2, 4):
         T.store(
             conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 +
             ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 +
             blockIdx_x % 32 * 8 + threadIdx_x % 8,
             T.load("float32", conv2d_transpose_nhwc_local,
                    ax1 * 4 + ax2), True)
Example #8
0
 def main(
     X: T.Buffer[(128, 128), "int8"],
     W: T.Buffer[(128, 128), "int8"],
     compute: T.Buffer[(128, 128), "int32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     compute_local = T.alloc_buffer([128, 128],
                                    dtype="int32",
                                    scope="local")
     X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(2,
                                                     thread="threadIdx.x"):
                 for i0_3_init, i1_3_init, i0_4_init in T.grid(4, 16, 4):
                     with T.block("compute_o_init"):
                         i = T.axis.spatial(
                             128, i0_0_i1_0_fused // 2 * 16 +
                             i0_3_init * 4 + i0_4_init)
                         j = T.axis.spatial(
                             128,
                             i0_0_i1_0_fused % 2 * 64 +
                             i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 +
                             i1_3_init,
                         )
                         T.reads()
                         T.writes(compute_local[i, j])
                         T.block_attr(
                             {"meta_schedule.auto_tensorize": "dp4a"})
                         with T.block("compute_init"):
                             T.reads()
                             T.writes(compute_local[i, j])
                             compute_local[i, j] = 0
                 for i2_0_0 in T.serial(2):
                     for ax0_ax1_fused in T.serial(1024):
                         with T.block("X_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(X[v0, v1])
                             T.writes(X_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 4})
                             X_shared[v0, v1] = X[v0, v1]
                     for ax0_ax1_fused in T.serial(4096):
                         with T.block("W_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused % 2 * 64 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(W[v0, v1])
                             T.writes(W_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 1})
                             W_shared[v0, v1] = W[v0, v1]
                     for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(
                             2, 4, 16, 8, 4, 1):
                         with T.block("compute_o_update"):
                             i = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 +
                                 i0_4)
                             j = T.axis.spatial(
                                 128,
                                 i0_0_i1_0_fused % 2 * 64 +
                                 i0_1_i1_1_fused * 32 +
                                 i0_2_i1_2_fused * 16 + i1_3,
                             )
                             k_o = T.axis.reduce(
                                 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2)
                             T.reads(
                                 compute_local[i, j],
                                 X_shared[i, k_o * 4:k_o * 4 + 4],
                                 W_shared[j, k_o * 4:k_o * 4 + 4],
                             )
                             T.writes(compute_local[i, j])
                             A = T.match_buffer(
                                 X_shared[i, k_o * 4:k_o * 4 + 4],
                                 [4],
                                 dtype="int8",
                                 scope="shared",
                                 align=4,
                                 offset_factor=1,
                             )
                             B = T.match_buffer(
                                 W_shared[j, k_o * 4:k_o * 4 + 4],
                                 [4],
                                 dtype="int8",
                                 scope="shared",
                                 align=4,
                                 offset_factor=1,
                             )
                             C = T.match_buffer(
                                 compute_local[i, j],
                                 [1],
                                 dtype="int32",
                                 scope="local",
                                 align=4,
                                 offset_factor=1,
                             )
                             C[0] = C[0] + T.call_pure_extern(
                                 "__dp4a",
                                 A[T.ramp(0, 1, 4)],
                                 B[T.ramp(0, 1, 4)],
                                 0,
                                 dtype="int32",
                             )
                 for ax0, ax1 in T.grid(16, 16):
                     with T.block("compute_local"):
                         v0 = T.axis.spatial(
                             128, i0_0_i1_0_fused // 2 * 16 + ax0)
                         v1 = T.axis.spatial(
                             128,
                             i0_0_i1_0_fused % 2 * 64 +
                             i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 +
                             ax1,
                         )
                         T.reads(compute_local[v0, v1])
                         T.writes(compute[v0, v1])
                         compute[v0, v1] = compute_local[v0, v1]
Example #9
0
 def main(
     placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
     placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
     conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid(
             1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1):
         for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid(
                 4, 7, 4, 4):
             with T.block("conv2d_NCHWc_int8_o_init"):
                 n = T.axis.spatial(1, 0)
                 oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init)
                 oh = T.axis.spatial(56,
                                     i2_0 * 28 + i2_2_init * 4 + i2_3_init)
                 ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init)
                 oc_block_o = T.axis.spatial(1, 0)
                 T.reads()
                 T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
                 for i4_1 in T.vectorized(16):
                     with T.block("conv2d_NCHWc_int8_init"):
                         oc_block_init = T.axis.spatial(16, i4_1)
                         T.reads()
                         T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                                    oc_block_init])
                         conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                           oc_block_init] = 0
         for (
                 i7_0,
                 i8_0,
                 i9_0_0,
                 i0_2,
                 i1_2,
                 i2_2,
                 i3_2,
                 i4_0_2,
                 i5_1,
                 i6_1,
                 i7_1,
                 i8_1,
                 i9_0_1,
                 i0_3,
                 i1_3,
                 i2_3,
                 i3_3,
                 i4_0_3,
         ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1):
             with T.block("conv2d_NCHWc_int8_o_update"):
                 n = T.axis.spatial(1, 0)
                 oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2)
                 oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3)
                 ow = T.axis.spatial(56, i3_1 * 4 + i3_3)
                 oc_block_o = T.axis.spatial(1, 0)
                 kh = T.axis.reduce(1, 0)
                 kw = T.axis.reduce(1, 0)
                 ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1])
                 ic_s_inner_o = T.axis.reduce(1, 0)
                 T.reads(
                     conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16],
                     placeholder[n, ic_outer, oh + kh, ow + kw,
                                 ic_f_inner * 4:ic_f_inner * 4 + 4],
                     placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                                   0:16, 0:4],
                 )
                 T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
                 A = T.match_buffer(
                     placeholder[n, ic_outer, oh + kh, ow + kw,
                                 ic_f_inner * 4:ic_f_inner * 4 + 4],
                     [4],
                     dtype="uint8",
                     offset_factor=1,
                 )
                 B = T.match_buffer(
                     placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                                   0:16, 0:4],
                     [16, 4],
                     dtype="int8",
                     offset_factor=1,
                 )
                 C = T.match_buffer(
                     conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16],
                     [16],
                     dtype="int32",
                     offset_factor=1,
                 )
                 A_u8x4 = A.vload([0], "uint8x4")
                 A_i32 = T.reinterpret(A_u8x4, dtype="int32")
                 B_i8x64 = B.vload([0, 0], dtype="int8x64")
                 B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
                 C[T.ramp(
                     0, 1,
                     16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin(
                         T.llvm_lookup_intrinsic_id(
                             "llvm.x86.avx512.vpdpbusd.512"),
                         T.uint32(0),
                         T.broadcast(0, 16),
                         T.broadcast(A_i32, 16),
                         B_i32x16,
                         dtype="int32x16",
                     )