def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
        B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
        C = T.match_buffer(c, (32, 8), "float16", align=128, offset_factor=16, scope="warp")

        with T.block("root"):
            T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
            T.writes(C[0:32, 0:8])
            for i, j, k in T.grid(16, 16, 16):
                with T.block("C"):
                    i, j, k = T.axis.remap("SSR", [i, j, k])
                    T.reads(
                        C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
                        A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
                        B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
                    )
                    T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
                    C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = (
                        C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]
                        + A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2]
                        * B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2]
                    )
def factorized_after_reverse_compute_at(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [16, 16, 16], "float32")
    B = T.match_buffer(b, [16], "float32")
    B_rf_local = T.alloc_buffer([16, 16], "float32", scope="local")
    for j in T.thread_binding(0, 16, thread="blockIdx.x"):
        for i_o in T.thread_binding(0, 4, thread="threadIdx.x"):
            for i_i, k in T.grid(4, 16):
                with T.block([16, 16, T.reduce_axis(0, 16)],
                             "B_rf") as [vi, vj, vk]:
                    T.bind(vi, i_o * 4 + i_i)
                    T.bind(vj, j)
                    T.bind(vk, k)
                    with T.init():
                        B_rf_local[vi, vj] = 0.0
                    B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk]
            for k in T.serial(0, 4):
                with T.block([16, T.reduce_axis(0, 16)], "B") as [vi, vk]:
                    T.bind(vi, j)
                    T.bind(vk, i_o * 4 + k)
                    with T.init():
                        B[vi] = 0.0
                    B[vi] = B[vi] + B_rf_local[vk, vi]
def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle):
    N = T.var("int32")
    M = T.var("int32")
    K = T.var("int32")
    A = T.match_buffer(a, (N, K), "float32")
    B = T.match_buffer(b, (K, M), "float32")
    C = T.match_buffer(c, (N, M), "float32")
    for i, j, k in T.grid(N, M, K):
        with T.block("gemm"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def blockized_2_after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], "float32")
    B = T.alloc_buffer([128, 128], "float32")
    C = T.match_buffer(c, [128, 128], "float32")
    for i_o, j_o in T.grid(8, 8):
        with T.block([8, 8], "B_outer") as [vio, vjo]:
            T.bind(vio, i_o)
            T.bind(vjo, j_o)
            T.reads([A[vio * 16:vio * 16 + 16, vjo * 16:vjo * 16 + 16, ]])
            T.writes([B[vio * 16:vio * 16 + 16, vjo * 16:vjo * 16 + 16]])
            for i_i, j_i in T.grid(16, 16):
                with T.block([128, 128], "B_inner") as [vi, vj]:
                    T.bind(vi, vio * 16 + i_i)
                    T.bind(vj, vjo * 16 + j_i)
                    B[vi, vj] = A[vi, vj] * 2.0
        for ax0, ax1 in T.grid(16, 16):
            with T.block([128, 128], "C") as [vi, vj]:
                T.bind(vi, i_o * 16 + ax0)
                T.bind(vj, j_o * 16 + ax1)
                T.reads([B[vi, vj]])
                T.writes([C[vi, vj]])
                C[vi, vj] = B[vi, vj] + 1.0
Esempio n. 5
0
def unschedulable_func(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    C = T.match_buffer(c, (16, 16), "float32")
    for i in range(0, 16):
        with T.block():
            T.reads(A[i, 0:16])
            T.writes(C[i, 0:16])
            B = T.alloc_buffer((16, 16), "float32")
            for j in range(0, 16):
                T.evaluate(T.call_extern("dummy_extern_function", B.data, dtype="int32"))
                B[i, j] = A[i, j] + 1.0
            for j in range(0, 16):
                C[i, j] = B[i, j] * 2.0
Esempio n. 6
0
def reduction_loop_only(
    A: T.Buffer[2, "float32"],
    B: T.Buffer[2, "float32"],
    C: T.Buffer[(), "float32"],
) -> None:
    for i0 in T.serial(2):
        with T.block("C"):
            k0 = T.axis.reduce(2, i0)
            T.reads(A[k0], B[k0])
            T.writes(C[()])
            with T.init():
                C[()] = T.float32(1.0)
            C[()] = T.min(C[()], A[k0] / B[k0])
Esempio n. 7
0
def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle,
                                  d: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    D = T.match_buffer(d, [128, 128])

    for k, i, j in T.grid(128, 128, 128):
        with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]:
            T.bind(ck, k)
            T.bind(ci, i)
            T.bind(cj, j)
            with T.init():
                C[ci, cj] = 0.0
            C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj]
        with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]:
            T.bind(dk, k)
            T.bind(di, i)
            T.bind(dj, j)
            with T.init():
                D[di, dj] = 0.0
            D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
