Exemplo n.º 1
0
def expected_recursive_bufferslice_indices(data: ty.handle,
                                           index: ty.handle) -> None:
    index_buf = tir.match_buffer(index, [1],
                                 dtype="int32",
                                 elem_offset=0,
                                 align=128,
                                 offset_factor=1)
    data_buf = tir.match_buffer(data, [16, 16],
                                elem_offset=0,
                                align=128,
                                offset_factor=1)
    with tir.block([], "root"):
        tir.reads([])
        tir.writes([])
        out_buf = tir.alloc_buffer([16, 16],
                                   elem_offset=0,
                                   align=128,
                                   offset_factor=1)
        for i0, i1 in tir.grid(16, 16):
            with tir.block([16, 16], "") as [vi, vj]:
                tir.bind(vi, i0)
                tir.bind(vj, i1)
                tir.reads([data_buf[0:16, 0:16], index_buf[0]])
                tir.writes([out_buf[vi, vj]])
                out_buf[vi, vj] = data_buf[index_buf[index_buf[0]],
                                           index_buf[0]]
Exemplo n.º 2
0
def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [1024])
    B = tir.match_buffer(b, [1024])
    for i in tir.serial(0, 8):
        with tir.block([8]) as [vi]:
            tir.reads(A[vi * 128:vi * 128 + 128])
            tir.writes(B[vi * 128:vi * 128 + 128])
            A_cache = tir.alloc_buffer([1024])
            with tir.block([8]) as [v]:
                tir.bind(v, vi)
                tir.reads([A[v * 128:v * 128 + 128]])
                tir.writes([A_cache[v * 128:v * 128 + 128]])
                tir.evaluate(
                    tir.call_extern("test",
                                    A_cache.data,
                                    v * 128,
                                    128,
                                    A.data,
                                    v * 128,
                                    128,
                                    dtype="float32"))
            for j in tir.serial(0, 128):
                with tir.block([1024]) as [v]:
                    tir.bind(v, ((vi * 128) + j))
                    tir.reads([A_cache[v]])
                    tir.writes([B[v]])
                    B[v] = A_cache[v]
Exemplo n.º 3
0
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None:
    B = tir.match_buffer(b, [16, 16], "float32")
    C = tir.match_buffer(c, [16, 16], "float32")

    with tir.block([]):
        tir.reads([])
        tir.writes(B[0:16, 0:16])
        A = tir.allocate([256], "float32", "global")
        for i, j in tir.grid(16, 16):
            tir.store(A, i * 16 + j, 1)
        for i in range(0, 16):
            for j in range(0, 16):
                tir.evaluate(tir.load("float32", A, i * 16 + j))
            for j in range(0, 16):
                tir.evaluate(
                    tir.tvm_fill_fragment(B.data,
                                          16,
                                          16,
                                          16,
                                          0,
                                          tir.float32(0),
                                          dtype="handle"))

    for i, j in tir.grid(16, 16):
        with tir.block([16, 16]) as [vi, vj]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            C[vi, vj] = B[vi, vj]
def transformed_match_buffer_func() -> None:
    for i in range(0, 128):
        with tir.block([128]) as [vi]:
            tir.bind(vi, i)
            C = tir.alloc_buffer((128, 128))
            C0 = tir.match_buffer(C[vi, 0:128], (128))
            with tir.block([128]) as [jj]:
                C1 = tir.match_buffer(C0[jj], ())
                C1[()] = 0
Exemplo n.º 5
0
def rowsum_transformed(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for io, ii_ko_fused, ki in tir.grid(32, 128, 4):
        with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
            tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32))
            tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki)
            with tir.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Exemplo n.º 6
0
def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for i, k in tir.grid(128, 16):
        with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
            tir.bind(vi, i)
            tir.bind(vk, tir.floordiv(k * k, 2))
            with tir.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
