Esempio n. 1
0
    def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
        s0 = T.var("int32")
        s1 = T.var("int32")
        shared = T.match_buffer(
            shared_handle,
            shmem_shape,
            dtype,
            align=128,
            offset_factor=16,
            scope=shared_scope,
            strides=[s0, s1],
        )
        warp = T.match_buffer(warp_handle, (WARP_SIZE, local_size),
                              dtype,
                              align=128,
                              offset_factor=16,
                              scope="warp")

        with T.block("root"):
            T.reads(shared[0:row_dim, 0:col_dim])
            T.writes(warp[0:WARP_SIZE, 0:local_size])
            tx = T.env_thread("threadIdx.x")
            T.launch_thread(tx, WARP_SIZE)

            T.evaluate(
                T.ptx_ldmatrix(
                    ldmatrix_col_major,
                    4,  # Always load 4 matrices
                    ".b16",
                    warp.data,
                    warp.elem_offset + lift(local_size) * tx,
                    shared.access_ptr("r"),
                    shared_offset(tx, s0),
                    dtype=dtype,
                ))
Esempio n. 2
0
def ptx_ldmatrix(A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16),
                                                               "float16"],
                 num: T.int32, trans: T.uint8) -> None:
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    bx = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(bx, 1)
    T.launch_thread(tx, 32)
    with T.block():
        A_shared = T.alloc_buffer([16, 16], "float16", scope="shared")
        A_local = T.alloc_buffer([8], "float16", scope="local")

        for i in range(8):
            A_shared[i * 2 + tx // 16, tx % 16] = A[i * 2 + tx // 16, tx % 16]

        T.evaluate(
            T.ptx_ldmatrix(
                trans,
                num,
                ".b16",
                A_local.data,
                0,
                A_shared.data,
                16 * (tx % 16) + 8 * (tx // 16),
                dtype="float16",
            ))

        for k in range(2):
            for j in range(2):
                for i in range(2):
                    B[8 * j + tx // 4,
                      8 * k + (tx % 4) * 2 + i] = A_local[4 * k + 2 * j + i]
    def before(A: T.Buffer[(128, 16), "float32"], n: T.int32):
        i = T.env_thread("threadIdx.x")
        T.launch_thread(i, 128)

        for j in T.serial(16):
            if i < 32:
                A[i, j] = 0.0
 def main(inputs: T.Buffer[(8192,), "float32"], weight: T.Buffer[(2097152,), "float32"], conv2d_transpose_nhwc: T.Buffer[(16384,), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # var definition
     threadIdx_x = T.env_thread("threadIdx.x")
     blockIdx_x = T.env_thread("blockIdx.x")
     T.preflattened_buffer(inputs, [1, 4, 4, 512], dtype="float32", data=inputs.data)
     T.preflattened_buffer(weight, [4, 4, 512, 256], dtype="float32", data=weight.data)
     T.preflattened_buffer(conv2d_transpose_nhwc, [1, 8, 8, 256], dtype="float32", data=conv2d_transpose_nhwc.data)
     # body
     T.launch_thread(blockIdx_x, 64)
     conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local")
     PadInput_shared = T.allocate([768], "float32", "shared")
     weight_shared = T.allocate([4096], "float32", "shared")
     T.launch_thread(threadIdx_x, 32)
     for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2):
         conv2d_transpose_nhwc_local[i1_4_init * 4 + i2_3_init * 2 + i2_4_init] = T.float32(0)
     for i6_0 in T.serial(16):
         for ax0_ax1_ax2_ax3_fused_0 in T.serial(24):
             PadInput_shared[ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x] = T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, inputs[blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560], T.float32(0), dtype="float32")
         for ax0_ax1_ax2_ax3_fused_0 in T.serial(32):
             weight_shared[T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4)] = weight[T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + ax0_ax1_ax2_ax3_fused_0 % 2 * 4096 + threadIdx_x // 2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4)]
         for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2):
             conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] = conv2d_transpose_nhwc_local[i1_4 * 4 + i2_3 * 2 + i2_4] + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, PadInput_shared[threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2], T.float32(0), dtype="float32") * weight_shared[i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024]
     for ax1, ax2 in T.grid(2, 4):
         conv2d_transpose_nhwc[threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8] = conv2d_transpose_nhwc_local[ax1 * 4 + ax2]