Esempio n. 8
0
def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    D = T.match_buffer(d, [16])
    C = T.alloc_buffer([16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1):
        with T.block([16, T.reduce_axis(0, 256),
                      T.reduce_axis(0, 256)], "C") as [b, i, j]:
            T.bind(b, i0)
            T.bind(i, T.floordiv(i1_i2_fused_outer, 256))
            T.bind(j, T.floormod(i1_i2_fused_outer, 256))
            T.reads([C[b], A[b, i, j]])
            T.writes([C[b]])
            with T.init():
                C[b] = 0.0
            C[b] = C[b] + (A[b, i, j] * A[b, i, j])
    for i0_1 in T.serial(0, 16):
        with T.block([16], "D") as [b_1]:
            T.bind(b_1, i0_1)
            T.reads([C[b_1]])
            T.writes([D[b_1]])
            D[b_1] = T.sqrt(C[b_1], dtype="float32")
Esempio n. 9
0
def opaque_access(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [16, 16], "float32")
    B = T.match_buffer(b, [16, 16], "float32")
    for i, j in T.grid(16, 16):
        with T.block("A"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads([])
            T.writes([A[0:16, 0:16]])
            A[vi, vj] = 1
    for i, j in T.grid(16, 16):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads([])
            T.writes([B[0:16, 0:16]])
            T.evaluate(
                T.tvm_fill_fragment(B.data,
                                    16,
                                    16,
                                    16,
                                    0,
                                    vi * 16 + vj,
                                    dtype="handle"))
Esempio n. 10
0
def elementwise_split_with_predicate(a: T.handle, b: T.handle) -> None:
    B = T.match_buffer(b, [128, 128, 128])
    A = T.match_buffer(a, [128, 128, 128])
    for i0, i1, i2, j0, j1, k0, k1 in T.grid(1000, 2, 3, 1, 129, 3, 43):
        with T.block("B"):
            T.where((i0 * 2 + i1) * 3 + i2 < 128 and j1 < 128
                    and k0 * 43 + k1 < 128)
            vi = T.axis.S(128, i0 * 6 + i1 * 3 + i2)
            vj = T.axis.S(128, j1)
            vk = T.axis.S(128, k0 * 43 + k1)
            T.reads([A[vi, vj, vk]])
            T.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Esempio n. 11
0
 def main(a: T.handle, d: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main"})
     A = T.match_buffer(a, [1024, 1024], dtype="float32")
     D = T.match_buffer(d, [1024, 1024], dtype="float32")
     # body
     # with tir.block("root")
     B = T.alloc_buffer([1024, 1024], dtype="float32")
     for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16):
         with T.block("A"):
             vi = T.axis.S(1024, i0_0 * 64 + i0_1)
             vj = T.axis.S(1024, i1_0 * 16 + i1_1)
             T.reads([A[vi, vj]])
             T.writes([B[vi, vj]])
             B[vi, vj] = A[vi, vj] * T.float32(2)
     for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16):
         with T.block("C"):
             vi = T.axis.S(1024, i0_0 * 64 + i0_1)
             vj = T.axis.S(1024, i1_0 * 16 + i1_1)
             T.reads([B[vi, vj]])
             T.writes([D[vi, vj]])
             D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5)
