コード例 #1
0
def cuda_matmul_1(a: T.handle, b: T.handle, c: T.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = T.match_buffer(a, [2048, 2048], "float32")
    B = T.match_buffer(b, [2048, 2048], "float32")
    C = T.match_buffer(c, [2048, 2048], "float32")
    A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    with T.block([2048, 2048], "A_shared") as [v0, v1]:
        A_shared[v0, v1] = A[v0, v1]
    with T.block([2048, 2048], "B_shared") as [v0, v1]:
        B_shared[v0, v1] = B[v0, v1]
    with T.block([2048, 2048], "A_shared_local") as [v0, v1]:
        A_shared_local[v0, v1] = A_shared[v0, v1]
    with T.block([2048, 2048], "B_shared_local") as [v0, v1]:
        B_shared_local[v0, v1] = B_shared[v0, v1]
    for by in T.thread_binding(0, 32, thread="blockIdx.y"):
        for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
            for vy in T.thread_binding(0, 2, thread="vthread.y"):
                for vx in T.thread_binding(0, 2, thread="vthread.x"):
                    for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
                        for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
                            for k_0 in T.serial(0, 256):
                                for k_1 in T.unroll(0, 8):
                                    for _, i, j in T.grid(1, 4, 4):
                                        with T.block([
                                                2048, 2048,
                                                T.reduce_axis(0, 2048)
                                        ], "C") as [vi, vj, vk]:
                                            T.bind(
                                                vi,
                                                by * 64 + vy * 32 + ty * 4 + i)
                                            T.bind(
                                                vj,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            T.bind(vk, k_0 * 8 + k_1)
                                            with T.init():
                                                C_local[vi, vj] = 0.0
                                            C_local[vi, vj] = C_local[
                                                vi, vj] + A_shared_local[
                                                    vk, vi] * B_shared_local[
                                                        vk, vj]
                            for i, j in T.grid(4, 4):
                                with T.block([2048, 2048],
                                             "C_local") as [vi, vj]:
                                    T.bind(vi, by * 64 + vy * 32 + ty * 4 + i)
                                    T.bind(vj, bx * 64 + vx * 32 + tx * 4 + j)
                                    C[vi, vj] = C_local[vi, vj]
コード例 #2
0
def cuda_matmul_4(a: T.handle, b: T.handle, c: T.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = T.match_buffer(a, [2048, 2048], "float32")
    B = T.match_buffer(b, [2048, 2048], "float32")
    C = T.match_buffer(c, [2048, 2048], "float32")
    A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    for i, j in T.grid(2048, 2048):
        with T.block("B_shared"):
            v0, v1 = T.axis.remap("SS", [i, j])
            B_shared[v0, v1] = B[v0, v1]
    for by in T.thread_binding(0, 32, thread = "blockIdx.y"):
        for bx in T.thread_binding(0, 32, thread = "blockIdx.x"):
            for vy in T.thread_binding(0, 2, thread = "vthread.y"):
                for vx in T.thread_binding(0, 2, thread = "vthread.x"):
                    for ty in T.thread_binding(0, 8, thread = "threadIdx.y"):
                        for tx in T.thread_binding(0, 8, thread = "threadIdx.x"):
                            for k0 in T.serial(0, 256):
                                for i, j in T.grid(8, 64):
                                    with T.block("A_shared"):
                                        v0 = T.axis.S(2048, k0 * 8 + i)
                                        v1 = T.axis.S(2048, by * 64 + j)
                                        A_shared[v0, v1] = A[v0, v1]
                                for k1 in T.unroll(0, 8):
                                    for i, j in T.grid(1, 4):
                                        with T.block("A_shared_local"):
                                            v0 = T.axis.S(2048, k0 * 8 + k1 + i)
                                            v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j)
                                            A_shared_local[v0, v1] = A_shared[v0, v1]
                                    for i, j in T.grid(1, 4):
                                        with T.block("B_shared_local"):
                                            v0 = T.axis.S(2048, k0 * 8 + k1 + i)
                                            v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
                                            B_shared_local[v0, v1] = B_shared[v0, v1]
                                    for _, i, j in T.grid(1, 4, 4):
                                        with T.block("C"):
                                            vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i)
                                            vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
                                            vk = T.axis.R(2048, k0 * 8 + k1)
                                            with T.init():
                                                C_local[vi, vj] = 0.0
                                            C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj]
                            for i, j in T.grid(4, 4):
                                with T.block("C_local"):
                                    v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i)
                                    v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
                                    C[v0, v1] = C_local[v0, v1]
