Esempio n. 1
0
def get_valid_counts(
    data: T.handle,
    valid_count: T.handle,
    out: T.handle,
    out_indices: T.handle,
    score_threshold: T.float32,
    id_index: T.int32,
    score_index: T.int32,
) -> None:

    data_buf = T.match_buffer(data, (1, 2500, 6), "float32")
    valid_count_buf = T.match_buffer(valid_count, (1, ), "int32")
    out_buf = T.match_buffer(out, (1, 2500, 6), "float32")
    out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32")

    with T.block([1], "init") as [vi]:
        valid_count_buf[vi] = T.int32(0)
        with T.block([2500], "update") as [vj]:
            T.reads([data_buf[vi, vj, 6]])
            T.writes([
                valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj,
                                                                      6]
            ])
            if (data_buf[vi, vj, score_index] > score_threshold) and (
                (id_index < 0) or
                (data_buf[vi, vj, id_index] >= T.float32(0))):
                for k in T.serial(0, 6):
                    out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k]
                out_indices_buf[vi, valid_count_buf[vi]] = vj
                valid_count_buf[vi] = valid_count_buf[vi] + 1
            if vj >= valid_count_buf[vi]:
                for k in T.serial(0, 6):
                    out_buf[vi, vj, k] = T.float32(-1)
                out_indices_buf[vi, vj] = T.int32(-1)
Esempio n. 2
0
def dot_product_4x4_i8i8i32_neon(
        A: T.Buffer((4, ), "int8", offset_factor=1),
        B: T.Buffer((4, 4), "int8", offset_factor=1),
        C: T.Buffer((4, ), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:4], A[0:4], B[0:4, 0:4])
        T.writes(C[0:4])

        A_int8 = A.vload([0], "int8x4")
        re_int32 = T.reinterpret(A_int8, dtype="int32")
        vec_ai32 = T.broadcast(re_int32, 2)
        vec_a = T.reinterpret(vec_ai32, dtype="int8x8")

        vec_b = B.vload([0, 0], dtype="int8x16")

        # TODO(masahi): Remove duplication when inlined function call is supported
        vec_b_low = T.vectorlow(vec_b, dtype="int8x8")

        multiply_low = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
            T.uint32(2),
            vec_a,
            vec_b_low,
            dtype="int16x8",
        )

        pairwise_reduction_low = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
            T.uint32(1),
            multiply_low,
            dtype="int32x4",
        )

        vec_b_high = T.vectorhigh(vec_b, dtype="int8x8")

        multiply_high = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
            T.uint32(2),
            vec_a,
            vec_b_high,
            dtype="int16x8",
        )

        pairwise_reduction_high = T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
            T.uint32(1),
            multiply_high,
            dtype="int32x4",
        )

        C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
            T.uint32(2),
            pairwise_reduction_low,
            pairwise_reduction_high,
            dtype="int32x4",
        )
Esempio n. 3
0
def dp4a_impl(
    A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
    B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
    C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
    with T.block("root"):
        T.reads(C[0], A[0:4], B[0:4])
        T.writes(C[0])

        C[0] += T.call_pure_extern(
            "__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"), T.int32(0), dtype="int32"
        )
Esempio n. 4
0
def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 32], dtype="int8")
    B = T.match_buffer(b, [8, 32], dtype="int8")
    C = T.match_buffer(c, [16, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([16], "int8", scope="local")
    MultiB = T.allocate([8], "int8", scope="local")
    Accum = T.allocate([4], "int32", scope="local")
    for i in range(4):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in range(16):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col % 8 // 4 * 8,
            (tx % 32) % 4 * 4 + mma_multi_a_col % 4 + mma_multi_a_col // 8 * 16,
        ]
    for mma_multi_b_col in range(8):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4,
            (tx % 32) % 4 * 4 + mma_multi_b_col % 4 + mma_multi_b_col // 4 * 16,
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k32",
            "row",
            "col",
            "int8",
            "int8",
            "int32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8,
            (tx % 32) % 4 * 2 + mma_accum_c_id % 2,
        ] = T.load("int32", Accum, mma_accum_c_id)
Esempio n. 5
0
def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 256], dtype="int1")
    B = T.match_buffer(b, [8, 256], dtype="int1")
    C = T.match_buffer(c, [16, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([128], "int1", scope="local")
    MultiB = T.allocate([64], "int1", scope="local")
    Accum = T.allocate([4], "int32", scope="local")
    for i in range(4):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in range(128):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col % 64 // 32 * 8,
            (tx % 32) % 4 * 32 + mma_multi_a_col % 32 + mma_multi_a_col // 64 * 128,
        ]
    for mma_multi_b_col in range(16):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4,
            (tx % 32) % 4 * 32 + mma_multi_b_col % 32 + mma_multi_b_col // 32 * 128,
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k256",
            "row",
            "col",
            "int1",
            "int1",
            "int32",
            MultiA.data,
            0,
            MultiB.data,
            0,
            Accum.data,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8,
            (tx % 32) % 4 * 2 + mma_accum_c_id % 2,
        ] = Accum[mma_accum_c_id]