def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None:
    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with T.block("root"):
        T.reads([])
        T.writes([])
        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
        for i0 in T.serial(0, 128):
            for ax1 in T.serial(0, 128):
                with T.block("B"):
                    T.block_attr({"buffer_dim_align": [0]})
                    vi, vj = T.axis.remap("SS", [i0, ax1])
                    T.reads([A[vi, vj]])
                    T.writes([B[vi, vj]])
                    B[vi, vj] = (A[vi, vj]*T.float32(2))
            for i1 in T.serial(0, 128):
                with T.block("C"):
                    vi_1, vj_1 = T.axis.remap("SS", [i0, i1])
                    T.reads([B[vi_1, vj_1]])
                    T.writes([C[vi_1, vj_1]])
                    C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = T.match_buffer(a, [2048, 2048], "float32")
    B = T.match_buffer(b, [2048, 2048], "float32")
    C = T.match_buffer(c, [2048, 2048], "float32")
    A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    with T.block([2048, 2048], "A_shared") as [v0, v1]:
        A_shared[v0, v1] = A[v0, v1]
    with T.block([2048, 2048], "B_shared") as [v0, v1]:
        B_shared[v0, v1] = B[v0, v1]
    with T.block([2048, 2048], "B_shared_local") as [v0, v1]:
        B_shared_local[v0, v1] = B_shared[v0, v1]
    for by in T.thread_binding(0, 32, thread="blockIdx.y"):
        for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
            for vy in T.thread_binding(0, 2, thread="vthread.y"):
                for vx in T.thread_binding(0, 2, thread="vthread.x"):
                    for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
                        for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
                            for k_0 in T.serial(0, 256):
                                for k_1 in T.unroll(0, 8):
                                    for i, j in T.grid(1, 4):
                                        with T.block(
                                            [2048, 2048],
                                                "A_shared_local") as [v0, v1]:
                                            T.bind(v0, k_0 * 8 + k_1 + i)
                                            T.bind(
                                                v1,
                                                by * 64 + vy * 32 + ty * 4 + j)
                                            A_shared_local[v0,
                                                           v1] = A_shared[v0,
                                                                          v1]
                                    for _, i, j in T.grid(1, 4, 4):
                                        with T.block([
                                                2048, 2048,
                                                T.reduce_axis(0, 2048)
                                        ], "C") as [vi, vj, vk]:
                                            T.bind(
                                                vi,
                                                by * 64 + vy * 32 + ty * 4 + i)
                                            T.bind(
                                                vj,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            T.bind(vk, k_0 * 8 + k_1)
                                            with T.init():
                                                C_local[vi, vj] = T.float32(0)
                                            C_local[vi, vj] = C_local[
                                                vi, vj] + A_shared_local[
                                                    vk, vi] * B_shared_local[
                                                        vk, vj]
                            for i, j in T.grid(4, 4):
                                with T.block([2048, 2048],
                                             "C_local") as [v0, v1]:
                                    T.bind(v0, by * 64 + vy * 32 + ty * 4 + i)
                                    T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j)
                                    C[v0, v1] = C_local[v0, v1]
Esempio n. 14
0
File: cuda.py Progetto: were/tvm
    def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(
            a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
        )
        B = T.match_buffer(
            b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
        )
        C = T.match_buffer(
            c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
        )

        with T.block("root"):
            T.reads(
                C[0:WARP_SIZE, 0:local_size_out],
                A[0:WARP_SIZE, 0:local_size],
                B[0:WARP_SIZE, 0:local_size],
            )
            T.writes(C[0:WARP_SIZE, 0:local_size_out])

            for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
                with T.block("C"):
                    i, j, k = T.axis.remap("SSR", [i, j, k])
                    b_row_ind, b_col_ind = maybe_swap(k, j)

                    thread_id_C, local_id_C = index_map_C(i, j)
                    thread_id_A, local_id_A = index_map_A(i, k)
                    thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind)

                    T.reads(
                        C[thread_id_C, local_id_C],
                        A[thread_id_A, local_id_A],
                        B[thread_id_B, local_id_B],
                    )
                    T.writes(C[thread_id_C, local_id_C])

                    C[thread_id_C, local_id_C] += maybe_cast(
                        A[thread_id_A, local_id_A]
                    ) * maybe_cast(B[thread_id_B, local_id_B])
 def spatial_tiled_pad_and_pooling(
     X: T.Buffer[(64, 112, 112), "int32"], Y: T.Buffer[(64, 56, 56), "int32"]
 ) -> None:
     for h_o, w_o in T.grid(14, 14):
         with T.block():
             X_cache = T.alloc_buffer([112, 112, 64], dtype="int32")
             for ax0, ax1, ax2 in T.grid(64, 9, 9):
                 with T.block("cache"):
                     T.where(1 <= h_o * 8 + ax1 and 1 <= w_o * 8 + ax2)
                     T.reads(X[ax0, h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2])
                     T.writes(X_cache[h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2, ax0])
                     X_cache[h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2, ax0] = X[
                         ax0, h_o * 8 - 1 + ax1, w_o * 8 - 1 + ax2
                     ]
             for h_i, w_i, kh, kw, c in T.grid(4, 4, 3, 3, 64):
                 with T.block("compute"):
                     T.reads(
                         X_cache[(h_o * 4 + h_i) * 2 + kh - 1, (w_o * 4 + w_i) * 2 + kw - 1, c]
                     )
                     T.writes(Y[h_o * 4 + h_i, w_o * 4 + w_i, c])
                     if kh == 0 and kw == 0:
                         Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = 0
                     Y[h_o * 4 + h_i, w_o * 4 + w_i, c] = T.max(
                         Y[h_o * 4 + h_i, w_o * 4 + w_i, c],
                         T.if_then_else(
                             T.likely(1 <= (h_o * 4 + h_i) * 2 + kh, dtype="bool")
                             and T.likely((h_o * 4 + h_i) * 2 + kh < 113, dtype="bool")
                             and T.likely(1 <= (w_o * 4 + w_i) * 2 + kw, dtype="bool")
                             and T.likely((w_o * 4 + w_i) * 2 + kw < 113, dtype="bool"),
                             X_cache[
                                 (h_o * 4 + h_i) * 2 + kh - 1,
                                 (w_o * 4 + w_i) * 2 + kw - 1,
                                 c,
                             ],
                             0,
                             dtype="int32",
                         ),
                     )