def elementwise_affine_producer(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128), "float32")
    C = tir.match_buffer(c, (128, 128), "float32")
    B = tir.alloc_buffer((128, 128), "float32")
    for i, j, k, l in tir.grid(16, 2, 32, 16):
        with tir.block([128, 128], "B") as [vi, vj]:
            tir.bind(vi, i * 8 + j * 4 + k // 8)
            tir.bind(vj, k % 8 * 16 + l)
            B[vi, vj] = A[vi, vj] * 2.0
    with tir.block([128, 128], "C") as [vi, vj]:
        C[vi, vj] = B[vi, vj] + 1.0
def opaque_access_func() -> None:
    A = tir.alloc_buffer([1024])
    B = tir.alloc_buffer([1024])
    for i in tir.serial(0, 8):
        with tir.block([8]) as [v]:
            tir.bind(v, i)
            tir.reads([A[v * 128 : v * 128 + 128]])
            tir.writes([B[v * 128 : v * 128 + 128]])
            tir.evaluate(
                tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32")
            )
Exemplo n.º 9
0
def tiled(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128], "float32")
    B = tir.alloc_buffer([128, 128], "float32")
    C = tir.match_buffer(c, [128, 128], "float32")
    for i_0, j_0, i_1, j_1 in tir.grid(8, 8, 16, 16):
        with tir.block([128, 128], "B") as [vi, vj]:
            tir.bind(vi, i_0 * 16 + i_1)
            tir.bind(vj, j_0 * 16 + j_1)
            B[vi, vj] = A[vi, vj] * 2.0
    with tir.block([128, 128], "C") as [vi, vj]:
        C[vi, vj] = B[vi, vj] + 1.0