コード例 #3
0
 def test_tir_fma(A: T.handle, B: T.handle, C: T.handle,
                  d: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "test_fma", "tir.noalias": True})
     n = T.var("int32")
     stride = T.var("int32")
     stride_1 = T.var("int32")
     stride_2 = T.var("int32")
     stride_3 = T.var("int32")
     A_1 = T.match_buffer(
         A,
         [n],
         strides=[stride],
         elem_offset=0,
         align=128,
         offset_factor=1,
         buffer_type="auto",
     )
     B_1 = T.match_buffer(
         B,
         [n],
         strides=[stride_1],
         elem_offset=0,
         align=128,
         offset_factor=1,
         buffer_type="auto",
     )
     C_1 = T.match_buffer(
         C,
         [n],
         strides=[stride_2],
         elem_offset=0,
         align=128,
         offset_factor=1,
         buffer_type="auto",
     )
     d_1 = T.match_buffer(
         d,
         [n],
         strides=[stride_3],
         elem_offset=0,
         align=128,
         offset_factor=1,
         buffer_type="auto",
     )
     # body
     for i in T.serial(0, n):
         d_1[(i * stride_3)] = (A_1[(i * stride)] *
                                B_1[(i * stride_1)]) + C_1[(i * stride_2)]
コード例 #4
0
 def compacted_func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None:
     for bx in T.thread_binding(144, thread="blockIdx.x"):
         for vx in T.thread_binding(2, thread="vthread.x"):
             for tx_p in T.thread_binding(256, thread="threadIdx.x"):
                 with T.block():
                     for k_0 in T.serial(193):
                         with T.block():
                             A_shared = T.alloc_buffer([128, 4], dtype="float32", scope="shared")
                             B_shared = T.alloc_buffer([4, 128], dtype="float32", scope="shared")
                             for v_u in T.serial(1):
                                 for tx in T.thread_binding(256, thread="threadIdx.x"):
                                     for vec in T.vectorized(3):
                                         with T.block("A_shared"):
                                             T.where(bx // 18 * 128 + (tx * 3 + vec) // 4 < 960 and k_0 * 4 + (tx * 3 + vec) % 4 < 770 and tx * 3 + vec < 512)
                                             A_shared[(tx * 3 + vec) // 4, (tx * 3 + vec) % 4] = A[bx // 18 * 128 + (tx * 3 + vec) // 4, k_0 * 4 + (tx * 3 + vec) % 4]
                             for v_u in T.serial(1):
                                 for tx in T.thread_binding(256, thread="threadIdx.x"):
                                     for vec in T.vectorized(4):
                                         with T.block("B_shared"):
                                             T.where(k_0 * 4 + tx // 32 < 770 and tx * 4 + vec < 512)
                                             B_shared[tx // 32, tx % 32 * 4 + vec] = B[k_0 * 4 + tx // 32, bx % 18 * 128 + tx % 32 * 4 + vec]
                             for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2):
                                 with T.block("update_update"):
                                     C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] = C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] + A_shared[tx_p // 32 * 16 + i_3 * 2 + i_4, k_2] * B_shared[k_2, vx * 64 + tx_p % 32 * 2 + j_4]
コード例 #5
0
 def func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None:
     for bx in T.thread_binding(144, thread="blockIdx.x"):
         for vx in T.thread_binding(2, thread="vthread.x"):
             for tx_p in T.thread_binding(256, thread="threadIdx.x"):
                 with T.block():
                     for k_0 in T.serial(193):
                         with T.block():
                             A_shared = T.alloc_buffer([960, 770], dtype="float32", scope="shared")
                             B_shared = T.alloc_buffer([770, 2304], dtype="float32", scope="shared")
                             for _u in T.serial(1):
                                 for tx in T.thread_binding(256, thread="threadIdx.x"):
                                     for vec in T.vectorized(3):
                                         with T.block("A_shared"):
                                             T.where(bx // 18 * 128 + ((_u * 256 + tx) * 3 + vec) // 4 < 960 and k_0 * 4 + ((_u * 256 + tx) * 3 + vec) % 4 < 770 and (_u * 256 + tx) * 3 + vec < 512)
                                             A_shared[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] = A[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4]
                             for _u in T.serial(1):
                                 for tx in T.thread_binding(256, thread="threadIdx.x"):
                                     for vec in T.vectorized(4):
                                         with T.block("B_shared"):
                                             T.where(k_0 * 4 + ((_u * 256 + tx) * 4 + vec) // 128 < 770 and (_u * 256 + tx) * 4 + vec < 512)
                                             B_shared[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] = B[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128]
                             for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2):
                                 with T.block("update_update"):
                                     C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] = C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] + A_shared[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, (k_0 + k_1) * 4 + k_2] * B_shared[(k_0 + k_1) * 4 + k_2, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
コード例 #6
0
def element_wise_thread_x(a: T.handle, b: T.handle, c: T.handle) -> None:
    j1_0 = T.env_thread("threadIdx.x")
    j0_0 = T.env_thread("threadIdx.x")
    i = T.env_thread("blockIdx.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    T.launch_thread(i, 128)
    with T.launch_thread(j0_0, 4):
        for j0_1 in T.serial(0, 32):
            T.store(
                B.data,
                i * 128 + j0_0 * 32 + j0_1,
                T.load("float32", A.data, i * 128 + j0_0 * 32 + j0_1) * 2.0,
                True,
            )
    T.launch_thread(j1_0, 4)
    for j1_1 in T.serial(0, 32):
        T.store(
            C.data,
            i * 128 + j1_0 * 32 + j1_1,
            T.load("float32", A.data, i * 128 + j1_0 * 32 + j1_1) + 1.0,
            True,
        )
コード例 #7
0
def non_perfect_tiling_cache(a: T.handle, b: T.handle) -> None:
    X = T.match_buffer(a, [224, 224], dtype="float32")
    Y = T.match_buffer(b, [224, 224], dtype="float32")
    cache = T.alloc_buffer([224, 224], dtype="float32")
    for hh_0, ww_0 in T.grid(28, 28):
        for ax0 in T.serial(0, 10):
            for ax1 in T.serial(0, 10):
                with T.block("cache"):
                    h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0)
                    w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1)
                    T.where(
                        1 <= hh_0 * 8 + ax0
                        and hh_0 * 8 + ax0 < 225
                        and 1 <= ww_0 * 8 + ax1
                        and ww_0 * 8 + ax1 < 225
                    )
                    cache[h, w] = X[h, w]
        for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
            with T.block("compute"):
                h = T.axis.spatial(224, hh_0 * 8 + hh_1)
                w = T.axis.spatial(224, ww_0 * 8 + ww_1)
                kh, kw = T.axis.remap("RR", [khh, kww])
                with T.init():
                    Y[h, w] = 0.0
                Y[h, w] = T.max(
                    Y[h, w],
                    T.if_then_else(
                        T.likely(1 <= h + kh, dtype="bool")
                        and T.likely(h + kh < 225, dtype="bool")
                        and T.likely(1 <= w + kw, dtype="bool")
                        and T.likely(w + kw < 225, dtype="bool"),
                        cache[h + kh - 1, w + kw - 1],
                        0.0,
                        dtype="float32",
                    ),
                )
コード例 #8
0
def elementwise_not_affine_fused(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [127, 128])
    B = T.match_buffer(b, [127, 128])
    for i in T.grid(4):
        for j_k_fused in T.serial(0, T.min(31, 126 - i * 32) * 128 + 128):
            with T.block("B"):
                vi = T.axis.S(
                    127,
                    i * 32 + T.floormod(T.floordiv(j_k_fused, 128),
                                        T.min(31, 126 - i * 32) + 1),
                )
                vj = T.axis.S(128, T.floormod(j_k_fused, 128))
                T.reads([A[vi, vj]])
                T.writes([B[vi, vj]])
                B[vi, vj] = A[vi, vj]
コード例 #9
0
def tir_matmul(
    A: T.Buffer[(16384,), "float32"],
    B: T.Buffer[(16384,), "float32"],
    C: T.Buffer[(16384,), "float32"],
) -> None:
    # function attr dict
    T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
    T.preflattened_buffer(A, [128, 128], dtype="float32", data=A.data)
    T.preflattened_buffer(B, [128, 128], dtype="float32", data=B.data)
    T.preflattened_buffer(C, [128, 128], dtype="float32", data=C.data)
    # body
    for x, y in T.grid(128, 128):
        C[x * 128 + y] = T.float32(0)
        for k in T.serial(128):
            C[x * 128 + y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k]
コード例 #10
0
    def main(a: T.handle, b: T.handle) -> None:
        A = T.match_buffer(a, [64, 64, 64])
        B = T.match_buffer(b, [64])

        for i0, j0 in T.grid(64, 64):
            for k0 in T.serial(32, 64):
                with T.block():
                    i, j, k = T.axis.remap("SRR", [i0, j0, k0])
                    T.reads(A[i, j, k])
                    T.writes(B[i])
                    BB = T.match_buffer(B[i], ())
                    AA = T.match_buffer(A[i, 0:64, 0:64], (64, 64))
                    if (j == 0) and (k == 32):
                        BB[()] = T.float32(0)
                    BB[()] += AA[j, k]
コード例 #11
0
def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None:
    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with T.block([], "root"):
        T.reads([])
        T.writes([])
        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
        for i0 in T.serial(0, 128):
            for ax1 in T.serial(0, 128):
                with T.block([128, 128], "B") as [vi, vj]:
                    T.block_attr({"buffer_dim_align": [0]})
                    T.bind(vi, i0)
                    T.bind(vj, ax1)
                    T.reads([A[vi, vj]])
                    T.writes([B[vi, vj]])
                    B[vi, vj] = (A[vi, vj]*T.float32(2))
            for i1 in T.serial(0, 128):
                with T.block([128, 128], "C") as [vi_1, vj_1]:
                    T.bind(vi_1, i0)
                    T.bind(vj_1, i1)
                    T.reads([B[vi_1, vj_1]])
                    T.writes([C[vi_1, vj_1]])
                    C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
コード例 #12
0
def unified_element_wise_vthread_x(a: T.handle, b: T.handle) -> None:
    vthread_x = T.env_thread("vthread.x")
    thread_x = T.env_thread("threadIdx.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    T.launch_thread(vthread_x, 2)
    T.launch_thread(thread_x, 64)
    T.launch_thread(vthread_x, 2)
    for j_1 in T.serial(0, 64):
        T.store(
            B.data,
            vthread_x * 8256 + thread_x * 128 + j_1,
            T.load("float32", A.data, vthread_x * 8256 + thread_x * 128 + j_1)
            * 2.0,
            True,
        )
コード例 #13
0
def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [1024], dtype="float32")
    B = T.match_buffer(b, [1024], dtype="float32")
    A_cache = T.alloc_buffer([1024], dtype="float32")
    with T.block("opaque"):
        # annotated opaque partial access should be kept
        T.reads(A[0:512])
        T.writes([A_cache[0:512]])
        T.evaluate(A.access_ptr("r", extent=512))
        T.evaluate(A_cache.access_ptr("w", extent=512))
    for i in T.serial(0, 512):
        with T.block("B"):
            vi = T.axis.spatial(512, i)
            T.reads([A_cache[vi]])
            T.writes([B[vi]])
            B[vi] = A_cache[vi] * 2.0 + 1.0
コード例 #14
0
def rowsum_blockized(a: T.handle, b: T.handle) -> None:
    B = T.match_buffer(b, [32, 4])
    A = T.match_buffer(a, [32, 4, 128])
    for i0, i2_0 in T.grid(32, 16):
        with T.block("blockized_B"):
            io, ko = T.axis.remap("SR", [i0, i2_0])
            with T.init():
                for i1 in T.serial(0, 4):
                    with T.block("B_init"):
                        ii_init = T.axis.S(4, i1)
                        B[io, ii_init] = 0.0
            for i1_1, i2_1 in T.grid(4, 8):
                with T.block("B"):
                    ii = T.axis.S(4, i1_1)
                    k = T.axis.R(128, ko * 8 + i2_1)
                    B[io, ii] = B[io, ii] + A[io, ii, k]
コード例 #15
0
 def main(
     A: T.Buffer[(16384, ), "float32"],
     B: T.Buffer[(16384, ), "float32"],
     C: T.Buffer[(16384, ), "float32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     T.preflattened_buffer(A, [128, 128], data=A.data)
     T.preflattened_buffer(B, [128, 128], data=B.data)
     T.preflattened_buffer(C, [128, 128], data=C.data)
     # body
     for x, y in T.grid(128, 128):
         C[x * 128 + y] = 0.0
         for k in T.serial(0, 128):
             C[x * 128 +
               y] = C[x * 128 + y] + A[x * 128 + k] * B[y * 128 + k]
コード例 #16
0
 def after_rowsum_blockize(
     A: T.Buffer[(128, 128), "float32"],
     B: T.Buffer[(128, ), "float32"],
 ) -> None:
     with T.block("blockized_B"):
         vko = T.axis.R(1, 0)
         vio = T.axis.S(1, 0)
         with T.init():
             for i1 in T.serial(0, 128):
                 with T.block("B_init"):
                     vi_init = T.axis.S(128, i1)
                     B[vi_init] = T.float32(0)
         for i0, i1_1 in T.grid(128, 128):
             with T.block("B"):
                 vk, vi = T.axis.remap("RS", [i0, i1_1])
                 B[vi] = B[vi] + A[vi, vk]
コード例 #17
0
 def main(
     T_reshape: T.Buffer[(1, 12, 384, 384), "float32"],
     placeholder_1: T.Buffer[(T.int64(1), T.int64(12), T.int64(384),
                              384), "bool"],
     T_where: T.Buffer[(T.int64(1), T.int64(12), T.int64(384), 384),
                       "float32"]
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256),
                                                 thread="blockIdx.x"):
         for i0_i1_i2_i3_fused_2 in T.thread_binding(
                 T.int64(1024), thread="threadIdx.x"):
             for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)):
                 with T.block("T_where"):
                     ax0 = T.axis.spatial(T.int64(1), T.int64(0))
                     ax1 = T.axis.spatial(
                         T.int64(12),
                         ((i0_i1_i2_i3_fused_0 * T.int64(256) +
                           i0_i1_i2_i3_fused_1) * T.int64(1024) +
                          i0_i1_i2_i3_fused_2) % T.int64(1769472) //
                         T.int64(147456))
                     ax2 = T.axis.spatial(
                         T.int64(384),
                         ((i0_i1_i2_i3_fused_0 * T.int64(256) +
                           i0_i1_i2_i3_fused_1) * T.int64(1024) +
                          i0_i1_i2_i3_fused_2) % T.int64(147456) //
                         T.int64(384))
                     ax3 = T.axis.spatial(
                         384,
                         T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) +
                                  i0_i1_i2_i3_fused_1) * T.int64(1024) +
                                 i0_i1_i2_i3_fused_2) % T.int64(384),
                                "int32"))
                     T.where((i0_i1_i2_i3_fused_0 * T.int64(256) +
                              i0_i1_i2_i3_fused_1) * T.int64(1024) +
                             i0_i1_i2_i3_fused_2 < T.int64(1769472))
                     T.reads(placeholder_1[ax0, ax1, ax2, ax3],
                             T_reshape[ax0, ax1, ax2, ax3])
                     T.writes(T_where[ax0, ax1, ax2, ax3])
                     T_where[ax0, ax1, ax2, ax3] = T.Select(
                         T.cast(placeholder_1[ax0, ax1, ax2,
                                              ax3], "int32") != 0,
                         T.float32(-1000000000), T_reshape[ax0, ax1,
                                                           ax2, ax3])