Esempio n. 16
0
def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None:
    A = T.match_buffer(a, (16, 16, 16))
    C = T.alloc_buffer((16, 16))
    D = T.alloc_buffer((16, 16))
    E = T.alloc_buffer((16, 16))
    F = T.match_buffer(f, (16, 16))

    for i in T.serial(0, 16):
        for j1 in T.serial(0, 16):
            for k1o, k1i in T.grid(4, 4):
                with T.block("C"):
                    ci, cj = T.axis.remap("SS", [i, j1])
                    ck = T.axis.R(16, k1o * 4 + k1i)
                    with T.init():
                        C[ci, cj] = 0.0
                    C[ci, cj] = C[ci, cj] + A[ci, cj, ck]
            for k2o, k2i in T.grid(4, 4):
                with T.block("D"):
                    di, dj = T.axis.remap("SS", [i, j1])
                    dk = T.axis.R(16, k2o * 4 + k2i)
                    with T.init():
                        D[di, dj] = 0.0
                    D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj]
        for j2 in T.serial(0, 16):
            for k3o, k3i in T.grid(4, 4):
                with T.block("E"):
                    ei, ej = T.axis.remap("SS", [i, j2])
                    ek = T.axis.R(16, k3o * 4 + k3i)
                    with T.init():
                        E[ei, ej] = 0.0
                    E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej]
            for k4o, k4i in T.grid(4, 4):
                with T.block("F"):
                    fi, fj = T.axis.remap("SS", [i, j2])
                    fk = T.axis.R(16, k4o * 4 + k4i)
                    with T.init():
                        F[fi, fj] = 0.0
                    F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
Esempio n. 17
0
def get_valid_counts(
    data: T.handle,
    valid_count: T.handle,
    out: T.handle,
    out_indices: T.handle,
    score_threshold: T.float32,
    id_index: T.int32,
    score_index: T.int32,
) -> None:

    data_buf = T.match_buffer(data, (1, 2500, 6), "float32")
    valid_count_buf = T.match_buffer(valid_count, (1, ), "int32")
    out_buf = T.match_buffer(out, (1, 2500, 6), "float32")
    out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32")

    with T.block("init"):
        vi = T.axis.S(1, 0)
        valid_count_buf[vi] = T.int32(0)
        for j in range(2500):
            with T.block("update"):
                vj = T.axis.S(2500, j)
                T.reads([data_buf[vi, vj, 6]])
                T.writes([
                    valid_count_buf[vi], out_indices_buf[vi, vj],
                    out_buf[vi, vj, 6]
                ])
                if (data_buf[vi, vj, score_index] > score_threshold) and (
                    (id_index < 0) or
                    (data_buf[vi, vj, id_index] >= T.float32(0))):
                    for k in T.serial(0, 6):
                        out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj,
                                                                       k]
                    out_indices_buf[vi, valid_count_buf[vi]] = vj
                    valid_count_buf[vi] = valid_count_buf[vi] + 1
                if vj >= valid_count_buf[vi]:
                    for k in T.serial(0, 6):
                        out_buf[vi, vj, k] = T.float32(-1)
                    out_indices_buf[vi, vj] = T.int32(-1)