Esempio n. 5
0
def ptx_cp_async(A: T.Buffer[(32, 128), "float16"],
                 B: T.Buffer[(32, 128), "float16"]) -> None:
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    bx = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(bx, 1)
    T.launch_thread(tx, 32)
    with T.block():
        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
        T.reads(A[0:32, 0:128])
        T.writes(B[0:32, 0:128])

        for i in range(16):
            T.evaluate(
                T.ptx_cp_async(A_shared.data,
                               tx * 128 + 8 * i,
                               A.data,
                               tx * 128 + 8 * i,
                               16,
                               dtype="float16"))

        # TODO(masahi): Remove dtype requirement from TVMScript parser
        T.evaluate(T.ptx_commit_group(dtype="float16"))
        T.evaluate(T.ptx_wait_group(0, dtype="float16"))

        for i in range(128):
            B[tx, i] = A_shared[tx, i]
Esempio n. 6
0
def ptx_global_to_shared_dyn_copy_fp16x8(
    A: T.Buffer[(32, 128), "float16"],
    B: T.Buffer[(32, 128), "float16"],
    C: T.Buffer[(32, 128), "float16"],
) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    bx = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(bx, 1)
    T.launch_thread(tx, 32)
    with T.block():
        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
        B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
        T.reads(A[0:32, 0:128], B[0:32, 0:128])
        T.writes(C[0:32, 0:128])

        T.attr("default", "async_scope", 1)
        for i in T.serial(16):
            for j in T.vectorized(8):
                A_shared[tx, i * 8 + j] = A[tx, i * 8 + j]
                B_shared[tx, i * 8 + j] = B[tx, i * 8 + j]

        T.evaluate(T.ptx_commit_group(dtype=""))
        T.evaluate(T.ptx_wait_group(0, dtype=""))

        for i in range(128):
            C[tx, i] = A_shared[tx, i] + B_shared[tx, i]
Esempio n. 7
0
def unified_element_wise_thread_x(a: T.handle, b: T.handle,
                                  c: T.handle) -> None:
    thread_x = T.env_thread("threadIdx.x")
    block_x = T.env_thread("blockIdx.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    T.launch_thread(block_x, 128)
    with T.launch_thread(thread_x, 4):
        for j0_1 in T.serial(0, 32):
            T.store(
                B.data,
                block_x * 128 + thread_x * 32 + j0_1,
                T.load("float32", A.data, block_x * 128 + thread_x * 32 + j0_1)
                * 2.0,
                True,
            )
    T.launch_thread(thread_x, 4)
    for j1_1 in T.serial(0, 32):
        T.store(
            C.data,
            block_x * 128 + thread_x * 32 + j1_1,
            T.load("float32", A.data, block_x * 128 + thread_x * 32 + j1_1) +
            1.0,
            True,
        )
 def main(a: T.handle, b: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "T.noalias": True})
     # var definition
     threadIdx_x = T.env_thread("threadIdx.x")
     threadIdx_y = T.env_thread("threadIdx.y")
     blockIdx_x = T.env_thread("blockIdx.x")
     blockIdx_y = T.env_thread("blockIdx.y")
     blockIdx_z = T.env_thread("blockIdx.z")
     A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32")
     B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32")
     # body
     T.launch_thread(blockIdx_z, 196)
     B_local = T.allocate([6400000], "float32", "local")
     Apad_shared = T.allocate([512], "float32", "shared")
     Apad_shared_local = T.allocate([8], "float32", "local")
     T.launch_thread(blockIdx_y, 8)
     T.launch_thread(blockIdx_x, 4)
     T.launch_thread(threadIdx_y, 8)
     T.launch_thread(threadIdx_x, 8)
     for ff_c_init, nn_c_init in T.grid(8, 8):
         T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True)
     for rc_outer, ry, rx in T.grid(32, 3, 3):
         for ax3_inner_outer in T.serial(0, 2):
             T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4))
         for rc_inner in T.serial(0, 8):
             for ax3 in T.serial(0, 8):
                 T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True)
             for ff_c, nn_c in T.grid(8, 8):
                 T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True)
     for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
         T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on
