def transformed_func() -> None:
    A = tir.alloc_buffer([128, 128])
    with tir.block([128, 128], "") as [i, j]:
        A[i, j] = tir.float32(0)
    with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]:
        B = tir.alloc_buffer([128, 128])
        if k == 0:
            for ii, jj in tir.grid(4, 4):
                B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
        for ii, jj in tir.grid(4, 4):
            with tir.block([], ""):
                tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]])
                tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
                C = tir.alloc_buffer([128, 128])
                for kk in tir.serial(0, 4):
                    B[((i * 4) + ii),
                      ((j * 4) + jj)] = (B[((i * 4) + ii),
                                           ((j * 4) + jj)] + C[((i * 4) + ii),
                                                               ((k * 4) + kk)])
                for kk in tir.serial(0, 4):
                    with tir.block([], ""):
                        tir.reads([
                            B[((i * 4) + ii), ((j * 4) + jj)],
                            C[((i * 4) + ii), ((k * 4) + kk)],
                        ])
                        tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
                        D = tir.alloc_buffer([128, 128])
                        B[((i * 4) + ii),
                          ((j * 4) +
                           jj)] = B[((i * 4) + ii),
                                    ((j * 4) + jj)] + (D[((j * 4) + jj), (
                                        (k * 4) + kk)] * C[((i * 4) + ii),
                                                           ((k * 4) + kk)])
Ejemplo n.º 2
0
def get_valid_counts(
    data: ty.handle,
    valid_count: ty.handle,
    out: ty.handle,
    out_indices: ty.handle,
    score_threshold: ty.float32,
    id_index: ty.int32,
    score_index: ty.int32,
) -> None:

    data_buf = tir.match_buffer(data, (1, 2500, 6), "float32")
    valid_count_buf = tir.match_buffer(valid_count, (1, ), "int32")
    out_buf = tir.match_buffer(out, (1, 2500, 6), "float32")
    out_indices_buf = tir.match_buffer(out_indices, (1, 2500), "int32")

    with tir.block([1], "init") as [vi]:
        valid_count_buf[vi] = tir.int32(0)
        with tir.block([2500], "update") as [vj]:
            tir.reads([data_buf[vi, vj, 6]])
            tir.writes([
                valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj,
                                                                      6]
            ])
            if (data_buf[vi, vj, score_index] > score_threshold) and (
                (id_index < 0) or
                (data_buf[vi, vj, id_index] >= tir.float32(0))):
                for k in tir.serial(0, 6):
                    out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k]
                out_indices_buf[vi, valid_count_buf[vi]] = vj
                valid_count_buf[vi] = valid_count_buf[vi] + 1
            if vj >= valid_count_buf[vi]:
                for k in tir.serial(0, 6):
                    out_buf[vi, vj, k] = tir.float32(-1)
                out_indices_buf[vi, vj] = tir.int32(-1)
Ejemplo n.º 3
0
def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None:
    C = tir.match_buffer(c, [128, 128],
                         elem_offset=0,
                         align=128,
                         offset_factor=1)
    A = tir.match_buffer(a, [128, 128],
                         elem_offset=0,
                         align=128,
                         offset_factor=1)
    # body
    with tir.block([], "root"):
        tir.reads([])
        tir.writes([])
        B = tir.alloc_buffer([128, 128],
                             elem_offset=0,
                             align=128,
                             offset_factor=1)
        for i0 in tir.serial(0, 128):
            for ax1 in tir.serial(0, 128):
                with tir.block([128, 128], "B") as [vi, vj]:
                    tir.block_attr({"buffer_dim_align": [0]})
                    tir.bind(vi, i0)
                    tir.bind(vj, ax1)
                    tir.reads([A[vi, vj]])
                    tir.writes([B[vi, vj]])
                    B[vi, vj] = (A[vi, vj] * tir.float32(2))
            for i1 in tir.serial(0, 128):
                with tir.block([128, 128], "C") as [vi_1, vj_1]:
                    tir.bind(vi_1, i0)
                    tir.bind(vj_1, i1)
                    tir.reads([B[vi_1, vj_1]])
                    tir.writes([C[vi_1, vj_1]])
                    C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1))
