def war_dependency(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, 128))
    C = tir.match_buffer(c, (128, 128))

    for i, j in tir.grid(128, 128):
        with tir.block([128, 128], "C") as [vi, vj]:
            C[vi, vj] = B[vi, vj] + 1.0
        with tir.block([128, 128], "B") as [vi, vj]:
            B[vi, vj] = A[vi, vj] * 2.0
Exemple #2
0
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    for i, j in tir.grid(128, 128):
        with tir.block([128, 128], "init") as [vi, vj]:
            C[vi, vj] = tir.float32(0)
        for k in range(0, 128):
            with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def compacted_predicate_func(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (32), "float32")
    C = tir.match_buffer(c, (32), "float32")

    for i, j in tir.grid(5, 7):
        with tir.block([]) as []:
            tir.reads(A[i * 7 + j])
            tir.writes(C[i * 7 + j])
            tir.where(i * 7 + j < 32)
            C[i * 7 + j] = A[i * 7 + j] + 1.0
def elementwise_predicate(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.alloc_buffer((128, 128))
    C = tir.match_buffer(c, (128, 128))
    with tir.block([128, 128], "B") as [vi, vj]:
        B[vi, vj] = A[vi, vj] * 2.0
    for i, j in tir.grid(128, 128):
        with tir.block([128, 128], "C") as [vi, vj]:
            tir.where(B[i, j] < 10.0)
            C[vi, vj] = B[vi, vj] + 1.0
Exemple #5
0
def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128, 128])
    B = tir.match_buffer(b, [128, 128, 128])
    for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8):
        with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
            tir.bind(vi, ((i1 * 64) + i3))
            tir.bind(vj, ((j1 * 32) + j2))
            tir.bind(vk, ((k1 * 8) + k2))
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #6
0
def rowsum_transformed(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for io, ii_ko_fused, ki in tir.grid(32, 128, 4):
        with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
            tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32))
            tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki)
            with tir.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Exemple #7
0
def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for i, k in tir.grid(128, 16):
        with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
            tir.bind(vi, i)
            tir.bind(vk, tir.floordiv(k * k, 2))
            with tir.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Exemple #8
0
def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    C = tir.alloc_buffer((128, 128, 128))
    for i, j in tir.grid(128, 128):
        for k in tir.serial(0, 128):
            with tir.block([128, 128, 128], "C") as [vi, vj, vk]:
                C[vi, vj, vk] = A[vi, vj, vk] * 2.0
        for k in tir.serial(0, 128):
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                B[vi, vj, vk] = C[vi, vj, vk] * 2.0
Exemple #9
0
def buffer_shape_mismatch(a: ty.handle) -> None:
    A = tir.match_buffer(a, (8, 8))
    for i, j in tir.grid(8, 2):
        with tir.block([]):
            tir.reads([])
            tir.writes([A[i, j * 4:j * 4 + 4]])
            sub_A = tir.match_buffer(
                A[i, j * 4:j * 4 + 4],
                (5))  # error: shape mismatched between 4 and 5
            for jj in range(0, 4):
                sub_A[i, j * 4 + jj] = 1
Exemple #10
0
def tiled(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128], "float32")
    B = tir.alloc_buffer([128, 128], "float32")
    C = tir.match_buffer(c, [128, 128], "float32")
    for i_0, j_0, i_1, j_1 in tir.grid(8, 8, 16, 16):
        with tir.block([128, 128], "B") as [vi, vj]:
            tir.bind(vi, i_0 * 16 + i_1)
            tir.bind(vj, j_0 * 16 + j_1)
            B[vi, vj] = A[vi, vj] * 2.0
    with tir.block([128, 128], "C") as [vi, vj]:
        C[vi, vj] = B[vi, vj] + 1.0