Esempio n. 9
0
    def mma_store_impl(a: T.handle, c: T.handle) -> None:
        s0 = T.var("int32")
        s1 = T.var("int32")

        C_warp = T.match_buffer(a, [WARP_SIZE, local_size],
                                dtype=dtype,
                                scope="warp",
                                offset_factor=1)
        C = T.match_buffer(c, [M_DIM, N_DIM],
                           dtype=dtype,
                           scope="global",
                           offset_factor=1,
                           strides=[s0, s1])

        with T.block("root"):
            T.reads(C_warp[0:WARP_SIZE, 0:local_size])
            T.writes(C[0:M_DIM, 0:N_DIM])
            tx = T.env_thread("threadIdx.x")
            T.launch_thread(tx, WARP_SIZE)

            T.evaluate(
                T.mma_store(
                    M_DIM,
                    N_DIM,
                    C.access_ptr("w"),
                    C_warp.data,
                    C_warp.elem_offset,
                    s0,
                    dtype=dtype,
                ))
    def before(A: T.Buffer[(128, 16), "float32"], n: T.int32):
        thread_x = T.env_thread("threadIdx.x")
        T.launch_thread(thread_x, 128)

        for i in T.thread_binding(0, 128, thread="threadIdx.x"):
            if i < 32:
                for j in T.serial(16):
                    A[i, j] = 0.0
    def expected(A: T.Buffer[(128, 16), "float32"], n: T.int32):
        thread_x = T.env_thread("threadIdx.x")

        T.launch_thread(thread_x, 128)
        if n == 0:
            for i in T.thread_binding(0, 128, thread="threadIdx.x"):
                for j in T.serial(16):
                    A[i, j] = 0.0