コード例 #18
0
def rowsum_blockized(a: T.handle, b: T.handle) -> None:
    B = T.match_buffer(b, [32, 4])
    A = T.match_buffer(a, [32, 4, 128])
    for i0, i2_0 in T.grid(32, 16):
        with T.block([32, T.reduce_axis(0, 16)], "blockized_B") as [io, ko]:
            T.bind(io, i0)
            T.bind(ko, i2_0)
            with T.init():
                for i1 in T.serial(0, 4):
                    with T.block([4], "B_init") as [ii_init]:
                        T.bind(ii_init, i1)
                        B[io, ii_init] = 0.0
            for i1_1, i2_1 in T.grid(4, 8):
                with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]:
                    T.bind(ii, i1_1)
                    T.bind(k, ko * 8 + i2_1)
                    B[io, ii] = B[io, ii] + A[io, ii, k]
コード例 #19
0
def opaque_access_func() -> None:
    A = T.alloc_buffer([1024])
    B = T.alloc_buffer([1024])
    for i in T.serial(0, 8):
        with T.block():
            v = T.axis.S(8, i)
            T.reads([A[v * 128:v * 128 + 128]])
            T.writes([B[v * 128:v * 128 + 128]])
            T.evaluate(
                T.call_extern("test",
                              B.data,
                              v * 128,
                              128,
                              A.data,
                              v * 128,
                              128,
                              dtype="float32"))
