Esempio n. 1
0
File: cuda.py Progetto: were/tvm
    def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(
            a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
        )
        B = T.match_buffer(
            b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
        )
        C = T.match_buffer(
            c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
        )

        with T.block("root"):
            T.reads(
                C[0:WARP_SIZE, 0:local_size_out],
                A[0:WARP_SIZE, 0:local_size],
                B[0:WARP_SIZE, 0:local_size],
            )
            T.writes(C[0:WARP_SIZE, 0:local_size_out])
            tx = T.env_thread("threadIdx.x")
            T.launch_thread(tx, WARP_SIZE)

            T.evaluate(
                T.ptx_mma(
                    mma_prefix,
                    "row",
                    "col",
                    in_dtype_abbrv,
                    in_dtype_abbrv,
                    out_dtype_abbrv,
                    A.data,
                    A.elem_offset + tx * lift(local_size),
                    B.data,
                    B.elem_offset + tx * lift(local_size),
                    C.data,
                    C.elem_offset + tx * lift(local_size_out),
                    False,
                    dtype=out_dtype,
                )
            )

            T.evaluate(
                T.ptx_mma(
                    mma_prefix,
                    "row",
                    "col",
                    in_dtype_abbrv,
                    in_dtype_abbrv,
                    out_dtype_abbrv,
                    A.data,
                    A.elem_offset + tx * lift(local_size),
                    B.data,
                    B.elem_offset + tx * lift(local_size) + lift(local_size) // 2,
                    C.data,
                    C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2,
                    False,
                    dtype=out_dtype,
                )
            )
Esempio n. 2
0
def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 4], dtype="float16")
    B = T.match_buffer(b, [4, 16], dtype="float16")
    C = T.match_buffer(c, [16, 16], dtype="float32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([4], "float16", scope="local")
    MultiB = T.allocate([4], "float16", scope="local")
    Accum = T.allocate([8], "float32", scope="local")
    for i in range(8):
        Accum[i] = T.float32(0)

    for mma_multi_a_col in T.vectorized(4):
        MultiA[mma_multi_a_col] = A[
            ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)),
            mma_multi_a_col,
        ]
    for mma_multi_b_col in T.vectorized(4):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) % 4,
            mma_multi_b_col + (4 * ((tx % 32) // 8)),
        ]
    T.evaluate(
        T.ptx_mma(
            "m8n8k4",
            "row",
            "row",
            "fp16",
            "fp16",
            "fp32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="float32",
        )
    )
    for mma_accum_c_id in range(8):
        C[
            ((tx % 32) % 2)
            + ((mma_accum_c_id // 2 % 2) * 2)
            + 4 * ((tx % 32) // 16)
            + ((tx % 32) % 16 // 4) % 2 * 8,
            (tx % 32) % 4 // 2 * 2
            + (tx % 32) % 16 // 8 * 4
            + mma_accum_c_id % 2
            + mma_accum_c_id // 4 * 8,
        ] = T.load("float32", Accum, mma_accum_c_id)
Esempio n. 3
0
def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 32], dtype="int8")
    B = T.match_buffer(b, [8, 32], dtype="int8")
    C = T.match_buffer(c, [16, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([16], "int8", scope="local")
    MultiB = T.allocate([8], "int8", scope="local")
    Accum = T.allocate([4], "int32", scope="local")
    for i in range(4):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in range(16):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col % 8 // 4 * 8,
            (tx % 32) % 4 * 4 + mma_multi_a_col % 4 + mma_multi_a_col // 8 * 16,
        ]
    for mma_multi_b_col in range(8):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4,
            (tx % 32) % 4 * 4 + mma_multi_b_col % 4 + mma_multi_b_col // 4 * 16,
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k32",
            "row",
            "col",
            "int8",
            "int8",
            "int32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8,
            (tx % 32) % 4 * 2 + mma_accum_c_id % 2,
        ] = T.load("int32", Accum, mma_accum_c_id)
Esempio n. 4
0
def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 16], dtype="float16")
    B = T.match_buffer(b, [8, 16], dtype="float16")
    C = T.match_buffer(c, [16, 8], dtype="float32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([8], "float16", scope="local")
    MultiB = T.allocate([4], "float16", scope="local")
    Accum = T.allocate([4], "float32", scope="local")
    for i in range(4):
        Accum[i] = T.float32(0)

    for mma_multi_a_col in range(8):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col % 4 // 2 * 8,
            (tx % 32) % 4 * 2 + mma_multi_a_col % 2 + mma_multi_a_col // 4 * 8,
        ]
    for mma_multi_b_col in T.vectorized(4):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4,
            (tx % 32) % 4 * 2 + mma_multi_b_col % 2 + mma_multi_b_col // 2 * 8,
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k16",
            "row",
            "col",
            "fp16",
            "fp16",
            "fp32",
            MultiA.data,
            0,
            MultiB.data,
            0,
            Accum.data,
            0,
            False,
            dtype="float32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8,
            (tx % 32) % 4 * 2 + mma_accum_c_id % 2,
        ] = Accum[mma_accum_c_id]
Esempio n. 5
0
def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 256], dtype="int1")
    B = T.match_buffer(b, [8, 256], dtype="int1")
    C = T.match_buffer(c, [16, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([128], "int1", scope="local")
    MultiB = T.allocate([64], "int1", scope="local")
    Accum = T.allocate([4], "int32", scope="local")
    for i in range(4):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in range(128):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col % 64 // 32 * 8,
            (tx % 32) % 4 * 32 + mma_multi_a_col % 32 + mma_multi_a_col // 64 * 128,
        ]
    for mma_multi_b_col in range(16):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4,
            (tx % 32) % 4 * 32 + mma_multi_b_col % 32 + mma_multi_b_col // 32 * 128,
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k256",
            "row",
            "col",
            "int1",
            "int1",
            "int32",
            MultiA.data,
            0,
            MultiB.data,
            0,
            Accum.data,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8,
            (tx % 32) % 4 * 2 + mma_accum_c_id % 2,
        ] = Accum[mma_accum_c_id]
Esempio n. 6
0
def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [8, 32], dtype="int4")
    B = T.match_buffer(b, [8, 32], dtype="uint4")
    C = T.match_buffer(c, [8, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([8], "int4", scope="local")
    MultiB = T.allocate([8], "uint4", scope="local")
    Accum = T.allocate([2], "int32", scope="local")
    for i in range(2):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in T.vectorized(8):
        MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8]
    for mma_multi_b_col in T.vectorized(8):
        MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8]
    T.evaluate(
        T.ptx_mma(
            "m8n8k32",
            "row",
            "col",
            "int4",
            "uint4",
            "int32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(2):
        C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load(
            "int32", Accum, mma_accum_c_id
        )
Esempio n. 7
0
def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [8, 4], dtype="float64")
    B = T.match_buffer(b, [8, 4], dtype="float64")
    C = T.match_buffer(c, [8, 8], dtype="float64")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([1], "float64", scope="local")
    MultiB = T.allocate([1], "float64", scope="local")
    Accum = T.allocate([2], "float64", scope="local")
    for i in range(2):
        Accum[i] = T.float64(0)

    MultiA[0] = A[(tx % 32) // 4, (tx % 32) % 4]
    MultiB[0] = B[(tx % 32) // 4, (tx % 32) % 4]
    T.evaluate(
        T.ptx_mma(
            "m8n8k4",
            "row",
            "col",
            "fp64",
            "fp64",
            "fp64",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="float64",
        )
    )
    for mma_accum_c_id in range(2):
        C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load(
            "float64", Accum, mma_accum_c_id
        )