Ejemplo n.º 4
0
def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [1024])
    B = tir.match_buffer(b, [1024])
    for i in tir.serial(0, 8):
        with tir.block([8]) as [vi]:
            tir.reads(A[vi * 128:vi * 128 + 128])
            tir.writes(B[vi * 128:vi * 128 + 128])
            A_cache = tir.alloc_buffer([1024])
            with tir.block([8]) as [v]:
                tir.bind(v, vi)
                tir.reads([A[v * 128:v * 128 + 128]])
                tir.writes([A_cache[v * 128:v * 128 + 128]])
                tir.evaluate(
                    tir.call_extern("test",
                                    A_cache.data,
                                    v * 128,
                                    128,
                                    A.data,
                                    v * 128,
                                    128,
                                    dtype="float32"))
            for j in tir.serial(0, 128):
                with tir.block([1024]) as [v]:
                    tir.bind(v, ((vi * 128) + j))
                    tir.reads([A_cache[v]])
                    tir.writes([B[v]])
                    B[v] = A_cache[v]
Ejemplo n.º 5
0
    def range_missing_args(a: ty.handle) -> None:
        A = tir.match_buffer(a, (16, 16), "float32")

        tir.attr(A, "realize_scope", "")
        tir.realize(A[0:16, 0:16])
        for i in tir.serial(16):
            for j in tir.serial(0, 16):
                A[i, j] = 0.0
Ejemplo n.º 6
0
    def undefined_buffer(a: ty.handle) -> None:
        A = tir.match_buffer(a, (16, 16), "float32")

        tir.attr(A, "realize_scope", "")
        tir.realize(C[0:16, 0:16])
        for i in tir.serial(16):
            for j in tir.serial(0, 16):
                A[i, j] = 0.0
Ejemplo n.º 7
0
def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16, 16])
    C = tir.alloc_buffer([16, 16])
    D = tir.alloc_buffer([16, 16])
    E = tir.alloc_buffer([16, 16])
    F = tir.match_buffer(f, [16, 16])
    C_rf = tir.alloc_buffer([16, 16, 4])

    for i, j1, k1o, k1i in tir.grid(16, 16, 4, 4):
        with tir.block([4, 16, 16, tir.reduce_axis(0, 4)],
                       "C_rf") as [vk1o, ci, cj, vk1i]:
            tir.bind(vk1o, k1o)
            tir.bind(ci, i)
            tir.bind(cj, j1)
            tir.bind(vk1i, k1i)
            with tir.init():
                C_rf[ci, cj, vk1o] = 0.0
            C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj,
                                                        ((vk1o * 4) + vk1i)]
    for i_1 in tir.serial(0, 16):
        for j1_1 in tir.serial(0, 16):
            for k1o_1 in tir.serial(0, 4):
                with tir.block([tir.reduce_axis(0, 4), 16, 16],
                               "C") as [vk1o_1, ci_1, cj_1]:
                    tir.bind(vk1o_1, k1o_1)
                    tir.bind(ci_1, i_1)
                    tir.bind(cj_1, j1_1)
                    with tir.init():
                        C[ci_1, cj_1] = 0.0
                    C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1]
            for k2o, k2i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "D") as [di, dj, dk]:
                    tir.bind(di, i_1)
                    tir.bind(dj, j1_1)
                    tir.bind(dk, (k2o * 4) + k2i)
                    with tir.init():
                        D[di, dj] = 0.0
                    D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj]
        for j2 in tir.serial(0, 16):
            for k3o, k3i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "E") as [ei, ej, ek]:
                    tir.bind(ei, i_1)
                    tir.bind(ej, j2)
                    tir.bind(ek, (k3o * 4) + k3i)
                    with tir.init():
                        E[ei, ej] = 0.0
                    E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej]
            for k4o, k4i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "F") as [fi, fj, fk]:
                    tir.bind(fi, i_1)
                    tir.bind(fj, j2)
                    tir.bind(fk, (k4o * 4) + k4i)
                    with tir.init():
                        F[fi, fj] = 0.0
                    F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
