Exemple #1
0
def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [1024], dtype="float32")
    B = T.match_buffer(b, [1024], dtype="float32")
    A_cache = T.alloc_buffer([1024], dtype="float32")
    with T.block("opaque"):
        # annotated opaque partial access should be kept
        T.reads(A[0:512])
        T.writes([A_cache[0:512]])
        T.evaluate(
            T.tvm_access_ptr(T.type_annotation(dtype="float32"),
                             A.data,
                             0,
                             512,
                             "r",
                             dtype="handle"))
        T.evaluate(
            T.tvm_access_ptr(T.type_annotation(dtype="float32"),
                             A_cache.data,
                             0,
                             512,
                             "w",
                             dtype="handle"))
    for i in T.serial(0, 512):
        with T.block("B"):
            vi = T.axis.spatial(512, i)
            T.reads([A_cache[vi]])
            T.writes([B[vi]])
            B[vi] = A_cache[vi] * 2.0 + 1.0
Exemple #2
0
def access_opaque_ptr_then_elemwise(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [1024])
    B = T.match_buffer(b, [1024])
    A_cache = T.alloc_buffer([1024])
    BB = T.alloc_buffer([1024])
    with T.block("opaque"):
        # annotated opaque partial access
        T.reads(A[0:512])
        T.writes(A_cache[0:512])
        T.evaluate(
            T.tvm_access_ptr(T.type_annotation(dtype="float32"),
                             A.data,
                             0,
                             512,
                             "r",
                             dtype="handle"))
        T.evaluate(
            T.tvm_access_ptr(T.type_annotation(dtype="float32"),
                             A_cache.data,
                             0,
                             512,
                             "w",
                             dtype="handle"))
    for i in range(512):
        with T.block("BB"):
            vi = T.axis.remap("S", [i])
            BB[vi] = A_cache[vi] * 2.0
    for i in range(512):
        with T.block("B"):
            vi = T.axis.remap("S", [i])
            B[vi] = BB[vi] + 1.0
Exemple #3
0
def opaque_access_store(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.alloc_buffer((128, 128))
    C = T.match_buffer(c, (128, 128))
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[0:128, 0:128])
            T.writes(C[0:128, 0:128])
            T.evaluate(
                T.tvm_access_ptr(T.type_annotation(dtype="float32"),
                                 B.data,
                                 0,
                                 128,
                                 "r",
                                 dtype="handle"))
            T.evaluate(
                T.tvm_access_ptr(T.type_annotation(dtype="float32"),
                                 C.data,
                                 0,
                                 128,
                                 "w",
                                 dtype="handle"))
            C[vi, vj] = B[vi, vj] + 1.0
