예제 #1
0
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None:
    B = tir.match_buffer(b, [16, 16], "float32")
    C = tir.match_buffer(c, [16, 16], "float32")

    with tir.block([]):
        tir.reads([])
        tir.writes(B[0:16, 0:16])
        A = tir.allocate([256], "float32", "global")
        for i, j in tir.grid(16, 16):
            tir.store(A, i * 16 + j, 1)
        for i in range(0, 16):
            for j in range(0, 16):
                tir.evaluate(tir.load("float32", A, i * 16 + j))
            for j in range(0, 16):
                tir.evaluate(
                    tir.tvm_fill_fragment(B.data,
                                          16,
                                          16,
                                          16,
                                          0,
                                          tir.float32(0),
                                          dtype="handle"))

    for i, j in tir.grid(16, 16):
        with tir.block([16, 16]) as [vi, vj]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            C[vi, vj] = B[vi, vj]
예제 #2
0
def opaque_access(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16], "float32")
    B = tir.match_buffer(b, [16, 16], "float32")
    with tir.block([16, 16], "A") as [vi, vj]:
        tir.reads([])
        tir.writes([A[0:16, 0:16]])
        tir.store(A.data, vi * 16 + vj, 1)
    with tir.block([16, 16], "B") as [vi, vj]:
        tir.reads([])
        tir.writes([B[0:16, 0:16]])
        tir.evaluate(
            tir.tvm_fill_fragment(B.data,
                                  16,
                                  16,
                                  16,
                                  0,
                                  vi * 16 + vj,
                                  dtype="handle"))
예제 #3
0
def opaque_access_split(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16))
    B = tir.match_buffer(b, (16, 16))
    for i, j0, j1 in tir.grid(16, 4, 4):
        with tir.block([16, 16], "A") as [vi, vj]:
            tir.bind(vi, i)
            tir.bind(vj, ((j0 * 4) + j1))
            tir.reads([])
            tir.writes([A[0:16, 0:16]])
            tir.store(A.data, ((vi * 16) + vj), 1, 1)
    for i, j0, j1 in tir.grid(16, 4, 4):
        with tir.block([16, 16], "B") as [vi, vj]:
            tir.bind(vi, i)
            tir.bind(vj, ((j0 * 4) + j1))
            tir.reads([])
            tir.writes([B[0:16, 0:16]])
            tir.evaluate(
                tir.tvm_fill_fragment(B.data,
                                      16,
                                      16,
                                      16,
                                      0, ((vi * 16) + vj),
                                      dtype="handle"))
예제 #4
0
def opaque_access_fused(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16])
    B = tir.match_buffer(b, [16, 16])
    for i_j_fused in tir.serial(0, 256):
        with tir.block([16, 16], "A") as [vi, vj]:
            tir.bind(vi, tir.floordiv(i_j_fused, 16))
            tir.bind(vj, tir.floormod(i_j_fused, 16))
            tir.reads([])
            tir.writes([A[0:16, 0:16]])
            tir.store(A.data, ((vi * 16) + vj), 1, 1)
    for i_j_fused in tir.serial(0, 256):
        with tir.block([16, 16], "B") as [vi, vj]:
            tir.bind(vi, tir.floordiv(i_j_fused, 16))
            tir.bind(vj, tir.floormod(i_j_fused, 16))
            tir.reads([])
            tir.writes([B[0:16, 0:16]])
            tir.evaluate(
                tir.tvm_fill_fragment(B.data,
                                      16,
                                      16,
                                      16,
                                      0, ((vi * 16) + vj),
                                      dtype="handle"))