Ejemplo n.º 8
0
def flattened_elementwise_func(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16), "float32")
    C = tir.match_buffer(c, (16, 16), "float32")
    for i in tir.serial(0, 16):
        B_new = tir.allocate([16], "float32", "global")
        for j in tir.serial(0, 16):
            B_new[j] = tir.load("float32", A.data, ((i * 16) + j)) + 1.0
        for j in tir.serial(0, 16):
            C.data[((i * 16) + j)] = tir.load("float32", B_new, j) * 2.0
def bound_to_thread(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    B = tir.alloc_buffer([128, 128], scope="shared")
    for i in tir.thread_binding(0, 128, thread="threadIdx.x"):
        for j in tir.serial(0, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                B[vi, vj] = A[vi, vj] * 2.0
        for j in tir.serial(0, 128):
            with tir.block([128, 128], "C") as [vi, vj]:
                C[vj, vi] = B[vj, vi] + 1.0
Ejemplo n.º 10
0
def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    C = tir.alloc_buffer((128, 128, 128))
    for i, j in tir.grid(128, 128):
        for k in tir.serial(0, 128):
            with tir.block([128, 128, 128], "C") as [vi, vj, vk]:
                C[vi, vj, vk] = A[vi, vj, vk] * 2.0
        for k in tir.serial(0, 128):
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                B[vi, vj, vk] = C[vi, vj, vk] * 2.0
def warp_memory(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    B = tir.alloc_buffer([128, 4, 32], scope="warp")
    for i_o in tir.thread_binding(0, 4, thread="threadIdx.y"):
        for i_i in tir.thread_binding(0, 32, thread="threadIdx.x"):
            for j in tir.serial(0, 128):
                with tir.block([4, 32, 128], "B") as [warp_id, lane_id, vj]:
                    B[vj, warp_id, lane_id] = A[warp_id * 32 + lane_id, vj] * 2.0
            for j in tir.serial(0, 128):
                with tir.block([4, 32, 128], "C") as [warp_id, lane_id, vj]:
                    C[warp_id * 32 + lane_id, vj] = B[vj, warp_id, lane_id] + 1.0
Ejemplo n.º 12
0
def read_out_of_bound(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16], "float32")
    B = tir.alloc_buffer([16], "float32")
    C = tir.match_buffer(c, [16], "float32")
    for i in tir.serial(0, 16):
        with tir.block([16], "B") as [v]:
            B[v] = A[v]
    for j in tir.serial(0, 16):
        with tir.block([16], "C") as [v]:
            tir.reads(B[v:v + 2])
            C[v] = tir.if_then_else(v < 15,
                                    tir.max(B[v], B[v + 1]),
                                    B[v],
                                    dtype="float32")
Ejemplo n.º 13
0
def tiled_after_reverse_compute_at(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128], "float32")
    B = tir.alloc_buffer([128, 128], "float32")
    C = tir.match_buffer(c, [128, 128], "float32")
    for i_0, j_0, i_1 in tir.grid(8, 8, 16):
        for j_1 in tir.serial(0, 16):
            with tir.block([128, 128], "B") as [vi, vj]:
                tir.bind(vi, i_0 * 16 + i_1)
                tir.bind(vj, j_0 * 16 + j_1)
                B[vi, vj] = A[vi, vj] * 2.0
        for j_1 in tir.serial(0, 16):
            with tir.block([128, 128], "C") as [vi, vj]:
                tir.bind(vi, i_0 * 16 + i_1)
                tir.bind(vj, j_0 * 16 + j_1)
                C[vi, vj] = B[vi, vj] + 1.0
Ejemplo n.º 14
0
def elementwise_under_loop(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    C = tir.match_buffer(c, (128, 128))
    B = tir.alloc_buffer((128, 128))
    for i in tir.serial(0, 128):
        for j in tir.serial(0, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                B[vi, vj] = A[vi, vj] * 2.0
        for j in tir.serial(0, 128):
            with tir.block([128, 128], "C") as [vi, vj]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                C[vi, vj] = B[vi, vj] + 1.0
def elementwise_dependent_loop(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128, 128))
    for i in tir.serial(0, 128):
        for j, k, l in tir.grid(128, i, 128):
            with tir.block([128, 128, i, 128], "B") as [vi, vj, vk, vl]:
                B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
Ejemplo n.º 16
0
def read_out_of_bound_after_compute_at(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16], "float32")
    B = tir.alloc_buffer([16], "float32")
    C = tir.match_buffer(c, [16], "float32")
    for j in tir.serial(0, 16):
        for i in tir.serial(0, tir.min(1, 15 - j) + 1):
            with tir.block([16], "B") as [v]:
                tir.bind(v, j + i)
                B[v] = A[v]
        with tir.block([16], "C") as [v]:
            tir.bind(v, j)
            tir.reads([B[v:v + 2]])
            C[v] = tir.if_then_else(v < 15,
                                    tir.max(B[v], B[v + 1]),
                                    B[v],
                                    dtype="float32")
def equal_ranked_threads(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    B = tir.alloc_buffer([128, 128], scope="shared")
    for i_o in tir.thread_binding(0, 16, thread="threadIdx.x"):
        for i_i in tir.thread_binding(0, 8, thread="threadIdx.y"):
            for j in tir.serial(0, 128):
                with tir.block([128, 128], "B") as [vi, vj]:
                    tir.bind(vi, i_o * 8 + i_i)
                    tir.bind(vj, j)
                    B[vi, vj] = A[vi, vj] * 2.0
            for j in tir.serial(0, 128):
                with tir.block([128, 128], "C") as [vi, vj]:
                    tir.bind(vi, i_o * 8 + i_i)
                    tir.bind(vj, j)
                    C[vj, vi] = B[vj, vi] + 1.0
Ejemplo n.º 18
0
def multiple_reduction_blocks(a: ty.handle, f: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16, 16))
    C = tir.alloc_buffer((16, 16))
    D = tir.alloc_buffer((16, 16))
    E = tir.alloc_buffer((16, 16))
    F = tir.match_buffer(f, (16, 16))

    for i in tir.serial(0, 16):
        for j1 in tir.serial(0, 16):
            for k1o, k1i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "C") as [ci, cj, ck]:
                    tir.bind(ci, i)
                    tir.bind(cj, j1)
                    tir.bind(ck, k1o * 4 + k1i)
                    with tir.init():
                        C[ci, cj] = 0.0
                    C[ci, cj] = C[ci, cj] + A[ci, cj, ck]
            for k2o, k2i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "D") as [di, dj, dk]:
                    tir.bind(di, i)
                    tir.bind(dj, j1)
                    tir.bind(dk, k2o * 4 + k2i)
                    with tir.init():
                        D[di, dj] = 0.0
                    D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj]
        for j2 in tir.serial(0, 16):
            for k3o, k3i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "E") as [ei, ej, ek]:
                    tir.bind(ei, i)
                    tir.bind(ej, j2)
                    tir.bind(ek, k3o * 4 + k3i)
                    with tir.init():
                        E[ei, ej] = 0.0
                    E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej]
            for k4o, k4i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "F") as [fi, fj, fk]:
                    tir.bind(fi, i)
                    tir.bind(fj, j2)
                    tir.bind(fk, k4o * 4 + k4i)
                    with tir.init():
                        F[fi, fj] = 0.0
                    F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