コード例 #20
0
def exp_exp_opaque_access_with_tvm_access_ptr_inlined(
    lookup_table: T.Buffer[(1024,), "int8"],
    x: T.Buffer[(16,), "float16"],
    compute: T.Buffer[(16,), "float16"],
) -> None:
    for i0 in T.serial(16):
        with T.block("compute_1"):
            i0_1 = T.axis.spatial(16, i0)
            # Do not put the opaque access to new write region when opaque access
            # wrapped with a tvm_access_ptr and the access mask set to "read only"
            T.reads(x[i0_1], lookup_table[0:1024])
            T.writes(compute[i0_1])
            compute[i0_1] = T.exp(
                T.exp(x[i0_1], dtype="float16"),
                lookup_table.access_ptr("r"),
                dtype="float16",
            )
コード例 #21
0
def element_wise_vthread_x(a: T.handle, b: T.handle) -> None:
    i_0 = T.env_thread("vthread.x")
    i_1 = T.env_thread("threadIdx.x")
    j_0 = T.env_thread("vthread.x")
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    T.launch_thread(i_0, 2)
    T.launch_thread(i_1, 64)
    T.launch_thread(j_0, 2)
    for j_1 in T.serial(0, 64):
        T.store(
            B.data,
            i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1,
            T.load("float32", A.data, i_0 * 8192 + i_1 * 128 + j_0 * 64 + j_1)
            * 2.0,
            True,
        )