Exemplo n.º 10
0
def rowsum_not_serial(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for i in tir.serial(0, 128):
        for k in tir.parallel(0, 128):
            with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
                tir.bind(vi, i)
                tir.bind(vk, k)
                with tir.init():
                    B[vi] = 0.0
                B[vi] = B[vi] + A[vi, vk]
Exemplo n.º 11
0
def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = tir.match_buffer(a, [2048, 2048], "float32")
    B = tir.match_buffer(b, [2048, 2048], "float32")
    C = tir.match_buffer(c, [2048, 2048], "float32")
    A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    with tir.block([2048, 2048], "A_shared") as [v0, v1]:
        A_shared[v0, v1] = A[v0, v1]
    with tir.block([2048, 2048], "B_shared") as [v0, v1]:
        B_shared[v0, v1] = B[v0, v1]
    with tir.block([2048, 2048], "A_shared_local") as [v0, v1]:
        A_shared_local[v0, v1] = A_shared[v0, v1]
    with tir.block([2048, 2048], "B_shared_local") as [v0, v1]:
        B_shared_local[v0, v1] = B_shared[v0, v1]
    for by in tir.thread_binding(0, 32, thread="blockIdx.y"):
        for bx in tir.thread_binding(0, 32, thread="blockIdx.x"):
            for vy in tir.thread_binding(0, 2, thread="vthread.y"):
                for vx in tir.thread_binding(0, 2, thread="vthread.x"):
                    for ty in tir.thread_binding(0, 8, thread="threadIdx.y"):
                        for tx in tir.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.x"):
                            for k_0 in tir.serial(0, 256):
                                for k_1 in tir.unroll(0, 8):
                                    for _, i, j in tir.grid(1, 4, 4):
                                        with tir.block([
                                                2048, 2048,
                                                tir.reduce_axis(0, 2048)
                                        ], "C") as [vi, vj, vk]:
                                            tir.bind(
                                                vi,
                                                by * 64 + vy * 32 + ty * 4 + i)
                                            tir.bind(
                                                vj,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            tir.bind(vk, k_0 * 8 + k_1)
                                            with tir.init():
                                                C_local[vi, vj] = 0.0
                                            C_local[vi, vj] = C_local[
                                                vi, vj] + A_shared_local[
                                                    vk, vi] * B_shared_local[
                                                        vk, vj]
                            for i, j in tir.grid(4, 4):
                                with tir.block([2048, 2048],
                                               "C_local") as [vi, vj]:
                                    tir.bind(vi,
                                             by * 64 + vy * 32 + ty * 4 + i)
                                    tir.bind(vj,
                                             bx * 64 + vx * 32 + tx * 4 + j)
                                    C[vi, vj] = C_local[vi, vj]
def concatenate_multi_producer_uncovered(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, ))
    B = tir.match_buffer(b, (128, ))
    for i in range(0, 63):
        with tir.block([63], "A_0") as vi:
            A[vi] = vi + 1
    for i in range(0, 64):
        with tir.block([64], "A_1") as vi:
            tir.bind(vi, i + 64)
            A[vi] = vi + 2
    with tir.block([128], "B") as vi:
        B[vi] = A[vi] * 2.0
Exemplo n.º 13
0
def two_elementwise_after_compute_at(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128), "float32")
    B = tir.alloc_buffer((128, 128), "float32")
    C = tir.match_buffer(c, (128, 128), "float32")
    for i in range(0, 128):
        for ax0, ax1 in tir.grid(1, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                tir.bind(vi, i + ax0)
                tir.bind(vj, ax1)
                B[vi, vj] = A[vi, vj] * 2.0
        for j in range(0, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                C[vi, vj] = B[vi, vj] + 1.0
Exemplo n.º 14
0
def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None:
    C = tir.match_buffer(c, [128, 128],
                         elem_offset=0,
                         align=128,
                         offset_factor=1)
    A = tir.match_buffer(a, [128, 128],
                         elem_offset=0,
                         align=128,
                         offset_factor=1)
    # body
    with tir.block([], "root"):
        tir.reads([])
        tir.writes([])
        B = tir.alloc_buffer([128, 128],
                             elem_offset=0,
                             align=128,
                             offset_factor=1)
        for i0 in tir.serial(0, 128):
            for ax1 in tir.serial(0, 128):
                with tir.block([128, 128], "B") as [vi, vj]:
                    tir.block_attr({"buffer_dim_align": [0]})
                    tir.bind(vi, i0)
                    tir.bind(vj, ax1)
                    tir.reads([A[vi, vj]])
                    tir.writes([B[vi, vj]])
                    B[vi, vj] = (A[vi, vj] * tir.float32(2))
            for i1 in tir.serial(0, 128):
                with tir.block([128, 128], "C") as [vi_1, vj_1]:
                    tir.bind(vi_1, i0)
                    tir.bind(vj_1, i1)
                    tir.reads([B[vi_1, vj_1]])
                    tir.writes([C[vi_1, vj_1]])
                    C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
Exemplo n.º 15
0
def blockized_1(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128], "float32")
    B = tir.alloc_buffer([128, 128], "float32")
    C = tir.match_buffer(c, [128, 128], "float32")
    with tir.block([128, 128], "B") as [vi, vj]:
        B[vi, vj] = A[vi, vj] * 2.0
    with tir.block([8, 8], "C_outer") as [vi_o, vj_o]:
        tir.reads([B[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16, ]])
        tir.writes([C[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16]])
        for i_i, j_i in tir.grid(16, 16):
            with tir.block([128, 128], "C_inner") as [vi, vj]:
                tir.bind(vi, vi_o * 16 + i_i)
                tir.bind(vj, vj_o * 16 + j_i)
                C[vi, vj] = B[vi, vj] + 1.0
def func_multi_consumer() -> None:
    A = tir.alloc_buffer((128))
    B = tir.alloc_buffer((128))
    C = tir.alloc_buffer((128))
    for i in tir.grid(8):
        for j in tir.grid(16):
            with tir.block([128], "A") as [vi]:
                tir.bind(vi, i * 16 + j)
                A[vi] = 1.0
        for j in tir.grid(16):
            with tir.block([128], "B") as [vi]:
                tir.bind(vi, i * 16 + j)
                B[vi] = A[vi] + 1.0
    for i in tir.grid(128):
        with tir.block([128], "C") as [vi]:
            C[vi] = A[vi]
Exemplo n.º 17
0
def read_out_of_bound_after_compute_at(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16], "float32")
    B = tir.alloc_buffer([16], "float32")
    C = tir.match_buffer(c, [16], "float32")
    for j in tir.serial(0, 16):
        for i in tir.serial(0, tir.min(1, 15 - j) + 1):
            with tir.block([16], "B") as [v]:
                tir.bind(v, j + i)
                B[v] = A[v]
        with tir.block([16], "C") as [v]:
            tir.bind(v, j)
            tir.reads([B[v:v + 2]])
            C[v] = tir.if_then_else(v < 15,
                                    tir.max(B[v], B[v + 1]),
                                    B[v],
                                    dtype="float32")
