Beispiel #1
0
def dot_product_4x4_i8i8i32_neon(
        A: T.Buffer((4, ), "int8", offset_factor=1),
        B: T.Buffer((4, 4), "int8", offset_factor=1),
        C: T.Buffer((4, ), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:4], A[0:4], B[0:4, 0:4])
        T.writes(C[0:4])

        A_int8 = A.vload([0], "int8x4")
        re_int32 = T.reinterpret(A_int8, dtype="int32")
        vec_ai32 = T.broadcast(re_int32, 2)
        vec_a = T.reinterpret(vec_ai32, dtype="int8x8")

        vec_b = B.vload([0, 0], dtype="int8x16")

        # TODO(masahi): Remove duplication when inlined function call is supported
        vec_b_low = T.vectorlow(vec_b, dtype="int8x8")

        multiply_low = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
            T.uint32(2),
            vec_a,
            vec_b_low,
            dtype="int16x8",
        )

        pairwise_reduction_low = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
            T.uint32(1),
            multiply_low,
            dtype="int32x4",
        )

        vec_b_high = T.vectorhigh(vec_b, dtype="int8x8")

        multiply_high = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
            T.uint32(2),
            vec_a,
            vec_b_high,
            dtype="int16x8",
        )

        pairwise_reduction_high = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
            T.uint32(1),
            multiply_high,
            dtype="int32x4",
        )

        C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
            T.uint32(2),
            pairwise_reduction_low,
            pairwise_reduction_high,
            dtype="int32x4",
        )
Beispiel #2
0
def sdot4(
    A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
    B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
    C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
    with T.block("root"):
        T.reads(C[0], A[0:4], B[0:4])
        T.writes(C[0])

        C[0] += T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"),
            T.uint32(4),
            T.reinterpret(A.vload([0], "int8x4"), dtype="int32"),
            T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
            T.int32(0),
            T.bool(1),
            dtype="int32",
        )
Beispiel #3
0
def dot_product_16x4_u8i8i32_vnni(
    A: T.Buffer((4,), "uint8", offset_factor=1),
    B: T.Buffer((16, 4), "int8", offset_factor=1),
    C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
        T.writes(C[0:16])

        A_u8x4 = A.vload([0], "uint8x4")
        A_i32 = T.reinterpret(A_u8x4, dtype="int32")

        B_i8x64 = B.vload([0, 0], dtype="int8x64")
        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")

        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this is an update +=
            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
            T.uint32(0),
            T.int32x16(0),
            T.broadcast(A_i32, 16),
            B_i32x16,
            dtype="int32x16",
        )
Beispiel #4
0
def dot_product_4x4_i8i8i32_sdot(
        A: T.Buffer((4, ), "int8", offset_factor=1),
        B: T.Buffer((4, 4), "int8", offset_factor=1),
        C: T.Buffer((4, ), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:4], A[0:4], B[0:4, 0:4])
        T.writes(C[0:4])

        A_i8x4 = A.vload([0], "int8x4")
        A_i32 = T.reinterpret(A_i8x4, dtype="int32")
        vec_ai32 = T.broadcast(A_i32, 4)
        vec_a = T.reinterpret(vec_ai32, dtype="int8x16")

        vec_b = B.vload([0, 0], dtype="int8x16")

        C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
            T.uint32(3),
            T.int32x4(0),
            vec_a,
            vec_b,
            dtype="int32x4",
        )
Beispiel #5
0
 def main(
     placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
     placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
     conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid(
             1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1):
         for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid(
                 4, 7, 4, 4):
             with T.block("conv2d_NCHWc_int8_o_init"):
                 n = T.axis.spatial(1, 0)
                 oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init)
                 oh = T.axis.spatial(56,
                                     i2_0 * 28 + i2_2_init * 4 + i2_3_init)
                 ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init)
                 oc_block_o = T.axis.spatial(1, 0)
                 T.reads()
                 T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
                 for i4_1 in T.vectorized(16):
                     with T.block("conv2d_NCHWc_int8_init"):
                         oc_block_init = T.axis.spatial(16, i4_1)
                         T.reads()
                         T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                                    oc_block_init])
                         conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                           oc_block_init] = 0
         for (
                 i7_0,
                 i8_0,
                 i9_0_0,
                 i0_2,
                 i1_2,
                 i2_2,
                 i3_2,
                 i4_0_2,
                 i5_1,
                 i6_1,
                 i7_1,
                 i8_1,
                 i9_0_1,
                 i0_3,
                 i1_3,
                 i2_3,
                 i3_3,
                 i4_0_3,
         ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1):
             with T.block("conv2d_NCHWc_int8_o_update"):
                 n = T.axis.spatial(1, 0)
                 oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2)
                 oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3)
                 ow = T.axis.spatial(56, i3_1 * 4 + i3_3)
                 oc_block_o = T.axis.spatial(1, 0)
                 kh = T.axis.reduce(1, 0)
                 kw = T.axis.reduce(1, 0)
                 ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1])
                 ic_s_inner_o = T.axis.reduce(1, 0)
                 T.reads(
                     conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16],
                     placeholder[n, ic_outer, oh + kh, ow + kw,
                                 ic_f_inner * 4:ic_f_inner * 4 + 4],
                     placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                                   0:16, 0:4],
                 )
                 T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
                 A = T.match_buffer(
                     placeholder[n, ic_outer, oh + kh, ow + kw,
                                 ic_f_inner * 4:ic_f_inner * 4 + 4],
                     [4],
                     dtype="uint8",
                     offset_factor=1,
                 )
                 B = T.match_buffer(
                     placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                                   0:16, 0:4],
                     [16, 4],
                     dtype="int8",
                     offset_factor=1,
                 )
                 C = T.match_buffer(
                     conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16],
                     [16],
                     dtype="int32",
                     offset_factor=1,
                 )
                 A_u8x4 = A.vload([0], "uint8x4")
                 A_i32 = T.reinterpret(A_u8x4, dtype="int32")
                 B_i8x64 = B.vload([0, 0], dtype="int8x64")
                 B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
                 C[T.ramp(
                     0, 1,
                     16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin(
                         T.llvm_lookup_intrinsic_id(
                             "llvm.x86.avx512.vpdpbusd.512"),
                         T.uint32(0),
                         T.broadcast(0, 16),
                         T.broadcast(A_i32, 16),
                         B_i32x16,
                         dtype="int32x16",
                     )