Esempio n. 6
0
def dot_product_16x4_u8i8i32_desc(
    A: T.Buffer((4,), "uint8", offset_factor=1),
    B: T.Buffer((16, 4), "int8", offset_factor=1),
    C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
        T.writes(C[0:16])
        for i in T.serial(0, 16):
            with T.init():
                C[i] = T.int32(0)
            for k in T.serial(0, 4):
                with T.block("update"):
                    vi, vk = T.axis.remap("SR", [i, k])
                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
Esempio n. 7
0
def sdot4(
    A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
    B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
    C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
) -> None:
    with T.block("root"):
        T.reads(C[0], A[0:4], B[0:4])
        T.writes(C[0])

        C[0] += T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"),
            T.uint32(4),
            T.reinterpret(A.vload([0], "int8x4"), dtype="int32"),
            T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
            T.int32(0),
            T.bool(1),
            dtype="int32",
        )
Esempio n. 8
0
def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [8, 32], dtype="int4")
    B = T.match_buffer(b, [8, 32], dtype="uint4")
    C = T.match_buffer(c, [8, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([8], "int4", scope="local")
    MultiB = T.allocate([8], "uint4", scope="local")
    Accum = T.allocate([2], "int32", scope="local")
    for i in range(2):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in T.vectorized(8):
        MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8]
    for mma_multi_b_col in T.vectorized(8):
        MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8]
    T.evaluate(
        T.ptx_mma(
            "m8n8k32",
            "row",
            "col",
            "int4",
            "uint4",
            "int32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(2):
        C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load(
            "int32", Accum, mma_accum_c_id
        )
Esempio n. 9
0
def tir_argmax_val_idx(
    var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle
) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    m = T.var("int32")
    n = T.var("int32")
    val = T.match_buffer(var_val, [m, n], dtype="float32")
    idx = T.match_buffer(var_idx, [m, n], dtype="int32")
    argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32")
    argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32")
    for i0, i1 in T.grid(m, n):
        with T.block("argmax"):
            i, k = T.axis.remap("SR", [i0, i1])
            T.reads(val[i, k], idx[i, k])
            T.writes(argmax_v0[i], argmax_v1[i])
            with T.init():
                argmax_v0[i] = T.min_value("float32")
                argmax_v1[i] = T.int32(-1)
            v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k])
            v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k])
            argmax_v0[i] = v_argmax_v0
            argmax_v1[i] = v_argmax_v1
Esempio n. 10
0
def dot_product_16x4_u8i8i32_vnni(
    A: T.Buffer((4,), "uint8", offset_factor=1),
    B: T.Buffer((16, 4), "int8", offset_factor=1),
    C: T.Buffer((16,), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
        T.writes(C[0:16])

        A_u8x4 = A.vload([0], "uint8x4")
        A_i32 = T.reinterpret(A_u8x4, dtype="int32")

        B_i8x64 = B.vload([0, 0], dtype="int8x64")
        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")

        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this is an update +=
            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
            T.uint32(0),
            T.int32x16(0),
            T.broadcast(A_i32, 16),
            B_i32x16,
            dtype="int32x16",
        )
Esempio n. 11
0
def dot_product_4x4_i8i8i32_sdot(
        A: T.Buffer((4, ), "int8", offset_factor=1),
        B: T.Buffer((4, 4), "int8", offset_factor=1),
        C: T.Buffer((4, ), "int32", offset_factor=1),
) -> None:
    with T.block("root"):
        T.reads(C[0:4], A[0:4], B[0:4, 0:4])
        T.writes(C[0:4])

        A_i8x4 = A.vload([0], "int8x4")
        A_i32 = T.reinterpret(A_i8x4, dtype="int32")
        vec_ai32 = T.broadcast(A_i32, 4)
        vec_a = T.reinterpret(vec_ai32, dtype="int8x16")

        vec_b = B.vload([0, 0], dtype="int8x16")

        C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
            T.uint32(3),
            T.int32x4(0),
            vec_a,
            vec_b,
            dtype="int32x4",
        )
Esempio n. 12
0
 def constant_binds_wrapped():
     x = T.int32(1)
     y = T.float32(42.0)
     T.evaluate(T.cast(x, "float32") + y)
Esempio n. 13
0
def preflattened_buffer_map(A: T.handle, B: T.handle):
    A_1 = T.match_buffer(A, [1])
    T.preflattened_buffer(A_1, [1], align=T.int32(1), offset_factor=T.int64(2))
    B_1 = T.match_buffer(B, [1])
    T.preflattened_buffer(B_1, [1])
    B_1[0] = A_1[0]