コード例 #22
0
def transformed_square_sum_square_root_factor_one_2(a: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    D = T.match_buffer(d, [16])
    C = T.alloc_buffer([16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 1, 65536):
        with T.block("C"):
            b = T.axis.S(16, i0)
            i = T.axis.R(256, T.floordiv(i1_i2_fused_inner, 256))
            j = T.axis.R(256, T.floormod(i1_i2_fused_inner, 256))
            with T.init():
                C[b] = 0.0
            C[b] = C[b] + (A[b, i, j] * A[b, i, j])
    for i0_1 in T.serial(0, 16):
        with T.block("D"):
            b_1 = T.axis.S(16, i0_1)
            D[b_1] = T.sqrt(C[b_1], dtype="float32")
コード例 #23
0
def simple_compute_missing_annotation(A: T.Buffer[(16, 16), "float32"],
                                      C: T.Buffer[(16, 16), "float32"]):
    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
        for i in T.serial(0,
                          16,
                          annotations={"software_pipeline_stage": [0, 1]}):
            with T.block():
                T.reads(A[tx, i])
                T.writes(C[tx, i])
                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
                with T.block():
                    T.reads(A[tx, i])
                    T.writes(B[tx, 0])
                    B[tx, 0] = A[tx, i] * T.float32(2)
                with T.block():
                    T.reads(B[tx, 0])
                    T.writes(C[tx, i])
                    C[tx, i] = B[tx, 0] + T.float32(1)
コード例 #24
0
ファイル: test_lower_build.py プロジェクト: zjppoet/tvm
 def main(a: T.handle, b: T.handle, c: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol": "main",
         "from_legacy_te_schedule": True,
         "tir.noalias": True
     })
     A = T.match_buffer(a, [128, 128])
     B = T.match_buffer(b, [128, 128])
     C = T.match_buffer(c, [128, 128])
     # body
     for x, y in T.grid(128, 128):
         C.data[x * 128 + y] = 0.0
         for k in T.serial(0, 128):
             C.data[x * 128 + y] = T.load(
                 "float32", C.data, x * 128 +
                 y) + T.load("float32", A.data, x * 128 + k) * T.load(
                     "float32", B.data, y * 128 + k)
コード例 #25
0
def concat_func_3(
    placeholder: T.Buffer[(50176,), "int8"],
    placeholder_1: T.Buffer[(25088,), "int8"],
    placeholder_2: T.Buffer[(25088,), "int8"],
    T_concat: T.Buffer[(100352,), "int8"],
) -> None:
    T.preflattened_buffer(placeholder, (1, 64, 28, 28), "int8", data=placeholder.data)
    T.preflattened_buffer(placeholder_1, (1, 32, 28, 28), "int8", data=placeholder_1.data)
    T.preflattened_buffer(placeholder_2, (1, 32, 28, 28), "int8", data=placeholder_2.data)
    T.preflattened_buffer(T_concat, (1, 128, 28, 28), "int8", data=T_concat.data)
    for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}):
        for i2, i3 in T.grid(28, 28):
            if 96 <= i1:
                T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3 - 75264]
            if 64 <= i1 and i1 < 96:
                T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[i1 * 784 + i2 * 28 + i3 - 50176]
            if i1 < 64:
                T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3]