def multi_producer_consumer(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, ))
    B = tir.match_buffer(b, (128, ))
    for i in range(0, 64):
        with tir.block([64], "A_0") as vi:
            A[vi] = vi + 1
    for i in range(0, 64):
        with tir.block([64], "A_1") as vi:
            tir.bind(vi, i + 64)
            A[vi] = vi + 2
    for i in range(0, 64):
        with tir.block([64], "B_0") as vi:
            B[vi] = A[vi] + 2.0
    for i in range(0, 64):
        with tir.block([64], "B_1") as vi:
            tir.bind(vi, i + 64)
            B[vi] = A[vi] + 3.0
def elementwise_reordered2(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128, 128))
    for k, j, i, l in tir.grid(128, 128, 128, 128):
        with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            tir.bind(vk, k)
            tir.bind(vl, l)
            B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_not_affine(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128, 128))
    for i, j, k, l in tir.grid(128, 128, 128, 8):
        with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            tir.bind(vk, k)
            tir.bind(vl, l * 16)
            B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_reordered_with_predicate(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128, 128))
    for l, j, k, i in tir.grid(128, 128, 128, 128):
        with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]:
            tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100)
            tir.bind(vi, i)
            tir.bind(vj, j)
            tir.bind(vk, k)
            tir.bind(vl, l)
            B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
Exemplo n.º 22
0
def factorized_after_reverse_compute_at(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16, 16], "float32")
    B = tir.match_buffer(b, [16], "float32")
    B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local")
    for j in tir.thread_binding(0, 16, thread="blockIdx.x"):
        for i_o in tir.thread_binding(0, 4, thread="threadIdx.x"):
            for i_i, k in tir.grid(4, 16):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "B_rf") as [vi, vj, vk]:
                    tir.bind(vi, i_o * 4 + i_i)
                    tir.bind(vj, j)
                    tir.bind(vk, k)
                    with tir.init():
                        B_rf_local[vi, vj] = 0.0
                    B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk]
            for k in tir.serial(0, 4):
                with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]:
                    tir.bind(vi, j)
                    tir.bind(vk, i_o * 4 + k)
                    with tir.init():
                        B[vi] = 0.0
                    B[vi] = B[vi] + B_rf_local[vk, vi]