def elementwise_affine_producer(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128), "float32")
    C = tir.match_buffer(c, (128, 128), "float32")
    B = tir.alloc_buffer((128, 128), "float32")
    for i, j, k, l in tir.grid(16, 2, 32, 16):
        with tir.block([128, 128], "B") as [vi, vj]:
            tir.bind(vi, i * 8 + j * 4 + k // 8)
            tir.bind(vj, k % 8 * 16 + l)
            B[vi, vj] = A[vi, vj] * 2.0
    with tir.block([128, 128], "C") as [vi, vj]:
        C[vi, vj] = B[vi, vj] + 1.0
Exemple #12
0
def blockized_2(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128], "float32")
    B = tir.alloc_buffer([128, 128], "float32")
    C = tir.match_buffer(c, [128, 128], "float32")
    for i_o, j_o in tir.grid(8, 8):
        with tir.block([8, 8], "B_outer") as [vio, vjo]:
            tir.bind(vio, i_o)
            tir.bind(vjo, j_o)
            tir.reads([A[vio * 16:vio * 16 + 16, vjo * 16:vjo * 16 + 16, ]])
            tir.writes([B[vio * 16:vio * 16 + 16, vjo * 16:vjo * 16 + 16]])
            for i_i, j_i in tir.grid(16, 16):
                with tir.block([128, 128], "B_inner") as [vi, vj]:
                    tir.bind(vi, vio * 16 + i_i)
                    tir.bind(vj, vjo * 16 + j_i)
                    B[vi, vj] = A[vi, vj] * 2.0
    for i_o, j_o, i_i, j_i in tir.grid(4, 4, 32, 32):
        with tir.block([128, 128], "C") as [vi, vj]:
            tir.bind(vi, i_o * 32 + i_i)
            tir.bind(vj, j_o * 32 + j_i)
            C[vi, vj] = B[vi, vj] + 1.0
def elementwise_reordered_with_predicate(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128, 128))
    for l, j, k, i in tir.grid(128, 128, 128, 128):
        with tir.block([128, 128, 128, 128], "B") as [vi, vj, vk, vl]:
            tir.where(i * 2097152 + j * 16384 + k * 128 + l < 100)
            tir.bind(vi, i)
            tir.bind(vj, j)
            tir.bind(vk, k)
            tir.bind(vl, l)
            B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
def elementwise_with_wrong_block_var_type(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j, k in tir.grid(128, 128, 128):
        with tir.block([128, 128, tir.scan_axis(0, 128)], "B") as [vi, vj, vk]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            tir.bind(vk, k)
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
def func() -> None:
    A = tir.alloc_buffer((128, 128), "float32")
    B = tir.alloc_buffer((128, 128), "float32")
    C = tir.alloc_buffer((128, 128), "float32")
    D = tir.alloc_buffer((128, 128), "float32")
    with tir.block([]):
        # Need add read/write region manually to avoid triggering block access region detector
        tir.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]])
        tir.writes([A[0:12, 0:12]])
        for i, j in tir.grid(8, 8):
            A[i, j] = B[0, 0] + C[0, 0]
        with tir.block([2, 2]) as [vi, vj]:
            tir.reads([
                A[vi * 4 + 4:vi * 4 + 8, vj * 4 + 4:vj * 4 + 8], C[12:16,
                                                                   12:16]
            ])
            tir.writes([A[vi * 4 + 4:vi * 4 + 8, vj * 4 + 4:vj * 4 + 8]])
            for i, j in tir.grid(4, 4):
                A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12]
        tir.evaluate(D.data)
Exemple #16
0
def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j in tir.grid(128, 128):
        for k in tir.serial(10, 128):
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                tir.bind(vk, k)
                tir.reads([A[vi, vj, vk]])
                tir.writes([B[vi, vj, vk]])
                B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #17
0
def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j in tir.grid(128, 128):
        for k in tir.thread_binding(0, 128, thread="threadIdx.x"):
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                tir.bind(vk, k)
                tir.reads([A[vi, vj, vk]])
                tir.writes([B[vi, vj, vk]])
                B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #18
0
def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j in tir.grid(128, 128):
        for k in tir.serial(0, 128, annotations={"useless_annotation": True}):
            with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
                tir.bind(vi, i)
                tir.bind(vj, j)
                tir.bind(vk, k)
                tir.reads([A[vi, vj, vk]])
                tir.writes([B[vi, vj, vk]])
                B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #19
0
def blockized_after_compute_at(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128], "float32")
    B = tir.alloc_buffer([128, 128], "float32")
    C = tir.match_buffer(c, [128, 128], "float32")
    for i0_0, i1_0 in tir.grid(8, 8):
        for ax0, ax1 in tir.grid(16, 16):
            with tir.block([128, 128], "B") as [vi, vj]:
                tir.bind(vi, i0_0 * 16 + ax0)
                tir.bind(vj, i1_0 * 16 + ax1)
                B[vi, vj] = A[vi, vj] * 2.0
        with tir.block([8, 8], "C_outer") as [vi_o, vj_o]:
            tir.bind(vi_o, i0_0)
            tir.bind(vj_o, i1_0)
            tir.reads(
                [B[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16, ]])
            tir.writes([C[vi_o * 16:vi_o * 16 + 16, vj_o * 16:vj_o * 16 + 16]])
            for i0_1, i1_1 in tir.grid(16, 16):
                with tir.block([128, 128], "C_inner") as [vi, vj]:
                    tir.bind(vi, vi_o * 16 + i0_1)
                    tir.bind(vj, vj_o * 16 + i1_1)
                    C[vi, vj] = B[vi, vj] + 1.0