コード例 #26
0
def block_predicate_cache_write_output_buf() -> None:
    A = T.alloc_buffer([120], dtype="float32")
    B = T.alloc_buffer([120], dtype="float32")
    B_shared = T.alloc_buffer([120], dtype="float32", scope="shared")
    for i, j in T.grid(16, 8):
        with T.block("producer"):
            ax = T.axis.spatial(120, i * 8 + j)
            T.where(i * 8 + j < 120)
            A[ax] = T.float32(0)
    for i, j in T.grid(16, 8):
        with T.block("consumer"):
            ax = T.axis.spatial(120, i * 8 + j)
            T.where(i * 8 + j < 120)
            B_shared[ax] = A[ax] + T.float32(1)
    for ax0 in T.serial(120):
        with T.block("B_shared"):
            v0 = T.axis.spatial(120, ax0)
            B[v0] = B_shared[v0]
コード例 #27
0
 def expected(A: T.Buffer[(4, 4), "float32"]):
     for i in T.serial(4):
         if i < 2:
             for j in T.serial(4):
                 if j < 3:
                     for k in T.serial(4):
                         A[i, j] = 0.0
                 else:
                     for k in T.serial(4):
                         A[i, j] = 2.0
         else:
             for j in T.serial(4):
                 if j < 3:
                     for k in T.serial(4):
                         A[i, j] = 1.0
                 else:
                     for k in T.serial(4):
                         A[i, j] = 3.0