Ejemplo n.º 19
0
    def unsupported_function_call(a: ty.handle) -> None:
        A = tir.match_buffer(a, (16, 16), "float32")

        tir.attr(A, "realize_scope", "")
        tir.realize(A[0:16, 0:16])
        for i in tir.const_range(16):
            for j in tir.serial(0, 16):
                A[i, j] = 0.0
def elementwise_non_single_branch(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    C = tir.alloc_buffer((128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j in tir.grid(128, 128):
        for k in tir.serial(0, 128):
            with tir.block([128, 128, 128], "C") as [vi, vj, vk]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                tir.bind(vk, k)
                C[vi, vj, vk] = A[vi, vj, vk] * 2.0
        for k in tir.serial(0, 128):
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                tir.bind(vk, k)
                B[vi, vj, vk] = C[vi, vj, vk] * 2.0
def transformed_element_func(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16])
    C = tir.match_buffer(c, [16, 16])

    for i_0 in range(0, 16):
        with tir.block([]):
            tir.reads([A[i_0, 0:16]])
            tir.writes([C[i_0, 0:16]])
            B = tir.alloc_buffer([16, 16])
            for j_0 in tir.serial(0, 16):
                with tir.block([16, 16], "") as [i, j]:
                    tir.bind(i, i_0)
                    tir.bind(j, j_0)
                    B[i, j] = A[i, j] + 1.0
            for j_0 in tir.serial(0, 16):
                with tir.block([16, 16], "") as [i, j]:
                    tir.bind(i, i_0)
                    tir.bind(j, j_0)
                    C[i, j] = B[i, j] * 2.0
Ejemplo n.º 22
0
def elementwise_fused(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for fused in tir.serial(0, 2097152):
        with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
            tir.bind(vi, tir.floordiv(fused, 16384))
            tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128))
            tir.bind(vk, tir.floormod(fused, 128))
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def opaque_access_func() -> None:
    A = tir.alloc_buffer([1024])
    B = tir.alloc_buffer([1024])
    for i in tir.serial(0, 8):
        with tir.block([8]) as [v]:
            tir.bind(v, i)
            tir.reads([A[v * 128 : v * 128 + 128]])
            tir.writes([B[v * 128 : v * 128 + 128]])
            tir.evaluate(
                tir.call_extern("test", B.data, v * 128, 128, A.data, v * 128, 128, dtype="float32")
            )
Ejemplo n.º 24
0
def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = tir.match_buffer(a, [2048, 2048], "float32")
    B = tir.match_buffer(b, [2048, 2048], "float32")
    C = tir.match_buffer(c, [2048, 2048], "float32")
    A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    with tir.block([2048, 2048], "A_shared") as [v0, v1]:
        A_shared[v0, v1] = A[v0, v1]
    with tir.block([2048, 2048], "B_shared") as [v0, v1]:
        B_shared[v0, v1] = B[v0, v1]
    with tir.block([2048, 2048], "A_shared_local") as [v0, v1]:
        A_shared_local[v0, v1] = A_shared[v0, v1]
    with tir.block([2048, 2048], "B_shared_local") as [v0, v1]:
        B_shared_local[v0, v1] = B_shared[v0, v1]
    for by in tir.thread_binding(0, 32, thread="blockIdx.y"):
        for bx in tir.thread_binding(0, 32, thread="blockIdx.x"):
            for vy in tir.thread_binding(0, 2, thread="vthread.y"):
                for vx in tir.thread_binding(0, 2, thread="vthread.x"):
                    for ty in tir.thread_binding(0, 8, thread="threadIdx.y"):
                        for tx in tir.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.x"):
                            for k_0 in tir.serial(0, 256):
                                for k_1 in tir.unroll(0, 8):
                                    for _, i, j in tir.grid(1, 4, 4):
                                        with tir.block([
                                                2048, 2048,
                                                tir.reduce_axis(0, 2048)
                                        ], "C") as [vi, vj, vk]:
                                            tir.bind(
                                                vi,
                                                by * 64 + vy * 32 + ty * 4 + i)
                                            tir.bind(
                                                vj,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            tir.bind(vk, k_0 * 8 + k_1)
                                            with tir.init():
                                                C_local[vi, vj] = 0.0
                                            C_local[vi, vj] = C_local[
                                                vi, vj] + A_shared_local[
                                                    vk, vi] * B_shared_local[
                                                        vk, vj]
                            for i, j in tir.grid(4, 4):
                                with tir.block([2048, 2048],
                                               "C_local") as [vi, vj]:
                                    tir.bind(vi,
                                             by * 64 + vy * 32 + ty * 4 + i)
                                    tir.bind(vj,
                                             bx * 64 + vx * 32 + tx * 4 + j)
                                    C[vi, vj] = C_local[vi, vj]
Ejemplo n.º 25
0
def rowsum_not_serial(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for i in tir.serial(0, 128):
        for k in tir.parallel(0, 128):
            with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
                tir.bind(vi, i)
                tir.bind(vk, k)
                with tir.init():
                    B[vi] = 0.0
                B[vi] = B[vi] + A[vi, vk]
Ejemplo n.º 26
0
def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j in tir.grid(128, 128):
        for k in tir.serial(0, 128, annotations={"useless_annotation": True}):
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                tir.bind(vk, k)
                tir.reads([A[vi, vj, vk]])
                tir.writes([B[vi, vj, vk]])
                B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Ejemplo n.º 27
0
def elementwise_symbolic_fused(a: ty.handle, b: ty.handle,
                               n: ty.int32) -> None:
    A = tir.match_buffer(a, (128, 128, n))
    B = tir.match_buffer(b, (128, 128, n))
    for i_j_k_fused in tir.serial(0, (n * 16384)):
        with tir.block([128, 128, n], "B") as [vi, vj, vk]:
            tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128)))
            tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128))
            tir.bind(vk, tir.floormod(i_j_k_fused, n))
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Ejemplo n.º 28
0
def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j in tir.grid(128, 128):
        for k in tir.serial(10, 128):
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                tir.bind(vk, k)
                tir.reads([A[vi, vj, vk]])
                tir.writes([B[vi, vj, vk]])
                B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Ejemplo n.º 29