Exemplo n.º 23
0
def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 256, 256])
    C = tir.match_buffer(c, [16])
    C_rf = tir.alloc_buffer([16, 256])

    for i0, i1, i2 in tir.grid(16, 256, 256):
        with tir.block([256, 16, tir.reduce_axis(0, 256)],
                       "C_rf") as [vi2, b, i]:
            tir.bind(vi2, i2)
            tir.bind(b, i0)
            tir.bind(i, i1)
            with tir.init():
                C_rf[b, vi2] = 0.0
            C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2])

    for i0_1, i2_1 in tir.grid(16, 256):
        with tir.block([tir.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]:
            tir.bind(vi2_1, i2_1)
            tir.bind(b_1, i0_1)
            with tir.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
Exemplo n.º 24
0
def elementwise_under_loop(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    C = tir.match_buffer(c, (128, 128))
    B = tir.alloc_buffer((128, 128))
    for i in tir.serial(0, 128):
        for j in tir.serial(0, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                B[vi, vj] = A[vi, vj] * 2.0
        for j in tir.serial(0, 128):
            with tir.block([128, 128], "C") as [vi, vj]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                C[vi, vj] = B[vi, vj] + 1.0
Exemplo n.º 25
0
def tiled_after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128], "float32")
    B = tir.alloc_buffer([128, 128], "float32")
    C = tir.match_buffer(c, [128, 128], "float32")
    for i_0, j_0, i_1 in tir.grid(8, 8, 16):
        for j_1 in tir.serial(0, 16):
            with tir.block([128, 128], "B") as [vi, vj]:
                tir.bind(vi, i_0 * 16 + i_1)
                tir.bind(vj, j_0 * 16 + j_1)
                B[vi, vj] = A[vi, vj] * 2.0
        for j_1 in tir.serial(0, 16):
            with tir.block([128, 128], "C") as [vi, vj]:
                tir.bind(vi, i_0 * 16 + i_1)
                tir.bind(vj, j_0 * 16 + j_1)
                C[vi, vj] = B[vi, vj] + 1.0
def equal_ranked_threads(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    B = tir.alloc_buffer([128, 128], scope="shared")
    for i_o in tir.thread_binding(0, 16, thread="threadIdx.x"):
        for i_i in tir.thread_binding(0, 8, thread="threadIdx.y"):
            for j in tir.serial(0, 128):
                with tir.block([128, 128], "B") as [vi, vj]:
                    tir.bind(vi, i_o * 8 + i_i)
                    tir.bind(vj, j)
                    B[vi, vj] = A[vi, vj] * 2.0
            for j in tir.serial(0, 128):
                with tir.block([128, 128], "C") as [vi, vj]:
                    tir.bind(vi, i_o * 8 + i_i)
                    tir.bind(vj, j)
                    C[vj, vi] = B[vj, vi] + 1.0
Exemplo n.º 27
0
def fail_subtree_compact_dataflow(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128), "float32")
    B = tir.alloc_buffer((128, 128), "float32")
    C = tir.match_buffer(c, (128, 128), "float32")
    for i in range(0, 128):
        for j in range(0, 64):
            with tir.block([128, 128], "B_0") as [vi, vj]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                B[vi, vj] = A[vi, vj] * 2.0
        for j in range(0, 64):
            with tir.block([128, 128], "B_1") as [vi, vj]:
                tir.bind(vi, i)
                tir.bind(vj, j + 64)
                B[vi, vj] = A[vi, vj] * 2.0
    with tir.block([128, 128], "C") as [vi, vj]:
        C[vi, vj] = B[vi, vj] + 1.0
def elementwise_func(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16), "float32")
    C = tir.match_buffer(c, (16, 16), "float32")
    for i in range(0, 16):
        with tir.block([]):
            tir.reads(A[i, 0:16])
            tir.writes(C[i, 0:16])
            B = tir.alloc_buffer((16, 16), "float32")
            for j in range(0, 16):
                with tir.block([16, 16]) as [vi, vj]:
                    tir.bind(vi, i)
                    tir.bind(vj, j)
                    B[vi, vj] = A[vi, vj] + 1.0
            for j in range(0, 16):
                with tir.block([16, 16]) as [vi, vj]:
                    tir.bind(vi, i)
                    tir.bind(vj, j)
                    C[vi, vj] = B[vi, vj] * 2.0
Exemplo n.º 29
0
def elementwise_fused(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for fused in tir.serial(0, 2097152):
        with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
            tir.bind(vi, tir.floordiv(fused, 16384))
            tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128))
            tir.bind(vk, tir.floormod(fused, 128))
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemplo n.º 30
0
def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128, 128])
    B = tir.match_buffer(b, [128, 128, 128])
    for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8):
        with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
            tir.bind(vi, ((i1 * 64) + i3))
            tir.bind(vj, ((j1 * 32) + j2))
            tir.bind(vk, ((k1 * 8) + k2))
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0