Exemple #4
0
def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle,
                    handle_c: T.handle) -> None:
    # pylint: disable=missing-function-docstring
    # match buffer
    match_buffer_a = T.match_buffer(handle_a, [1024, 1024], "float16")
    match_buffer_b = T.match_buffer(handle_b, [1024, 1024], "float16")
    match_buffer_c = T.match_buffer(handle_c, [1024, 1024], "float32")

    # body
    for block_idx_x in T.thread_binding(0, 16, "blockIdx.x"):
        for block_idx_y in T.thread_binding(0, 8, "blockIdx.y"):
            with T.block():
                axis_bx, axis_by = T.axis.remap("SS",
                                                [block_idx_x, block_idx_y])
                shared_a = T.alloc_buffer([1024, 1024],
                                          "float16",
                                          scope="shared")
                shared_b = T.alloc_buffer([1024, 1024],
                                          "float16",
                                          scope="shared")
                wmma_a = T.alloc_buffer([1024, 1024],
                                        "float16",
                                        scope="wmma.matrix_a")
                wmma_b = T.alloc_buffer([1024, 1024],
                                        "float16",
                                        scope="wmma.matrix_b")
                wmma_c = T.alloc_buffer([1024, 1024],
                                        "float32",
                                        scope="wmma.accumulator")

                # pylint: disable=too-many-nested-blocks
                for thread_ty in T.thread_binding(0, 2, "threadIdx.y"):
                    for thread_tz in T.thread_binding(0, 2, "threadIdx.z"):
                        for index_i, index_jj in T.grid(2, 4):
                            with T.block():
                                new_axis_vi = T.axis.S(
                                    64, axis_bx * 4 + thread_ty * 2 + index_i)
                                new_axis_vj = T.axis.S(
                                    64, axis_by * 8 + thread_tz * 4 + index_jj)
                                T.reads([])
                                T.writes(wmma_c[new_axis_vi *
                                                16:new_axis_vi * 16 + 16,
                                                new_axis_vj *
                                                16:new_axis_vj * 16 + 16, ])
                                match_buffer_c0 = T.match_buffer(
                                    wmma_c[new_axis_vi * 16:new_axis_vi * 16 +
                                           16, new_axis_vj *
                                           16:new_axis_vj * 16 + 16, ],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_fill_fragment(
                                        match_buffer_c0.data,
                                        16,
                                        16,
                                        16,
                                        index_i * 4 + index_jj,
                                        T.float32(0),  # pylint: disable=not-callable
                                        dtype="handle",
                                    ))

                        for k_o in range(0, 32):
                            # copy data from global to shared
                            for thread_tx in T.thread_binding(
                                    0, 32, "threadIdx.x"):
                                for index_i0, index_j0 in T.grid(1, 4):
                                    for index_j1 in T.vectorized(0, 4):
                                        with T.block():
                                            new_axis_vi = T.axis.S(
                                                1024,
                                                axis_bx * 64 + thread_ty * 32 +
                                                thread_tx + index_i0,
                                            )
                                            new_axis_vj = T.axis.S(
                                                1024,
                                                k_o * 32 + thread_tz * 16 +
                                                index_j0 * 4 + index_j1,
                                            )
                                            shared_a[new_axis_vi, new_axis_vj +
                                                     8] = match_buffer_a[
                                                         new_axis_vi,
                                                         new_axis_vj]

                                for index_i0, index_j0 in T.grid(2, 4):
                                    for index_j1 in T.vectorized(0, 4):
                                        with T.block():
                                            new_axis_vi = T.axis.S(
                                                1024,
                                                axis_by * 128 +
                                                thread_ty * 64 +
                                                thread_tx * 2 + index_i0,
                                            )
                                            new_axis_vj = T.axis.S(
                                                1024,
                                                k_o * 32 + thread_tz * 16 +
                                                index_j0 * 4 + index_j1,
                                            )
                                            shared_b[new_axis_vi, new_axis_vj +
                                                     8] = match_buffer_b[
                                                         new_axis_vi,
                                                         new_axis_vj]

                            for k_i in range(0, 2):
                                for index_i in range(0, 2):
                                    with T.block():
                                        new_axis_vi = T.axis.S(
                                            64, axis_bx * 4 + thread_ty * 2 +
                                            index_i)
                                        axis_vk = T.axis.S(64, k_o * 2 + k_i)
                                        T.reads(shared_a[new_axis_vi *
                                                         16:new_axis_vi * 16 +
                                                         16, axis_vk *
                                                         16:axis_vk * 16 + 16 +
                                                         8, ])
                                        T.writes(
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ])
                                        stride0 = T.var("int32")
                                        stride1 = T.var("int32")
                                        match_buffer_a0 = T.match_buffer(
                                            shared_a[new_axis_vi *
                                                     16:new_axis_vi * 16 + 16,
                                                     axis_vk *
                                                     16:axis_vk * 16 + 16 +
                                                     8, ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[stride0, stride1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_a0 = T.match_buffer(
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_a0.data,
                                                16,
                                                16,
                                                16,
                                                index_i,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(
                                                        dtype="float16"),
                                                    match_buffer_a0.data,
                                                    match_buffer_a0.elem_offset
                                                    + 8,
                                                    match_buffer_a0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                match_buffer_a0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            ))
                                for index_jj in range(0, 4):
                                    with T.block():
                                        new_axis_vj = T.axis.S(
                                            64, axis_by * 8 + thread_tz * 4 +
                                            index_jj)
                                        axis_vk = T.axis.S(64, k_o * 2 + k_i)
                                        T.reads(shared_b[new_axis_vj *
                                                         16:new_axis_vj * 16 +
                                                         16, axis_vk *
                                                         16:axis_vk * 16 + 16 +
                                                         8, ])
                                        T.writes(
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ])
                                        stride0 = T.var("int32")
                                        stride1 = T.var("int32")
                                        match_buffer_b0 = T.match_buffer(
                                            shared_b[new_axis_vj *
                                                     16:new_axis_vj * 16 + 16,
                                                     axis_vk *
                                                     16:axis_vk * 16 + 16 +
                                                     8, ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[stride0, stride1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_b0 = T.match_buffer(
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_b0.data,
                                                16,
                                                16,
                                                16,
                                                index_jj,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(
                                                        dtype="float16"),
                                                    match_buffer_b0.data,
                                                    match_buffer_b0.elem_offset
                                                    + 8,
                                                    match_buffer_b0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                match_buffer_b0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            ))
                                for index_i, index_jj in T.grid(2, 4):
                                    with T.block():
                                        new_axis_vi = T.axis.S(
                                            64, axis_bx * 4 + thread_ty * 2 +
                                            index_i)
                                        new_axis_vj = T.axis.S(
                                            64, axis_by * 8 + thread_tz * 4 +
                                            index_jj)
                                        axis_vk = T.axis.R(64, k_o * 2 + k_i)
                                        T.reads([
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            wmma_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ],
                                        ])
                                        T.writes(
                                            wmma_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ])
                                        wmma_a1 = T.match_buffer(
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_b1 = T.match_buffer(
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_c1 = T.match_buffer(
                                            wmma_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_mma_sync(
                                                wmma_c1.data,
                                                index_i * 4 + index_jj,
                                                wmma_a1.data,
                                                index_i,
                                                wmma_b1.data,
                                                index_jj,
                                                wmma_c1.data,
                                                index_i * 4 + index_jj,
                                                dtype="handle",
                                            ))
                        for index_i, index_jj in T.grid(2, 4):
                            with T.block():
                                new_axis_vi = T.axis.S(
                                    64, axis_bx * 4 + thread_ty * 2 + index_i)
                                new_axis_vj = T.axis.S(
                                    64, axis_by * 8 + thread_tz * 4 + index_jj)
                                T.reads(wmma_c[new_axis_vi *
                                               16:new_axis_vi * 16 + 16,
                                               new_axis_vj *
                                               16:new_axis_vj * 16 + 16, ])
                                T.writes(
                                    match_buffer_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ])
                                stride0 = T.var("int32")
                                stride1 = T.var("int32")
                                wmma_c2 = T.match_buffer(
                                    wmma_c[new_axis_vi * 16:new_axis_vi * 16 +
                                           16, new_axis_vj *
                                           16:new_axis_vj * 16 + 16, ],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                match_buffer_c1 = T.match_buffer(
                                    match_buffer_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ],
                                    (16, 16),
                                    "float32",
                                    strides=[stride0, stride1],
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_store_matrix_sync(
                                        wmma_c2.data,
                                        16,
                                        16,
                                        16,
                                        index_i * 4 + index_jj,
                                        T.tvm_access_ptr(
                                            T.type_annotation(dtype="float32"),
                                            match_buffer_c1.data,
                                            match_buffer_c1.elem_offset,
                                            match_buffer_c1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        match_buffer_c1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    ))
Exemple #5
0
def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None:
    # match buffer
    A = T.match_buffer(a, [1024, 1024], "float16")
    B = T.match_buffer(b, [1024, 1024], "float16")
    C = T.match_buffer(c, [1024, 1024], "float32")

    # body
    for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"):
        for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"):
            with T.block([16, 8]) as [bx, by]:
                T.bind(bx, blockIdx_x)
                T.bind(by, blockIdx_y)
                shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared")
                shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared")
                wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a")
                wmma_B = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b")
                wmma_C = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator")
                for ty in T.thread_binding(0, 2, "threadIdx.y"):
                    for tz in T.thread_binding(0, 2, "threadIdx.z"):
                        for i, j in T.grid(2, 4):
                            with T.block([64, 64]) as [vi, vj]:
                                T.bind(vi, bx * 4 + ty * 2 + i)
                                T.bind(vj, by * 8 + tz * 4 + j)
                                T.reads([])
                                T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                C0 = T.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_fill_fragment(
                                        C0.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        T.float32(0),
                                        dtype="handle",
                                    )
                                )

                        for ko in range(0, 32):
                            # copy data from global to shared
                            for tx in T.thread_binding(0, 32, "threadIdx.x"):
                                for i0, j0 in T.grid(1, 4):
                                    for j1 in T.vectorized(0, 4):
                                        with T.block([1024, 1024]) as [vi, vj]:
                                            T.bind(vi, bx * 64 + ty * 32 + tx + i0)
                                            T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_A[vi, vj + 8] = A[vi, vj]

                                for i0, j0 in T.grid(2, 4):
                                    for j1 in T.vectorized(0, 4):
                                        with T.block([1024, 1024]) as [vi, vj]:
                                            T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0)
                                            T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_B[vi, vj + 8] = B[vi, vj]

                            for ki in range(0, 2):
                                for i in range(0, 2):
                                    with T.block([64, 64]) as [vi, vk]:
                                        T.bind(vi, bx * 4 + ty * 2 + i)
                                        T.bind(vk, ko * 2 + ki)
                                        T.reads(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        T.writes(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = T.var("int32")
                                        s1 = T.var("int32")
                                        A0 = T.match_buffer(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_A0 = T.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_A0.data,
                                                16,
                                                16,
                                                16,
                                                i,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(dtype="float16"),
                                                    A0.data,
                                                    A0.elem_offset + 8,
                                                    A0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                A0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            )
                                        )
                                for j in range(0, 4):
                                    with T.block([64, 64]) as [vj, vk]:
                                        T.bind(vj, by * 8 + tz * 4 + j)
                                        T.bind(vk, ko * 2 + ki)
                                        T.reads(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        T.writes(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = T.var("int32")
                                        s1 = T.var("int32")
                                        B0 = T.match_buffer(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_B0 = T.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_B0.data,
                                                16,
                                                16,
                                                16,
                                                j,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(dtype="float16"),
                                                    B0.data,
                                                    B0.elem_offset + 8,
                                                    B0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                B0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            )
                                        )
                                for i, j in T.grid(2, 4):
                                    with T.block([64, 64, T.reduce_axis(0, 64)]) as [
                                        vi,
                                        vj,
                                        vk,
                                    ]:
                                        T.bind(vi, bx * 4 + ty * 2 + i)
                                        T.bind(vj, by * 8 + tz * 4 + j)
                                        T.bind(vk, ko * 2 + ki)
                                        T.reads(
                                            [
                                                wmma_A[
                                                    vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_B[
                                                    vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_C[
                                                    vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16
                                                ],
                                            ]
                                        )
                                        T.writes(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]
                                        )
                                        wmma_A1 = T.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_B1 = T.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_C1 = T.match_buffer(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_mma_sync(
                                                wmma_C1.data,
                                                i * 4 + j,
                                                wmma_A1.data,
                                                i,
                                                wmma_B1.data,
                                                j,
                                                wmma_C1.data,
                                                i * 4 + j,
                                                dtype="handle",
                                            )
                                        )
                        for i, j in T.grid(2, 4):
                            with T.block([64, 64]) as [vi, vj]:
                                T.bind(vi, bx * 4 + ty * 2 + i)
                                T.bind(vj, by * 8 + tz * 4 + j)
                                T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                s0 = T.var("int32")
                                s1 = T.var("int32")
                                wmma_C2 = T.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                C1 = T.match_buffer(
                                    C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[s0, s1],
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_store_matrix_sync(
                                        wmma_C2.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        T.tvm_access_ptr(
                                            T.type_annotation(dtype="float32"),
                                            C1.data,
                                            C1.elem_offset,
                                            C1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        C1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    )
                                )
Exemple #6
0
def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), dtype="float16")
    B = T.match_buffer(b, (128, 128), dtype="float16")
    C = T.match_buffer(c, (128, 128), dtype="float16")
    D = T.match_buffer(d, (128, 128), dtype="float16")

    for i, j in T.grid(128, 128):
        with T.block("load_store"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A[vi, vj])
            T.writes(D[vi, vj])
            D[vi, vj] = A[vi, vj]
    for i, j in T.grid(8, 8):
        with T.block("opaque"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
            T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
            T.evaluate(
                T.tvm_load_matrix_sync(
                    B.data,
                    16,
                    16,
                    16,
                    vi * 8 + vj,
                    T.tvm_access_ptr(
                        T.type_annotation(dtype="float16"),
                        A.data,
                        vi * 2048 + vj * 16,
                        128,
                        1,
                        dtype="handle",
                    ),
                    128,
                    "row_major",
                    dtype="handle",
                )
            )
    for i, j in T.grid(8, 8):
        with T.block("match_buffer"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
            T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
            A0 = T.match_buffer(
                A[
                    vi * 16 : vi * 16 + 16,
                    vj * 16 : vj * 16 + 16,
                ],
                (16, 16),
                "float16",
                strides=[128, 1],
                offset_factor=1,
            )
            C0 = T.match_buffer(
                C[
                    vi * 16 : vi * 16 + 16,
                    vj * 16 : vj * 16 + 16,
                ],
                (16, 16),
                "float16",
                strides=[128, 1],
                offset_factor=1,
            )
            T.evaluate(
                T.tvm_load_matrix_sync(
                    C0.data,
                    16,
                    16,
                    16,
                    vi * 8 + vj,
                    T.tvm_access_ptr(
                        T.type_annotation(dtype="float16"),
                        A0.data,
                        A0.elem_offset,
                        A0.strides[0],
                        1,
                        dtype="handle",
                    ),
                    128,
                    "row_major",
                    dtype="handle",
                )
            )
Exemple #7
0
def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), dtype="float16")
    B = T.match_buffer(b, (128, 128), dtype="float16")
    C = T.match_buffer(c, (128, 128), dtype="float16")
    D = T.match_buffer(d, (128, 128), dtype="float16")

    with T.block([128, 128], "load_store") as [vi, vj]:
        T.reads(A[vi, vj])
        T.writes(D[vi, vj])
        D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj)
    with T.block([8, 8], "opaque") as [vi, vj]:
        T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        T.evaluate(
            T.tvm_load_matrix_sync(
                B.data,
                16,
                16,
                16,
                vi * 8 + vj,
                T.tvm_access_ptr(
                    T.type_annotation(dtype="float16"),
                    A.data,
                    vi * 2048 + vj * 16,
                    128,
                    1,
                    dtype="handle",
                ),
                128,
                "row_major",
                dtype="handle",
            )
        )
    with T.block([8, 8], "match_buffer") as [vi, vj]:
        T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        A0 = T.match_buffer(
            A[
                vi * 16 : vi * 16 + 16,
                vj * 16 : vj * 16 + 16,
            ],
            (16, 16),
            "float16",
            strides=[128, 1],
            offset_factor=1,
        )
        C0 = T.match_buffer(
            C[
                vi * 16 : vi * 16 + 16,
                vj * 16 : vj * 16 + 16,
            ],
            (16, 16),
            "float16",
            strides=[128, 1],
            offset_factor=1,
        )
        T.evaluate(
            T.tvm_load_matrix_sync(
                C0.data,
                16,
                16,
                16,
                vi * 8 + vj,
                T.tvm_access_ptr(
                    T.type_annotation(dtype="float16"),
                    A0.data,
                    A0.elem_offset,
                    A0.strides[0],
                    1,
                    dtype="handle",
                ),
                128,
                "row_major",
                dtype="handle",
            )
        )