Esempio n. 12
0
File: cuda.py Progetto: were/tvm
    def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(
            a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
        )
        B = T.match_buffer(
            b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
        )
        C = T.match_buffer(
            c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
        )

        with T.block("root"):
            T.reads(
                C[0:WARP_SIZE, 0:local_size_out],
                A[0:WARP_SIZE, 0:local_size],
                B[0:WARP_SIZE, 0:local_size],
            )
            T.writes(C[0:WARP_SIZE, 0:local_size_out])
            tx = T.env_thread("threadIdx.x")
            T.launch_thread(tx, WARP_SIZE)

            T.evaluate(
                T.ptx_mma(
                    mma_prefix,
                    "row",
                    "col",
                    in_dtype_abbrv,
                    in_dtype_abbrv,
                    out_dtype_abbrv,
                    A.data,
                    A.elem_offset + tx * lift(local_size),
                    B.data,
                    B.elem_offset + tx * lift(local_size),
                    C.data,
                    C.elem_offset + tx * lift(local_size_out),
                    False,
                    dtype=out_dtype,
                )
            )

            T.evaluate(
                T.ptx_mma(
                    mma_prefix,
                    "row",
                    "col",
                    in_dtype_abbrv,
                    in_dtype_abbrv,
                    out_dtype_abbrv,
                    A.data,
                    A.elem_offset + tx * lift(local_size),
                    B.data,
                    B.elem_offset + tx * lift(local_size) + lift(local_size) // 2,
                    C.data,
                    C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2,
                    False,
                    dtype=out_dtype,
                )
            )
 def main(a: T.handle, b: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "T.noalias": True})
     # var definition
     threadIdx_x = T.env_thread("threadIdx.x")
     threadIdx_y = T.env_thread("threadIdx.y")
     blockIdx_x = T.env_thread("blockIdx.x")
     blockIdx_y = T.env_thread("blockIdx.y")
     blockIdx_z = T.env_thread("blockIdx.z")
     A = T.match_buffer(a, [14 * 14 * 256 * 256], dtype="float32")
     B = T.match_buffer(b, [14 * 14 * 512 * 256], dtype="float32")
     # body
     T.launch_thread(blockIdx_z, 196)
     B_local = T.allocate([64], "float32", "local")
     Apad_shared = T.allocate([512000], "float32", "shared")
     Apad_shared_local = T.allocate([8], "float32", "local")
     T.launch_thread(blockIdx_y, 8)
     T.launch_thread(blockIdx_x, 4)
     T.launch_thread(threadIdx_y, 8)
     T.launch_thread(threadIdx_x, 8)
     for ff_c_init, nn_c_init in T.grid(8, 8):
         B_local[ff_c_init * 8 + nn_c_init] = T.float32(0)
     for rc_outer, ry, rx in T.grid(32, 3, 3):
         for ax3_inner_outer in T.serial(0, 2):
             Apad_shared[T.ramp(
                 threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4,
                 1, 4)] = T.if_then_else(
                     1 <= blockIdx_z // 14 + ry
                     and blockIdx_z // 14 + ry < 15
                     and 1 <= rx + blockIdx_z % 14
                     and rx + blockIdx_z % 14 < 15,
                     A[T.ramp(
                         ry * 917504 + blockIdx_z * 65536 + rx * 65536 +
                         rc_outer * 2048 + threadIdx_y * 256 +
                         blockIdx_x * 64 + threadIdx_x * 8 +
                         ax3_inner_outer * 4 - 983040, 1, 4)],
                     T.broadcast(T.float32(0), 4),
                     dtype="float32x4",
                 )
             # Access of the last element of Apad_shared prevents
             # buffer compacting from reducing the amount of shared
             # memory used.
             Apad_shared[512000 - 1] = 0.0
         for rc_inner in T.serial(0, 8):
             for ax3 in T.serial(0, 8):
                 Apad_shared_local[ax3] = Apad_shared[rc_inner * 64 +
                                                      threadIdx_x * 8 + ax3]
             for ff_c, nn_c in T.grid(8, 8):
                 B_local[ff_c * 8 +
                         nn_c] = B_local[ff_c * 8 +
                                         nn_c] + Apad_shared_local[nn_c]
     for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8):
         B[blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 +
           ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 +
           nn_inner_inner_inner] = B_local[ff_inner_inner_inner * 8 +
                                           nn_inner_inner_inner]  # fmt: on
Esempio n. 14
0
File: cuda.py Progetto: were/tvm
    def mma_fill_impl(a: T.handle) -> None:
        C_warp = T.match_buffer(
            a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", offset_factor=1
        )

        with T.block("root"):
            T.reads()
            T.writes(C_warp[0:WARP_SIZE, 0:local_size])
            tx = T.env_thread("threadIdx.x")
            T.launch_thread(tx, WARP_SIZE)

            T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype))
        def main(A_param: T.handle, C_param: T.handle):
            A = T.match_buffer(A_param, (400,), "float32", strides=[1])
            C = T.match_buffer(C_param, (4,), "float32", strides=[1])
            T.func_attr({"from_legacy_te_schedule": True})
            threadIdx_x = T.env_thread("threadIdx.x")
            T.launch_thread(threadIdx_x, 1)
            for i in T.serial(0, 100):
                B = T.allocate([4], "float32", scope="shared", strides=[1])
                with T.attr(B.data, "double_buffer_scope", 1):
                    for j in T.serial(0, 4):
                        B[j] = A[4 * i + j]

                for j in T.serial(0, 4):
                    C[j] = B[j] + 1.0