예제 #5
0
def tensorcore_gemm(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    # match buffer
    A = tir.match_buffer(a, [1024, 1024], "float16")
    B = tir.match_buffer(b, [1024, 1024], "float16")
    C = tir.match_buffer(c, [1024, 1024], "float32")

    # body
    for blockIdx_x in tir.thread_binding(0, 16, "blockIdx.x"):
        for blockIdx_y in tir.thread_binding(0, 8, "blockIdx.y"):
            with tir.block([16, 8]) as [bx, by]:
                tir.bind(bx, blockIdx_x)
                tir.bind(by, blockIdx_y)
                shared_A = tir.alloc_buffer([1024, 1024], "float16", scope="shared")
                shared_B = tir.alloc_buffer([1024, 1024], "float16", scope="shared")
                wmma_A = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a")
                wmma_B = tir.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b")
                wmma_C = tir.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator")
                for ty in tir.thread_binding(0, 2, "threadIdx.y"):
                    for tz in tir.thread_binding(0, 2, "threadIdx.z"):
                        for i, j in tir.grid(2, 4):
                            with tir.block([64, 64]) as [vi, vj]:
                                tir.bind(vi, bx * 4 + ty * 2 + i)
                                tir.bind(vj, by * 8 + tz * 4 + j)
                                tir.reads([])
                                tir.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                C0 = tir.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                tir.evaluate(
                                    tir.tvm_fill_fragment(
                                        C0.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        tir.float32(0),
                                        dtype="handle",
                                    )
                                )

                        for ko in range(0, 32):
                            # copy data from global to shared
                            for tx in tir.thread_binding(0, 32, "threadIdx.x"):
                                for i0, j0 in tir.grid(1, 4):
                                    for j1 in tir.vectorized(0, 4):
                                        with tir.block([1024, 1024]) as [vi, vj]:
                                            tir.bind(vi, bx * 64 + ty * 32 + tx + i0)
                                            tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_A[vi, vj + 8] = A[vi, vj]

                                for i0, j0 in tir.grid(2, 4):
                                    for j1 in tir.vectorized(0, 4):
                                        with tir.block([1024, 1024]) as [vi, vj]:
                                            tir.bind(vi, by * 128 + ty * 64 + tx * 2 + i0)
                                            tir.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_B[vi, vj + 8] = B[vi, vj]

                            for ki in range(0, 2):
                                for i in range(0, 2):
                                    with tir.block([64, 64]) as [vi, vk]:
                                        tir.bind(vi, bx * 4 + ty * 2 + i)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        tir.writes(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = tir.var("int32")
                                        s1 = tir.var("int32")
                                        A0 = tir.match_buffer(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_A0 = tir.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_load_matrix_sync(
                                                wmma_A0.data,
                                                16,
                                                16,
                                                16,
                                                i,
                                                tir.tvm_access_ptr(
                                                    tir.type_annotation(dtype="float16"),
                                                    A0.data,
                                                    A0.elem_offset + 8,
                                                    A0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                A0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            )
                                        )
                                for j in range(0, 4):
                                    with tir.block([64, 64]) as [vj, vk]:
                                        tir.bind(vj, by * 8 + tz * 4 + j)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        tir.writes(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = tir.var("int32")
                                        s1 = tir.var("int32")
                                        B0 = tir.match_buffer(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_B0 = tir.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_load_matrix_sync(
                                                wmma_B0.data,
                                                16,
                                                16,
                                                16,
                                                j,
                                                tir.tvm_access_ptr(
                                                    tir.type_annotation(dtype="float16"),
                                                    B0.data,
                                                    B0.elem_offset + 8,
                                                    B0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                B0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            )
                                        )
                                for i, j in tir.grid(2, 4):
                                    with tir.block([64, 64, tir.reduce_axis(0, 64)]) as [
                                        vi,
                                        vj,
                                        vk,
                                    ]:
                                        tir.bind(vi, bx * 4 + ty * 2 + i)
                                        tir.bind(vj, by * 8 + tz * 4 + j)
                                        tir.bind(vk, ko * 2 + ki)
                                        tir.reads(
                                            [
                                                wmma_A[
                                                    vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_B[
                                                    vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_C[
                                                    vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16
                                                ],
                                            ]
                                        )
                                        tir.writes(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]
                                        )
                                        wmma_A1 = tir.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_B1 = tir.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_C1 = tir.match_buffer(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        tir.evaluate(
                                            tir.tvm_mma_sync(
                                                wmma_C1.data,
                                                i * 4 + j,
                                                wmma_A1.data,
                                                i,
                                                wmma_B1.data,
                                                j,
                                                wmma_C1.data,
                                                i * 4 + j,
                                                dtype="handle",
                                            )
                                        )
                        for i, j in tir.grid(2, 4):
                            with tir.block([64, 64]) as [vi, vj]:
                                tir.bind(vi, bx * 4 + ty * 2 + i)
                                tir.bind(vj, by * 8 + tz * 4 + j)
                                tir.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                tir.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                s0 = tir.var("int32")
                                s1 = tir.var("int32")
                                wmma_C2 = tir.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                C1 = tir.match_buffer(
                                    C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[s0, s1],
                                    offset_factor=1,
                                )
                                tir.evaluate(
                                    tir.tvm_store_matrix_sync(
                                        wmma_C2.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        tir.tvm_access_ptr(
                                            tir.type_annotation(dtype="float32"),
                                            C1.data,
                                            C1.elem_offset,
                                            C1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        C1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    )
                                )