0
 def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle,
                  d: ty.handle) -> None:
     # function attr dict
     tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True})
     n = tir.var("int32")
     stride = tir.var("int32")
     stride_1 = tir.var("int32")
     stride_2 = tir.var("int32")
     stride_3 = tir.var("int32")
     A_1 = tir.match_buffer(
         A,
         [n],
         strides=[stride],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     B_1 = tir.match_buffer(
         B,
         [n],
         strides=[stride_1],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     C_1 = tir.match_buffer(
         C,
         [n],
         strides=[stride_2],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     d_1 = tir.match_buffer(
         d,
         [n],
         strides=[stride_3],
         elem_offset=0,
         align=128,
         offset_factor=1,
         type="auto",
     )
     # body
     for i in tir.serial(0, n):
         d_1.data[(i * stride_3)] = (tir.load("float32", A_1.data,
                                              (i * stride)) *
                                     tir.load("float32", B_1.data,
                                              (i * stride_1))) + tir.load(
                                                  "float32", C_1.data,
                                                  (i * stride_2))
def elementwise_with_loops_not_same_scope(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j in tir.grid(128, 128):
        with tir.block([128, 128], "A") as [vi, vj]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            for k in tir.serial(0, 128):
                with tir.block([128], "B") as [vk]:
                    tir.bind(vk, k)
                    tir.reads([A[vi, vj, vk]])
                    tir.writes([B[vi, vj, vk]])
                    B[vi, vj, vk] = A[vi, vj, vk] * 2.0