def warp_memory(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    B = tir.alloc_buffer([128, 4, 32], scope="warp")
    for i_o in tir.thread_binding(0, 4, thread="threadIdx.y"):
        for i_i in tir.thread_binding(0, 32, thread="threadIdx.x"):
            for j in tir.serial(0, 128):
                with tir.block([4, 32, 128], "B") as [warp_id, lane_id, vj]:
                    B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0
            for j in tir.serial(0, 128):
                with tir.block([4, 32, 128], "C") as [warp_id, lane_id, vj]:
                    C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0
def equal_ranked_threads(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    B = tir.alloc_buffer([128, 128], scope="shared")
    for i_o in tir.thread_binding(0, 16, thread="threadIdx.x"):
        for i_i in tir.thread_binding(0, 8, thread="threadIdx.y"):
            for j in tir.serial(0, 128):
                with tir.block([128, 128], "B") as [vi, vj]:
                    tir.bind(vi, i_o * 8 + i_i)
                    tir.bind(vj, j)
                    B[vi, vj] = A[vi, vj] * 2.0
            for j in tir.serial(0, 128):
                with tir.block([128, 128], "C") as [vi, vj]:
                    tir.bind(vi, i_o * 8 + i_i)
                    tir.bind(vj, j)
                    C[vj, vi] = B[vj, vi] + 1.0
Exemple #3
0
def compacted_gpu_func(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16), "float32")
    C = tir.match_buffer(c, (16, 16), "float32")
    for i0 in tir.thread_binding(0, 4, thread="blockIdx.x"):
        for i1 in tir.thread_binding(0, 2, thread="threadIdx.x"):
            for i2 in tir.thread_binding(0, 2, thread="vthread"):
                with tir.block([]):
                    tir.reads(A[i0 * 4 + i1 * 2 + i2, 0:16])
                    tir.writes(C[i0 * 4 + i1 * 2 + i2, 0:16])
                    B = tir.alloc_buffer([1, 16], "float32", scope="local")
                    for j in range(0, 16):
                        with tir.block() as []:
                            tir.reads(A[i0 * 4 + i1 * 2 + i2, j])
                            tir.writes(B[0, j])
                            B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
                    for j in range(0, 16):
                        with tir.block() as []:
                            tir.reads(B[0, j])
                            tir.writes(C[i0 * 4 + i1 * 2 + i2, j])
                            C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
def bound_to_thread(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    B = tir.alloc_buffer([128, 128], scope="shared")
    for i in tir.thread_binding(0, 128, thread="threadIdx.x"):
        for j in tir.serial(0, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                B[vi, vj] = A[vi, vj] * 2.0
        for j in tir.serial(0, 128):
            with tir.block([128, 128], "C") as [vi, vj]:
                C[vj, vi] = B[vj, vi] + 1.0
Exemple #5
0
def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j in tir.grid(128, 128):
        for k in tir.thread_binding(0, 128, thread="threadIdx.x"):
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                tir.bind(vk, k)
                tir.reads([A[vi, vj, vk]])
                tir.writes([B[vi, vj, vk]])
                B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #6
0
def factorized_after_reverse_compute_at(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16, 16], "float32")
    B = tir.match_buffer(b, [16], "float32")
    B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local")
    for j in tir.thread_binding(0, 16, thread="blockIdx.x"):
        for i_o in tir.thread_binding(0, 4, thread="threadIdx.x"):
            for i_i, k in tir.grid(4, 16):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "B_rf") as [vi, vj, vk]:
                    tir.bind(vi, i_o * 4 + i_i)
                    tir.bind(vj, j)
                    tir.bind(vk, k)
                    with tir.init():
                        B_rf_local[vi, vj] = 0.0
                    B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk]
            for k in tir.serial(0, 4):
                with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]:
                    tir.bind(vi, j)
                    tir.bind(vk, i_o * 4 + k)
                    with tir.init():
                        B[vi] = 0.0
                    B[vi] = B[vi] + B_rf_local[vk, vi]