Esempio n. 16
0
def element_wise_two_thread_x_in_same_kernel_not_equal(a: T.handle,
                                                       b: T.handle,
                                                       c: T.handle) -> None:
    i = T.env_thread("blockIdx.x")
    j0 = T.env_thread("threadIdx.x")
    j1 = T.env_thread("threadIdx.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 64])
    T.launch_thread(i, 128)
    with T.launch_thread(j0, 128):
        T.store(B.data, i * 64 + j0,
                T.load("float32", A.data, i * 128 + j0) * 2.0, True)
    T.launch_thread(j1, 64)
    T.store(C.data, i * 64 + j1,
            T.load("float32", A.data, i * 128 + j1) + 1.0, True)
Esempio n. 17
0
def unified_element_wise_kernels_with_different_size(a: T.handle, b: T.handle,
                                                     c: T.handle,
                                                     d: T.handle) -> None:
    block_x = T.env_thread("blockIdx.x")
    thread_x = T.env_thread("threadIdx.x")
    block_x_1 = T.env_thread("blockIdx.x")
    thread_x_1 = T.env_thread("threadIdx.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [256, 256])
    D = T.match_buffer(d, [256, 256])
    with T.launch_thread(block_x, 128):
        T.launch_thread(thread_x, 128)
        T.store(
            B.data,
            block_x * 128 + thread_x,
            T.load("float32", A.data, block_x * 128 + thread_x) * 2.0,
            True,
        )
    T.launch_thread(block_x_1, 256)
    T.launch_thread(thread_x_1, 256)
    T.store(
        D.data,
        block_x_1 * 256 + thread_x_1,
        T.load("float32", C.data, block_x_1 * 256 + thread_x_1) + 1.0,
        True,
    )
Esempio n. 18
0
def ptx_global_to_shared_copy_fp32x1(
        A: T.Buffer[(32, 128), "float32"], B: T.Buffer[(32, 128),
                                                       "float32"]) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    bx = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(bx, 1)
    T.launch_thread(tx, 32)
    with T.block():
        A_shared = T.alloc_buffer([32, 128], "float32", scope="shared")
        T.reads(A[0:32, 0:128])
        T.writes(B[0:32, 0:128])

        T.attr("default", "async_scope", 1)
        for i in T.serial(128):
            A_shared[tx, i] = A[tx, i]

        T.evaluate(T.ptx_commit_group(dtype=""))
        T.evaluate(T.ptx_wait_group(0, dtype=""))

        for i in range(128):
            B[tx, i] = A_shared[tx, i]
Esempio n. 19
0
def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 4], dtype="float16")
    B = T.match_buffer(b, [4, 16], dtype="float16")
    C = T.match_buffer(c, [16, 16], dtype="float32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([4], "float16", scope="local")
    MultiB = T.allocate([4], "float16", scope="local")
    Accum = T.allocate([8], "float32", scope="local")
    for i in range(8):
        Accum[i] = T.float32(0)

    for mma_multi_a_col in T.vectorized(4):
        MultiA[mma_multi_a_col] = A[
            ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)),
            mma_multi_a_col,
        ]
    for mma_multi_b_col in T.vectorized(4):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) % 4,
            mma_multi_b_col + (4 * ((tx % 32) // 8)),
        ]
    T.evaluate(
        T.ptx_mma(
            "m8n8k4",
            "row",
            "row",
            "fp16",
            "fp16",
            "fp32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="float32",
        )
    )
    for mma_accum_c_id in range(8):
        C[
            ((tx % 32) % 2)
            + ((mma_accum_c_id // 2 % 2) * 2)
            + 4 * ((tx % 32) // 16)
            + ((tx % 32) % 16 // 4) % 2 * 8,
            (tx % 32) % 4 // 2 * 2
            + (tx % 32) % 16 // 8 * 4
            + mma_accum_c_id % 2
            + mma_accum_c_id // 4 * 8,
        ] = T.load("float32", Accum, mma_accum_c_id)
Esempio n. 20
0
def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
    j1_0 = T.env_thread("threadIdx.x")
    j0_0 = T.env_thread("threadIdx.x")
    i = T.env_thread("blockIdx.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    T.launch_thread(i, 128)
    with T.launch_thread(j0_0, 4):
        for j0_1 in T.serial(0, 32):
            T.store(
                B.data,
                i * 128 + j0_0 * 32 + j0_1,
                T.load("float32", A.data, i * 128 + j0_0 * 32 + j0_1) * 2.0,
                True,
            )
    T.launch_thread(j1_0, 4)
    for j1_1 in T.serial(0, 32):
        T.store(
            C.data,
            i * 128 + j1_0 * 32 + j1_1,
            T.load("float32", A.data, i * 128 + j1_0 * 32 + j1_1) + 1.0,
            True,
        )
Esempio n. 21
0
def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 8], dtype="float16")
    B = T.match_buffer(b, [16, 8], dtype="float16")
    C = T.match_buffer(c, [16, 8], dtype="float32")
    metadata = T.match_buffer(_metadata, [8], dtype="uint32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    multi_a = T.allocate([4], "float16", scope="local")
    multi_b = T.allocate([4], "float16", scope="local")
    accum = T.allocate([4], "float32", scope="local")
    meta_local = T.allocate([1], "uint32", scope="local")
    for i in range(4):
        accum[i] = T.float16(0)

    for i in range(4):
        multi_a[i] = A[tx // 4 + i // 2 * 8, tx % 4 * 2 + i % 2]

    for i in range(4):
        multi_b[i] = B[tx % 4 * 2 + i % 2 + i // 2 * 8, tx // 4]

    meta_local[0] = metadata[tx // 4]

    T.evaluate(
        T.ptx_mma_sp(
            "m16n8k16",
            "row",
            "col",
            "fp16",
            "fp16",
            "fp32",
            multi_a.data,
            0,
            multi_b.data,
            0,
            accum.data,
            0,
            meta_local.data,
            0,
            0,
            False,
            dtype="float32",
        )
    )

    for i in range(4):
        C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i]
Esempio n. 22
0
def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 32], dtype="int8")
    B = T.match_buffer(b, [8, 32], dtype="int8")
    C = T.match_buffer(c, [16, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([16], "int8", scope="local")
    MultiB = T.allocate([8], "int8", scope="local")
    Accum = T.allocate([4], "int32", scope="local")
    for i in range(4):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in range(16):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col % 8 // 4 * 8,
            (tx % 32) % 4 * 4 + mma_multi_a_col % 4 + mma_multi_a_col // 8 * 16,
        ]
    for mma_multi_b_col in range(8):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4,
            (tx % 32) % 4 * 4 + mma_multi_b_col % 4 + mma_multi_b_col // 4 * 16,
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k32",
            "row",
            "col",
            "int8",
            "int8",
            "int32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8,
            (tx % 32) % 4 * 2 + mma_accum_c_id % 2,
        ] = T.load("int32", Accum, mma_accum_c_id)
Esempio n. 23
0
def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 16], dtype="float16")
    B = T.match_buffer(b, [8, 16], dtype="float16")
    C = T.match_buffer(c, [16, 8], dtype="float32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([8], "float16", scope="local")
    MultiB = T.allocate([4], "float16", scope="local")
    Accum = T.allocate([4], "float32", scope="local")
    for i in range(4):
        Accum[i] = T.float32(0)

    for mma_multi_a_col in range(8):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col % 4 // 2 * 8,
            (tx % 32) % 4 * 2 + mma_multi_a_col % 2 + mma_multi_a_col // 4 * 8,
        ]
    for mma_multi_b_col in T.vectorized(4):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4,
            (tx % 32) % 4 * 2 + mma_multi_b_col % 2 + mma_multi_b_col // 2 * 8,
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k16",
            "row",
            "col",
            "fp16",
            "fp16",
            "fp32",
            MultiA.data,
            0,
            MultiB.data,
            0,
            Accum.data,
            0,
            False,
            dtype="float32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8,
            (tx % 32) % 4 * 2 + mma_accum_c_id % 2,
        ] = Accum[mma_accum_c_id]
Esempio n. 24
0
def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 256], dtype="int1")
    B = T.match_buffer(b, [8, 256], dtype="int1")
    C = T.match_buffer(c, [16, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([128], "int1", scope="local")
    MultiB = T.allocate([64], "int1", scope="local")
    Accum = T.allocate([4], "int32", scope="local")
    for i in range(4):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in range(128):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col % 64 // 32 * 8,
            (tx % 32) % 4 * 32 + mma_multi_a_col % 32 + mma_multi_a_col // 64 * 128,
        ]
    for mma_multi_b_col in range(16):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4,
            (tx % 32) % 4 * 32 + mma_multi_b_col % 32 + mma_multi_b_col // 32 * 128,
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k256",
            "row",
            "col",
            "int1",
            "int1",
            "int32",
            MultiA.data,
            0,
            MultiB.data,
            0,
            Accum.data,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8,
            (tx % 32) % 4 * 2 + mma_accum_c_id % 2,
        ] = Accum[mma_accum_c_id]