Exemple #20
0
def elementwise_split_case1(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128, 128])
    B = tir.match_buffer(b, [128, 128, 128])
    for i1, i2, i3, j1, j2, j3, k1, k2, k3 in tir.grid(2, 1, 64, 2, 1, 64, 2,
                                                       1, 64):
        with tir.block([128, 128, 128], "B") as [vi, vj, vk]:
            tir.bind(vi, i1 * 64 + i3)
            tir.bind(vj, j1 * 64 + j3)
            tir.bind(vk, k1 * 64 + k3)
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #21
0
def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    C_rf = tir.alloc_buffer([4, 128, 128])

    for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(
            128, 128, 4, 8, 4):
        with tir.block(
            [4, 128, 128,
             tir.reduce_axis(0, 4),
             tir.reduce_axis(0, 8)], "update_rf") as [
                 vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer
             ]:
            tir.bind(vi2_inner_inner, i2_inner_inner)
            tir.bind(vi, i0)
            tir.bind(vj, i1)
            tir.bind(vi2_outer, i2_outer)
            tir.bind(vi2_inner_outer, i2_inner_outer)
            with tir.init():
                C_rf[vi2_inner_inner, vi, vj] = 0.0
            C_rf[vi2_inner_inner, vi,
                 vj] = C_rf[vi2_inner_inner, vi, vj] + (A[vi, (
                     ((vi2_outer * 32) +
                      (vi2_inner_outer * 4)) + vi2_inner_inner)] * B[vj, (
                          ((vi2_outer * 32) +
                           (vi2_inner_outer * 4)) + vi2_inner_inner)])

    for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4):
        with tir.block([tir.reduce_axis(0, 4), 128, 128], "update") as [
                vi2_inner_inner_1,
                vi_1,
                vj_1,
        ]:
            tir.bind(vi2_inner_inner_1, i2_inner_inner_1)
            tir.bind(vi_1, i0_1)
            tir.bind(vj_1, i1_1)
            with tir.init():
                C[vi_1, vj_1] = 0.0
            C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
def cache_write_multi_consumer() -> None:
    A = tir.alloc_buffer((128))
    B = tir.alloc_buffer((128))
    C = tir.alloc_buffer((128))
    A_global = tir.alloc_buffer((128))
    for i in tir.grid(8):
        for j in tir.grid(16):
            with tir.block([128], "A_global") as [vi]:
                tir.bind(vi, i * 16 + j)
                A_global[vi] = 1.0
        for j in tir.grid(16):
            with tir.block([128], "A") as [vi]:
                tir.bind(vi, i * 16 + j)
                A[vi] = A_global[vi]
        for j in tir.grid(16):
            with tir.block([128], "B") as [vi]:
                tir.bind(vi, i * 16 + j)
                B[vi] = A[vi] + 1.0

    for i in tir.grid(128):
        with tir.block([128], "C") as [vi]:
            C[vi] = A[vi]
Exemple #23
0
 def main(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
     # function attr dict
     tir.func_attr({"global_symbol": "main", "tir.noalias": True})
     A = tir.match_buffer(a, [128, 128])
     B = tir.match_buffer(b, [128, 128])
     C = tir.match_buffer(c, [128, 128])
     # body
     for x, y in tir.grid(128, 128):
         C.data[x * 128 + y] = 0.0
         for k in tir.serial(0, 128):
             C.data[x * 128 + y] = tir.load("float32", C.data, x * 128 + y) + tir.load(
                 "float32", A.data, x * 128 + k
             ) * tir.load("float32", B.data, y * 128 + k)
def elementwise_with_loops_not_same_scope(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128, 128))
    B = tir.match_buffer(b, (128, 128, 128))
    for i, j in tir.grid(128, 128):
        with tir.block([128, 128], "A") as [vi, vj]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            for k in tir.serial(0, 128):
                with tir.block([128], "B") as [vk]:
                    tir.bind(vk, k)
                    tir.reads([A[vi, vj, vk]])
                    tir.writes([B[vi, vj, vk]])
                    B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #25