コード例 #28
0
def func_3(
    C: T.Buffer[(1,), "float32"],
    A: T.Buffer[(16,), "float32"],
    D: T.Buffer[(2,), "float32"],
    E: T.Buffer[(16,), "float32"],
    F: T.Buffer[(16,), "float32"],
):
    for i in T.serial(
        0,
        16,
    ):
        with T.block():
            B = T.alloc_buffer((1,), dtype="float32")
            with T.block():
                B[0] = A[i] * T.float32(2)
            with T.block():
                E[i] = A[i]
                F[i] = E[i] + 1.0
                C[0] = C[0] + A[i] + B[0] + T.float32(1) + D[0]
                A[i] = B[0] + T.float32(1) + D[1]
コード例 #29
0
def dag_interleaving(
    A: T.Buffer[(16, 16), "float32"],
    B: T.Buffer[(16, 16), "float32"],
    C: T.Buffer[(16, 16), "float32"],
) -> None:
    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
        for i in T.serial(
                0,
                16,
                annotations={
                    "software_pipeline_stage": [0, 0, 0, 0, 1],
                    "software_pipeline_order": [0, 2, 1, 3, 4],
                },
        ):
            with T.block():
                T.reads(A[tx, i])
                T.writes(C[tx, i])
                AS = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
                BS = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
                AL = T.alloc_buffer((1, 1), dtype="float32", scope="local")
                BL = T.alloc_buffer((1, 1), dtype="float32", scope="local")
                with T.block():
                    T.reads(A[tx, i])
                    T.writes(AS[tx, 0])
                    AS[tx, 0] = A[tx, i] * T.float32(2)
                with T.block():
                    T.reads(AS[tx, 0])
                    T.writes(AL[0, 0])
                    AL[0, 0] = AS[tx, 0]
                with T.block():
                    T.reads(B[tx, i])
                    T.writes(BS[tx, 0])
                    BS[tx, 0] = B[tx, i] + T.float32(2)
                with T.block():
                    T.reads(BS[tx, 0])
                    T.writes(BL[0, 0])
                    BL[0, 0] = BS[tx, 0]
                with T.block():
                    T.reads(AL[0, 0], BL[0, 0])
                    T.writes(C[tx, i])
                    C[tx, i] = AL[0, 0] * BL[0, 0]
コード例 #30
0
def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None:
    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
    B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with T.block("root"):
        T.reads([])
        T.writes([])
        for i0_0 in T.serial(0, 16):
            for i0_1_init, i1_init in T.grid(8, 128):
                with T.block("update_init"):
                    vi_init = T.axis.S(128, i0_0 * 8 + i0_1_init)
                    vj_init = T.axis.S(128, i1_init)
                    C[vi_init, vj_init] = T.float32(0)
            for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7):
                with T.block("update_update"):
                    T.where((((i2_0 * 7) + i2_1) < 128))
                    vi = T.axis.S(128, i0_0 * 8 + i0_1)
                    vj = T.axis.S(128, i1)
                    vk = T.axis.R(128, i2_0 * 7 + i2_1)
                    C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])