Esempio n. 18
0
def cache_write_multi_consumer() -> None:
    A = T.alloc_buffer((128))
    B = T.alloc_buffer((128))
    C = T.alloc_buffer((128))
    A_global = T.alloc_buffer((128))
    for i in T.grid(8):
        for j in T.grid(16):
            with T.block("A_global"):
                vi = T.axis.S(128, i * 16 + j)
                A_global[vi] = 1.0
        for j in T.grid(16):
            with T.block("A"):
                vi = T.axis.S(128, i * 16 + j)
                A[vi] = A_global[vi]
        for j in T.grid(16):
            with T.block("B"):
                vi = T.axis.S(128, i * 16 + j)
                B[vi] = A[vi] + 1.0

    for i in T.grid(128):
        with T.block("C"):
            vi = T.axis.S(128, i)
            C[vi] = A[vi]
Esempio n. 19
0
def cache_read_shape_int64(var_A: T.handle, var_C: T.handle) -> None:
    A = T.match_buffer(var_A, (T.int64(128), T.int64(128)), dtype="float32")
    C = T.match_buffer(var_C, (T.int64(128), T.int64(128)), dtype="float32")
    B = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32")
    A_global = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32")
    for ax0, ax1 in T.grid(T.int64(128), T.int64(128)):
        with T.block("A_global"):
            v0, v1 = T.axis.remap("SS", [ax0, ax1])
            T.reads(A[v0, v1])
            T.writes(A_global[v0, v1])
            A_global[v0, v1] = A[v0, v1]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A_global[vi, vj])
            T.writes(B[vi, vj])
            B[vi, vj] = A_global[vi, vj] * T.float32(2)
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[vi, vj])
            T.writes(C[vi, vj])
            C[vi, vj] = B[vi, vj] + T.float32(1)
Esempio n. 20
0
def decomposed_gemm(
    A: T.Buffer[(16, 16), "float32"],
    B: T.Buffer[(16, 16), "float32"],
    C: T.Buffer[(16, 16), "float32"],
):
    local = T.alloc_buffer((16, 16), "float32")
    for i, j in T.grid(4, 4):
        for ii, jj in T.grid(4, 4):
            with T.block("init"):
                vi = T.axis.S(16, i * 4 + ii)
                vj = T.axis.S(16, j * 4 + jj)
                local[vi, vj] = 0
        for k, ii, jj in T.grid(16, 4, 4):
            with T.block("update"):
                vi = T.axis.S(16, i * 4 + ii)
                vj = T.axis.S(16, j * 4 + jj)
                vk = T.axis.R(16, k)
                local[vi, vj] += A[vi, vk] * B[vj, vk]
        for ii, jj in T.grid(4, 4):
            with T.block("C"):
                vi = T.axis.S(16, i * 4 + ii)
                vj = T.axis.S(16, j * 4 + jj)
                C[vi, vj] = local[vi, vj]
Esempio n. 21
0
def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])

    for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4):
        with T.block("update"):
            vi, vj = T.axis.remap("SS", [i0, i1])
            vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner)
            T.reads([A[vi, vk], B[vj, vk]])
            T.writes([C[vi, vj]])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
Esempio n. 22
0
def buffer_load_store(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16, 16))
    C = T.match_buffer(c, (16, 16))
    for i, j, k in T.grid(4, 16, 8):
        with T.block():
            T.reads(C[i * 4:i * 4 + 4, k * 2:k * 2 + 2])
            T.writes(A[i * 4:i * 4 + 4, j, k * 2:k * 2 + 2])
            sub_A = T.match_buffer(A[i * 4:i * 4 + 4, j, k * 2:k * 2 + 2],
                                   (4, 1, 2),
                                   offset_factor=1)
            sub_C = T.match_buffer(C[i * 4:i * 4 + 4, k * 2:k * 2 + 2], (4, 2),
                                   offset_factor=1)
            for ii, kk in T.grid(4, 2):
                sub_A[ii, 0, kk] += sub_C[ii, kk]
def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    C = T.match_buffer(c, (16, 16), "float32")
    for i0 in range(0, 4):
        with T.block():
            T.reads(A[i0 * 4:i0 * 4 + 4, 0:16])
            T.writes(C[i0 * 4:i0 * 4 + 4, 0:16])
            B = T.alloc_buffer([4, 16],
                               "float32",
                               strides=[17, 1],
                               scope="global")
            for i1 in range(0, 4):
                for j in range(0, 16):
                    with T.block() as []:
                        T.reads(A[i0 * 4 + i1, j])
                        T.writes(B[i1, j])
                        B[i1, j] = A[i0 * 4 + i1, j] + 1.0
            for i1 in range(0, 4):
                for j in range(0, 16):
                    with T.block() as []:
                        T.reads(B[i1, j])
                        T.writes(C[i0 * 4 + i1, j])
                        C[i0 * 4 + i1, j] = B[i1, j] * 2.0