0
def factorized(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16, 16], "float32")
    B = tir.match_buffer(b, [16], "float32")
    B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local")
    for j in tir.thread_binding(0, 16, thread="blockIdx.x"):
        for i_o in tir.thread_binding(0, 4, thread="threadIdx.x"):
            for i_i, k in tir.grid(4, 16):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "B_rf") as [vi, vj, vk]:
                    tir.bind(vi, i_o * 4 + i_i)
                    tir.bind(vj, j)
                    tir.bind(vk, k)
                    with tir.init():
                        B_rf_local[vi, vj] = 0.0
                    B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk]
    for i, k in tir.grid(16, 16):
        with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]:
            tir.bind(vi, i)
            tir.bind(vk, k)
            with tir.init():
                B[vi] = 0.0
            B[vi] = B[vi] + B_rf_local[vk, vi]
Exemple #26
0
def tir_multi_output(a0: ty.handle, a1: ty.handle, b0: ty.handle, b1: ty.handle) -> None:
    m = tir.var("int32")
    n = tir.var("int32")
    A0 = tir.match_buffer(a0, (m, n))
    A1 = tir.match_buffer(a1, (m, n))
    B0 = tir.match_buffer(b0, (m, n))
    B1 = tir.match_buffer(b1, (m, n))

    for i0, i1 in tir.grid(m, n):
        with tir.block([m, n], "B.v0") as [i, j]:
            B0[i, j] = A0[i, j] + 2.0
        with tir.block([m, n], "B.v1") as [i, j]:
            B1[i, j] = A1[i, j] * 3.0
Exemple #27
0
def two_elementwise_after_compute_at(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128), "float32")
    B = tir.alloc_buffer((128, 128), "float32")
    C = tir.match_buffer(c, (128, 128), "float32")
    for i in range(0, 128):
        for ax0, ax1 in tir.grid(1, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                tir.bind(vi, i + ax0)
                tir.bind(vj, ax1)
                B[vi, vj] = A[vi, vj] * 2.0
        for j in range(0, 128):
            with tir.block([128, 128], "B") as [vi, vj]:
                C[vi, vj] = B[vi, vj] + 1.0
Exemple #28
0
def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 256, 256])
    C = tir.match_buffer(c, [16])
    C_rf = tir.alloc_buffer([16, 256])

    for i0, i1, i2 in tir.grid(16, 256, 256):
        with tir.block([256, 16, tir.reduce_axis(0, 256)],
                       "C_rf") as [vi2, b, i]:
            tir.bind(vi2, i2)
            tir.bind(b, i0)
            tir.bind(i, i1)
            with tir.init():
                C_rf[b, vi2] = 0.0
            C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2])

    for i0_1, i2_1 in tir.grid(16, 256):
        with tir.block([tir.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]:
            tir.bind(vi2_1, i2_1)
            tir.bind(b_1, i0_1)
            with tir.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
Exemple #29
0
def elementwise_symbolic_split(a: ty.handle, b: ty.handle,
                               n: ty.int32) -> None:
    A = tir.match_buffer(a, (128, 128, n))
    B = tir.match_buffer(b, (128, 128, n))
    for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)):
        with tir.block([128, 128, n], "B") as [vi, vj, vk]:
            tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n))
            tir.bind(vi, i)
            tir.bind(vj, j)
            tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1))
            tir.reads([A[vi, vj, vk]])
            tir.writes([B[vi, vj, vk]])
            B[vi, vj, vk] = A[vi, vj, vk] * 2.0
Exemple #30
0
def buffer_opaque_access(b: ty.handle, c: ty.handle) -> None:
    B = tir.match_buffer(b, [16, 16], "float32")
    C = tir.match_buffer(c, [16, 16], "float32")

    with tir.block([]):
        tir.reads([])
        tir.writes(B[0:16, 0:16])
        A = tir.allocate([256], "float32", "global")
        for i, j in tir.grid(16, 16):
            tir.store(A, i * 16 + j, 1)
        for i in range(0, 16):
            for j in range(0, 16):
                tir.evaluate(tir.load("float32", A, i * 16 + j))
            for j in range(0, 16):
                tir.evaluate(
                    tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, tir.float32(0), dtype="handle")
                )

    for i, j in tir.grid(16, 16):
        with tir.block([16, 16]) as [vi, vj]:
            tir.bind(vi, i)
            tir.bind(vj, j)
            C[vi, vj] = B[vi, vj]