Esempio n. 25
0
def element_wise_kernels_with_different_size(a: T.handle, b: T.handle,
                                             c: T.handle, d: T.handle) -> None:
    i0 = T.env_thread("blockIdx.x")
    j0 = T.env_thread("threadIdx.x")
    i1 = T.env_thread("blockIdx.x")
    j1 = T.env_thread("threadIdx.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [256, 256])
    D = T.match_buffer(d, [256, 256])
    with T.launch_thread(i0, 128):
        T.launch_thread(j0, 128)
        T.store(B.data, i0 * 128 + j0,
                T.load("float32", A.data, i0 * 128 + j0) * 2.0, True)
    T.launch_thread(i1, 256)
    T.launch_thread(j1, 256)
    T.store(D.data, i1 * 256 + j1,
            T.load("float32", C.data, i1 * 256 + j1) + 1.0, True)
Esempio n. 26
0
def gpu_func(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    C = T.match_buffer(c, (16, 16), "float32")

    i0 = T.env_thread("blockIdx.x")
    i1 = T.env_thread("threadIdx.x")
    i2 = T.env_thread("vthread")

    T.launch_thread(i0, 4)
    T.launch_thread(i1, 2)
    T.launch_thread(i2, 2)
    B = T.allocate([1, 16], "float32", "local")
    for j in range(0, 16):
        B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
    for j in range(0, 16):
        C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
Esempio n. 27
0
def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    C = T.match_buffer(c, (16, 16), "float32")

    i0 = T.env_thread("blockIdx.x")
    i1 = T.env_thread("threadIdx.x")
    i2 = T.env_thread("vthread")

    T.launch_thread(i0, 4)
    T.launch_thread(i1, 2)
    T.launch_thread(i2, 2)
    B = T.allocate([16], "float32", "local")
    for j in range(0, 16):
        B[j] = T.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + 1.0
    for j in range(0, 16):
        C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = T.load("float32", B, j) * 2.0
Esempio n. 28
0
def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None:
    vthread_x = T.env_thread("vthread.x")
    thread_x = T.env_thread("threadIdx.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    T.launch_thread(vthread_x, 2)
    T.launch_thread(thread_x, 64)
    T.launch_thread(vthread_x, 2)
    for j_1 in T.serial(0, 64):
        T.store(
            B.data,
            vthread_x * 8256 + thread_x * 128 + j_1,
            T.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1)
            * 2.0,
            True,
        )
Esempio n. 29
0
def element_wise_vthread_x(a: T.handle, b: T.handle) -> None:
    i_0 = T.env_thread("vthread.x")
    i_1 = T.env_thread("threadIdx.x")
    j_0 = T.env_thread("vthread.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    T.launch_thread(i_0, 2)
    T.launch_thread(i_1, 64)
    T.launch_thread(j_0, 2)
    for j_1 in T.serial(0, 64):
        T.store(
            B.data,
            i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1,
            T.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1)
            * 2.0,
            True,
        )
def element_wise_env_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
    j1_0 = T.env_thread("threadIdx.x")
    j0_0 = T.env_thread("threadIdx.x")
    i = T.env_thread("blockIdx.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    T.launch_thread(i, 128)
    T.launch_thread(j0_0, 4)
    T.launch_thread(j1_0, 4)

    for j0_1 in T.serial(0, 32):
        with T.block(""):
            B[i, j0_0 * 32 + j0_1] = A[i, j0_0 * 32 + j0_1] * 2.0
    for j1_1 in T.serial(0, 32):
        with T.block(""):
            C[i, j1_0 * 32 + j1_1] = B[i, j1_0 * 32 + j1_1] + 1.0