def tir_matmul(
    A: T.Buffer[(16, 16), "float32"],
    B: T.Buffer[(16, 16), "float32"],
    C: T.Buffer[(16, 16), "float32"],
) -> None:
    T.func_attr({"layout_free_buffers": [1]})
    for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
        with T.block("matmul"):
            vi = T.axis.S(16, i0 * 4 + i1)
            vj = T.axis.S(16, j)
            vk = T.axis.R(16, k0 * 4 + k1)
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Esempio n. 25
0
 def main(a: T.handle, b: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main"})
     A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32")
     B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32")
     # body
     with T.block("root"):
         T.block_attr({
             "meta_schedule.parallel": 128,
             "meta_schedule.vectorize": 32
         })
         for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(
                 128, 64, 4, 4, 64, 4, 8, 32):
             with T.block("move"):
                 vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2)
                 vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2)
                 vk = T.axis.spatial(1024, k0 * 32 + k1)
                 T.where((i0 * 4 + i1) * 4 + i2 < 1024
                         and (j0 * 4 + j1) * 8 + j2 < 1024
                         and k0 * 32 + k1 < 1024)
                 T.reads([A[vi, vj, vk]])
                 T.writes([B[vi, vj, vk]])
                 B[vi, vj, vk] = A[vi, vj, vk]
Esempio n. 26
0
def multiple_bufferstore(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    C = T.alloc_buffer([], dtype="float32")
    for i in T.serial(0, 128):
        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block("B"):
                vi, vk = T.axis.remap("SR", [i, k])
                T.reads([A[vi, vk], B[vi], C[()]])
                T.writes([B[vi], C[()]])
                with T.init():
                    B[vi] = T.float32(0)
                C[()] = A[vi, vk]
                B[vi] = B[vi] + C[()]
Esempio n. 27
0
def with_block_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 120], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    for i, ko in T.grid(128, 4):
        for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("B"):
                vi = T.axis.spatial(128, i)
                vk = T.axis.reduce(120, ko * 32 + ki)
                T.where(ko * 32 + ki < 120)
                T.reads([B[vi], A[vi, vk]])
                T.writes([B[vi]])
                with T.init():
                    B[vi] = T.float32(0)
                B[vi] = B[vi] + A[vi, vk]
def gemm() -> None:
    A = T.alloc_buffer([16, 16], "float32")
    B = T.alloc_buffer([16, 16], "float32")
    C = T.alloc_buffer([16, 16], "float32")
    for i, j, k, ii, jj in T.grid(4, 4, 16, 4, 4):
        with T.block("update"):
            vi = T.axis.S(16, i * 4 + ii)
            vj = T.axis.S(16, j * 4 + jj)
            vk = T.axis.R(16, k)
            T.reads(A[vi, vk], B[vj, vk])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = 0
            C[vi, vj] += A[vi, vk] * B[vj, vk]
Esempio n. 29
0
def expected_bufferslice_indices(data: T.handle, index: T.handle) -> None:
    index_buf = T.match_buffer(index, [1],
                               dtype="int32",
                               elem_offset=0,
                               align=128,
                               offset_factor=1)
    data_buf = T.match_buffer(data, [16, 16],
                              elem_offset=0,
                              align=128,
                              offset_factor=1)
    with T.block("root"):
        T.reads([])
        T.writes([])
        out_buf = T.alloc_buffer([16, 16],
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
        for i0, i1 in T.grid(16, 16):
            with T.block():
                vi, vj = T.axis.remap("SS", [i0, i1])
                T.reads([data_buf[vi, index_buf[0]], index_buf[0]])
                T.writes([out_buf[vi, vj]])
                out_buf[vi, vj] = data_buf[vi, index_buf[0]]
Esempio n. 30
0
def two_bound_loops(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    for i in T.serial(0, 128):
        for ko in T.thread_binding(0, 4, thread="threadIdx.x"):
            for ki in T.thread_binding(0, 32, thread="threadIdx.y"):
                with T.block("B"):
                    vi = T.axis.spatial(128, i)
                    vk = T.axis.reduce(128, ko * 32 + ki)
                    T.reads([B[vi], A[vi, vk]])
                    T.writes([B[vi]])
                    with T.init():
                        B[vi] = T.float32(0)
                    B[vi] = B[vi] + A[vi, vk]