Exemple #7
0
def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = tir.match_buffer(a, [2048, 2048], "float32")
    B = tir.match_buffer(b, [2048, 2048], "float32")
    C = tir.match_buffer(c, [2048, 2048], "float32")
    A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    with tir.block([2048, 2048], "A_shared") as [v0, v1]:
        A_shared[v0, v1] = A[v0, v1]
    with tir.block([2048, 2048], "B_shared") as [v0, v1]:
        B_shared[v0, v1] = B[v0, v1]
    with tir.block([2048, 2048], "A_shared_local") as [v0, v1]:
        A_shared_local[v0, v1] = A_shared[v0, v1]
    with tir.block([2048, 2048], "B_shared_local") as [v0, v1]:
        B_shared_local[v0, v1] = B_shared[v0, v1]
    for by in tir.thread_binding(0, 32, thread="blockIdx.y"):
        for bx in tir.thread_binding(0, 32, thread="blockIdx.x"):
            for vy in tir.thread_binding(0, 2, thread="vthread.y"):
                for vx in tir.thread_binding(0, 2, thread="vthread.x"):
                    for ty in tir.thread_binding(0, 8, thread="threadIdx.y"):
                        for tx in tir.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.x"):
                            for k_0 in tir.serial(0, 256):
                                for k_1 in tir.unroll(0, 8):
                                    for _, i, j in tir.grid(1, 4, 4):
                                        with tir.block([
                                                2048, 2048,
                                                tir.reduce_axis(0, 2048)
                                        ], "C") as [vi, vj, vk]:
                                            tir.bind(
                                                vi,
                                                by * 64 + vy * 32 + ty * 4 + i)
                                            tir.bind(
                                                vj,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            tir.bind(vk, k_0 * 8 + k_1)
                                            with tir.init():
                                                C_local[vi, vj] = 0.0
                                            C_local[vi, vj] = C_local[
                                                vi, vj] + A_shared_local[
                                                    vk, vi] * B_shared_local[
                                                        vk, vj]
                            for i, j in tir.grid(4, 4):
                                with tir.block([2048, 2048],
                                               "C_local") as [vi, vj]:
                                    tir.bind(vi,
                                             by * 64 + vy * 32 + ty * 4 + i)
                                    tir.bind(vj,
                                             bx * 64 + vx * 32 + tx * 4 + j)
                                    C[vi, vj] = C_local[vi, vj]
Exemple #8
0
def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    # match buffer
    A = tir.match_buffer(a, [1024, 1024], "float16")
    B = tir.match_buffer(b, [1024, 1024], "float16")
    C = tir.match_buffer(c, [1024, 1024], "float32")

    # body
    for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"):
        for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"):
            with tir.block([16, 8]) as [bx, by]:
                tir.bind(bx, blockIdx_x)
                tir.bind(by, blockIdx_y)
                shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared")
                shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared")
                wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a")
                wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b")
                wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator")
                for ty in tir.thread_binding(0, 2, "threadIdx.y"):
                    for tz in tir.thread_binding(0, 2, "threadIdx.z"):
                        for i, j in tir.grid(2, 4):
                            with tir.block([64, 64]) as [vi, vj]:
                                tir.bind(vi, bx * 4 + ty * 2 + i)
                                tir.bind(vj, by * 8 + tz * 4 + j)
                                tir.reads([])
                                tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                C0 = tir.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                tir.evaluate(
                                    tir.tvm_fill_fragment(
                                        C0.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        tir.float32(0),
                                        dtype="handle",
                                    )
                                )

                        for ko in range(0, 32):
                            # copy data from global to shared
                            for tx in tir.thread_binding(0, 32, "threadIdx.x"):
                                for i0, j0 in tir.grid(1, 4):
                                    for j1 in tir.vectorized(0, 4):
                                        with tir.block([1024, 1024]) as [vi, vj]:
                                            tir.bind(vi, bx * 64 + ty * 32 + tx + i0)
                                            tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_A[vi, vj + 8] = A[vi, vj]

                                for i0, j0 in tir.grid(2, 4):
                                    for j1 in tir.vectorized(0, 4):
                                        with tir.block([1024, 1024]) as [vi, vj]:
                                            tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0)
                                            tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_B[vi, vj + 8] = B[vi, vj]

                            for ki in range(0, 2):
                                for i in range(0, 2):
                                    with tir.block([64, 64]) as [vi, vk]:
                                        tir.bind(vi, bx * 4 + ty * 2 + i)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        tir.writes(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = tir.var("int32")
                                        s1 = tir.var("int32")
                                        A0 = tir.match_buffer(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_A0 = tir.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_load_matrix_sync(
                                                wmma_A0.data,
                                                16,
                                                16,
                                                16,
                                                i,
                                                tir.tvm_access_ptr(
                                                    tir.type_annotation(dtype="float16"),
                                                    A0.data,
                                                    A0.elem_offset + 8,
                                                    A0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                A0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            )
                                        )
                                for j in range(0, 4):
                                    with tir.block([64, 64]) as [vj, vk]:
                                        tir.bind(vj, by * 8 + tz * 4 + j)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        tir.writes(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = tir.var("int32")
                                        s1 = tir.var("int32")
                                        B0 = tir.match_buffer(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_B0 = tir.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_load_matrix_sync(
                                                wmma_B0.data,
                                                16,
                                                16,
                                                16,
                                                j,
                                                tir.tvm_access_ptr(
                                                    tir.type_annotation(dtype="float16"),
                                                    B0.data,
                                                    B0.elem_offset + 8,
                                                    B0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                B0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            )
                                        )
                                for i, j in tir.grid(2, 4):
                                    with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [
                                        vi,
                                        vj,
                                        vk,
                                    ]:
                                        tir.bind(vi, bx * 4 + ty * 2 + i)
                                        tir.bind(vj, by * 8 + tz * 4 + j)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            [
                                                wmma_A[
                                                    vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_B[
                                                    vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_C[
                                                    vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16
                                                ],
                                            ]
                                        )
                                        tir.writes(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]
                                        )
                                        wmma_A1 = tir.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_B1 = tir.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_C1 = tir.match_buffer(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_mma_sync(
                                                wmma_C1.data,
                                                i * 4 + j,
                                                wmma_A1.data,
                                                i,
                                                wmma_B1.data,
                                                j,
                                                wmma_C1.data,
                                                i * 4 + j,
                                                dtype="handle",
                                            )
                                        )
                        for i, j in tir.grid(2, 4):
                            with tir.block([64, 64]) as [vi, vj]:
                                tir.bind(vi, bx * 4 + ty * 2 + i)
                                tir.bind(vj, by * 8 + tz * 4 + j)
                                tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                s0 = tir.var("int32")
                                s1 = tir.var("int32")
                                wmma_C2 = tir.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                C1 = tir.match_buffer(
                                    C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[s0, s1],
                                    offset_factor=1,
                                )
                                tir.evaluate(
                                    tir.tvm_store_matrix_sync(
                                        wmma_C2.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        tir.tvm_access_ptr(
                                            tir.type_annotation(dtype="float32"),
                                            C1.data,
                                            C1.elem_offset,
                                            C1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        C1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    )
                                )