Esempio n. 1
0
def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16, 16])
    C = tir.alloc_buffer([16, 16])
    D = tir.alloc_buffer([16, 16])
    E = tir.alloc_buffer([16, 16])
    F = tir.match_buffer(f, [16, 16])
    C_rf = tir.alloc_buffer([16, 16, 4])

    for i, j1, k1o, k1i in tir.grid(16, 16, 4, 4):
        with tir.block([4, 16, 16, tir.reduce_axis(0, 4)],
                       "C_rf") as [vk1o, ci, cj, vk1i]:
            tir.bind(vk1o, k1o)
            tir.bind(ci, i)
            tir.bind(cj, j1)
            tir.bind(vk1i, k1i)
            with tir.init():
                C_rf[ci, cj, vk1o] = 0.0
            C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj,
                                                        ((vk1o * 4) + vk1i)]
    for i_1 in tir.serial(0, 16):
        for j1_1 in tir.serial(0, 16):
            for k1o_1 in tir.serial(0, 4):
                with tir.block([tir.reduce_axis(0, 4), 16, 16],
                               "C") as [vk1o_1, ci_1, cj_1]:
                    tir.bind(vk1o_1, k1o_1)
                    tir.bind(ci_1, i_1)
                    tir.bind(cj_1, j1_1)
                    with tir.init():
                        C[ci_1, cj_1] = 0.0
                    C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1]
            for k2o, k2i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "D") as [di, dj, dk]:
                    tir.bind(di, i_1)
                    tir.bind(dj, j1_1)
                    tir.bind(dk, (k2o * 4) + k2i)
                    with tir.init():
                        D[di, dj] = 0.0
                    D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj]
        for j2 in tir.serial(0, 16):
            for k3o, k3i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "E") as [ei, ej, ek]:
                    tir.bind(ei, i_1)
                    tir.bind(ej, j2)
                    tir.bind(ek, (k3o * 4) + k3i)
                    with tir.init():
                        E[ei, ej] = 0.0
                    E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej]
            for k4o, k4i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "F") as [fi, fj, fk]:
                    tir.bind(fi, i_1)
                    tir.bind(fj, j2)
                    tir.bind(fk, (k4o * 4) + k4i)
                    with tir.init():
                        F[fi, fj] = 0.0
                    F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]
Esempio n. 2
0
def rowsum_zero_dim_rfactor(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [128])
    B = tir.match_buffer(b, [])
    B_rf = tir.alloc_buffer([128])

    with tir.block([128], "B_rf") as [vi0]:
        with tir.init():
            B_rf[vi0] = 0.0
        B_rf[vi0] = B_rf[vi0] + A[vi0]

    with tir.block([tir.reduce_axis(0, 128)], "B") as [vi0_1]:
        with tir.init():
            B[()] = 0.0
        B[()] = B[()] + B_rf[vi0_1]
Esempio n. 3
0
def rowsum_wrong_reduce_pattern2(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
        with tir.init():
            B[vi] = 0.0
        B[vi] = B[vi] - A[vi, vk]
Esempio n. 4
0
def rowsum_zero_dim(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [128])
    B = tir.match_buffer(b, [])

    with tir.block([tir.reduce_axis(0, 128)], "B") as [k]:
        with tir.init():
            B[()] = 0.0
        B[()] = B[()] + A[k]
Esempio n. 5
0
def multiple_reduction_blocks(a: ty.handle, f: ty.handle) -> None:
    A = tir.match_buffer(a, (16, 16, 16))
    C = tir.alloc_buffer((16, 16))
    D = tir.alloc_buffer((16, 16))
    E = tir.alloc_buffer((16, 16))
    F = tir.match_buffer(f, (16, 16))

    for i in tir.serial(0, 16):
        for j1 in tir.serial(0, 16):
            for k1o, k1i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "C") as [ci, cj, ck]:
                    tir.bind(ci, i)
                    tir.bind(cj, j1)
                    tir.bind(ck, k1o * 4 + k1i)
                    with tir.init():
                        C[ci, cj] = 0.0
                    C[ci, cj] = C[ci, cj] + A[ci, cj, ck]
            for k2o, k2i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "D") as [di, dj, dk]:
                    tir.bind(di, i)
                    tir.bind(dj, j1)
                    tir.bind(dk, k2o * 4 + k2i)
                    with tir.init():
                        D[di, dj] = 0.0
                    D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj]
        for j2 in tir.serial(0, 16):
            for k3o, k3i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "E") as [ei, ej, ek]:
                    tir.bind(ei, i)
                    tir.bind(ej, j2)
                    tir.bind(ek, k3o * 4 + k3i)
                    with tir.init():
                        E[ei, ej] = 0.0
                    E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej]
            for k4o, k4i in tir.grid(4, 4):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "F") as [fi, fj, fk]:
                    tir.bind(fi, i)
                    tir.bind(fj, j2)
                    tir.bind(fk, k4o * 4 + k4i)
                    with tir.init():
                        F[fi, fj] = 0.0
                    F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj]
Esempio n. 6
0
def rowsum_not_dominant(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, 128))

    with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
        with tir.init():
            B[vi, vk] = 0.0
        B[vi, vk] = B[vi, vk] + A[vi, vk]
Esempio n. 7
0
def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, 128))
    C = tir.match_buffer(c, (128, 128))

    with tir.block([128, 128, tir.reduce_axis(0, 128)]) as [i, j, k]:
        with tir.init():
            C[i, j] = 0.0
        C[i, j] += A[i, k] * B[j, k]
Esempio n. 8
0
def square_sum(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 256, 256])
    C = tir.match_buffer(c, [16])

    with tir.block([16, tir.reduce_axis(0, 256),
                    tir.reduce_axis(0, 256)], "C") as [b, i, j]:
        with tir.init():
            C[b] = 0.0
        C[b] = C[b] + A[b, i, j] * A[b, i, j]
Esempio n. 9
0
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])

    with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = tir.float32(0)
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
Esempio n. 10
0
def matmul_m_128(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    m = tir.var("int32")
    A = tir.match_buffer(a, [m, 128])
    B = tir.match_buffer(b, [m, 128])
    C = tir.match_buffer(c, [m, m])

    with tir.block([m, m, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = 0.0
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
Esempio n. 11
0
def matmul_not_same_buffer_access(a: ty.handle, b: ty.handle,
                                  c: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, 128))
    C = tir.match_buffer(c, (128, 128))

    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = 0.0
        C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj]
Esempio n. 12
0
    def main(a: ty.handle, b: ty.handle) -> None:
        A = tir.match_buffer(a, [64, 64, 64])
        B = tir.match_buffer(b, [64])

        with tir.block([64,
                        tir.reduce_axis(0, 64),
                        tir.reduce_axis(32, 64)]) as [i, j, k]:
            with tir.init():
                B[i] = tir.float32(0)
            B[i] += A[i, j, k]
Esempio n. 13
0
 def batch_matmul(  # pylint: disable=no-self-argument
         a: ty.handle, b: ty.handle, c: ty.handle) -> None:
     tir.func_attr({"global_symbol": "batch_matmul", "tir.noalias": True})
     A = tir.match_buffer(a, [16, 128, 128])
     B = tir.match_buffer(b, [16, 128, 128])
     C = tir.match_buffer(c, [16, 128, 128])
     with tir.block([16, 128, 128, tir.reduce_axis(0, 128)],
                    "update") as [vn, vi, vj, vk]:
         with tir.init():
             C[vn, vi, vj] = 0.0
         C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]
Esempio n. 14
0
 def matmul(  # pylint: disable=no-self-argument
         a: ty.handle, b: ty.handle, c: ty.handle) -> None:
     tir.func_attr({"global_symbol": "matmul", "tir.noalias": True})
     A = tir.match_buffer(a, (1024, 1024), "float32")
     B = tir.match_buffer(b, (1024, 1024), "float32")
     C = tir.match_buffer(c, (1024, 1024), "float32")
     with tir.block([1024, 1024, tir.reduce_axis(0, 1024)],
                    "matmul") as [vi, vj, vk]:
         with tir.init():
             C[vi, vj] = 0.0
         C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Esempio n. 15
0
def rowsum_transformed(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for io, ii_ko_fused, ki in tir.grid(32, 128, 4):
        with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
            tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32))
            tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki)
            with tir.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Esempio n. 16
0
def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for i, k in tir.grid(128, 16):
        with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
            tir.bind(vi, i)
            tir.bind(vk, tir.floordiv(k * k, 2))
            with tir.init():
                B[vi] = 0.0
            B[vi] = B[vi] + A[vi, vk]
Esempio n. 17
0
    def main(a: ty.handle, b: ty.handle) -> None:
        A = tir.match_buffer(a, [64, 64, 64])
        B = tir.match_buffer(b, [64])

        with tir.block([64,
                        tir.reduce_axis(0, 64),
                        tir.reduce_axis(32, 64)]) as [i, j, k]:
            BB = tir.match_buffer(B[i], ())
            AA = tir.match_buffer(A[i, 0:64, 0:64], (64, 64))
            with tir.init():
                BB[()] = tir.float32(0)
            BB[()] += AA[j, k]
Esempio n. 18
0
def cuda_matmul_1(a: ty.handle, b: ty.handle, c: ty.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = tir.match_buffer(a, [2048, 2048], "float32")
    B = tir.match_buffer(b, [2048, 2048], "float32")
    C = tir.match_buffer(c, [2048, 2048], "float32")
    A_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = tir.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = tir.alloc_buffer([2048, 2048], "float32", scope="local")
    with tir.block([2048, 2048], "A_shared") as [v0, v1]:
        A_shared[v0, v1] = A[v0, v1]
    with tir.block([2048, 2048], "B_shared") as [v0, v1]:
        B_shared[v0, v1] = B[v0, v1]
    with tir.block([2048, 2048], "A_shared_local") as [v0, v1]:
        A_shared_local[v0, v1] = A_shared[v0, v1]
    with tir.block([2048, 2048], "B_shared_local") as [v0, v1]:
        B_shared_local[v0, v1] = B_shared[v0, v1]
    for by in tir.thread_binding(0, 32, thread="blockIdx.y"):
        for bx in tir.thread_binding(0, 32, thread="blockIdx.x"):
            for vy in tir.thread_binding(0, 2, thread="vthread.y"):
                for vx in tir.thread_binding(0, 2, thread="vthread.x"):
                    for ty in tir.thread_binding(0, 8, thread="threadIdx.y"):
                        for tx in tir.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.x"):
                            for k_0 in tir.serial(0, 256):
                                for k_1 in tir.unroll(0, 8):
                                    for _, i, j in tir.grid(1, 4, 4):
                                        with tir.block([
                                                2048, 2048,
                                                tir.reduce_axis(0, 2048)
                                        ], "C") as [vi, vj, vk]:
                                            tir.bind(
                                                vi,
                                                by * 64 + vy * 32 + ty * 4 + i)
                                            tir.bind(
                                                vj,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            tir.bind(vk, k_0 * 8 + k_1)
                                            with tir.init():
                                                C_local[vi, vj] = 0.0
                                            C_local[vi, vj] = C_local[
                                                vi, vj] + A_shared_local[
                                                    vk, vi] * B_shared_local[
                                                        vk, vj]
                            for i, j in tir.grid(4, 4):
                                with tir.block([2048, 2048],
                                               "C_local") as [vi, vj]:
                                    tir.bind(vi,
                                             by * 64 + vy * 32 + ty * 4 + i)
                                    tir.bind(vj,
                                             bx * 64 + vx * 32 + tx * 4 + j)
                                    C[vi, vj] = C_local[vi, vj]
Esempio n. 19
0
def rowsum_not_serial(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128))
    B = tir.match_buffer(b, (128, ))

    for i in tir.serial(0, 128):
        for k in tir.parallel(0, 128):
            with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]:
                tir.bind(vi, i)
                tir.bind(vk, k)
                with tir.init():
                    B[vi] = 0.0
                B[vi] = B[vi] + A[vi, vk]
Esempio n. 20
0
def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    C_rf = tir.alloc_buffer([4, 128, 128])

    for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(
            128, 128, 4, 8, 4):
        with tir.block(
            [4, 128, 128,
             tir.reduce_axis(0, 4),
             tir.reduce_axis(0, 8)], "update_rf") as [
                 vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer
             ]:
            tir.bind(vi2_inner_inner, i2_inner_inner)
            tir.bind(vi, i0)
            tir.bind(vj, i1)
            tir.bind(vi2_outer, i2_outer)
            tir.bind(vi2_inner_outer, i2_inner_outer)
            with tir.init():
                C_rf[vi2_inner_inner, vi, vj] = 0.0
            C_rf[vi2_inner_inner, vi,
                 vj] = C_rf[vi2_inner_inner, vi, vj] + (A[vi, (
                     ((vi2_outer * 32) +
                      (vi2_inner_outer * 4)) + vi2_inner_inner)] * B[vj, (
                          ((vi2_outer * 32) +
                           (vi2_inner_outer * 4)) + vi2_inner_inner)])

    for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4):
        with tir.block([tir.reduce_axis(0, 4), 128, 128], "update") as [
                vi2_inner_inner_1,
                vi_1,
                vj_1,
        ]:
            tir.bind(vi2_inner_inner_1, i2_inner_inner_1)
            tir.bind(vi_1, i0_1)
            tir.bind(vj_1, i1_1)
            with tir.init():
                C[vi_1, vj_1] = 0.0
            C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]
Esempio n. 21
0
def factorized_after_reverse_compute_at(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16, 16], "float32")
    B = tir.match_buffer(b, [16], "float32")
    B_rf_local = tir.alloc_buffer([16, 16], "float32", scope="local")
    for j in tir.thread_binding(0, 16, thread="blockIdx.x"):
        for i_o in tir.thread_binding(0, 4, thread="threadIdx.x"):
            for i_i, k in tir.grid(4, 16):
                with tir.block([16, 16, tir.reduce_axis(0, 16)],
                               "B_rf") as [vi, vj, vk]:
                    tir.bind(vi, i_o * 4 + i_i)
                    tir.bind(vj, j)
                    tir.bind(vk, k)
                    with tir.init():
                        B_rf_local[vi, vj] = 0.0
                    B_rf_local[vi, vj] = B_rf_local[vi, vj] + A[vj, vi, vk]
            for k in tir.serial(0, 4):
                with tir.block([16, tir.reduce_axis(0, 16)], "B") as [vi, vk]:
                    tir.bind(vi, j)
                    tir.bind(vk, i_o * 4 + k)
                    with tir.init():
                        B[vi] = 0.0
                    B[vi] = B[vi] + B_rf_local[vk, vi]
Esempio n. 22
0
def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 256, 256])
    C = tir.match_buffer(c, [16])
    C_rf = tir.alloc_buffer([16, 256])

    for i0, i1, i2 in tir.grid(16, 256, 256):
        with tir.block([256, 16, tir.reduce_axis(0, 256)],
                       "C_rf") as [vi2, b, i]:
            tir.bind(vi2, i2)
            tir.bind(b, i0)
            tir.bind(i, i1)
            with tir.init():
                C_rf[b, vi2] = 0.0
            C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2])

    for i0_1, i2_1 in tir.grid(16, 256):
        with tir.block([tir.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]:
            tir.bind(vi2_1, i2_1)
            tir.bind(b_1, i0_1)
            with tir.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
Esempio n. 23
0
def matmul_not_stage_pipeline(a: ty.handle, b: ty.handle,
                              d: ty.handle) -> None:
    A = tir.match_buffer(a, [256, 256])
    B = tir.match_buffer(b, [256, 256])
    D = tir.match_buffer(d, [256, 256])
    C = tir.alloc_buffer([256, 256])

    with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = 0.0
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

    with tir.block([256, 256], "D") as [vi, vj]:
        D[vi, vj] = C[vi, vj]
Esempio n. 24
0
def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 256, 256])
    D = tir.match_buffer(d, [16])
    C = tir.alloc_buffer([16])
    C_rf = tir.alloc_buffer([1, 16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1):
        with tir.block(
            [1, 16, tir.reduce_axis(0, 256),
             tir.reduce_axis(0, 256)], "C_rf") as [
                 vi1_i2_fused_inner,
                 b,
                 i,
                 j,
             ]:
            tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner)
            tir.bind(b, i0)
            tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256))
            tir.bind(j, tir.floormod(i1_i2_fused_outer, 256))
            with tir.init():
                C_rf[vi1_i2_fused_inner, b] = 0.0
            C_rf[vi1_i2_fused_inner,
                 b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j])

    for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1):
        with tir.block([tir.reduce_axis(0, 1), 16],
                       "C") as [vi1_i2_fused_inner_1, b_1]:
            tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1)
            tir.bind(b_1, i0_1)
            with tir.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1]

    for i0_2 in tir.serial(0, 16):
        with tir.block([16], "D") as [b_2]:
            tir.bind(b_2, i0_2)
            D[b_2] = tir.sqrt(C[b_2], dtype="float32")
Esempio n. 25
0
def matmul_loop_multiple_children(a: ty.handle, b: ty.handle, c: ty.handle,
                                  d: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])
    D = tir.match_buffer(d, [128, 128])

    for k, i, j in tir.grid(128, 128, 128):
        with tir.block([tir.reduce_axis(0, 128), 128, 128],
                       "C") as [ck, ci, cj]:
            tir.bind(ck, k)
            tir.bind(ci, i)
            tir.bind(cj, j)
            with tir.init():
                C[ci, cj] = 0.0
            C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj]
        with tir.block([tir.reduce_axis(0, 128), 128, 128],
                       "D") as [dk, di, dj]:
            tir.bind(dk, k)
            tir.bind(di, i)
            tir.bind(dj, j)
            with tir.init():
                D[di, dj] = 0.0
            D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj]
Esempio n. 26
0
def buffer_load_store_func(a: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, (128, 128), "float32")
    B = tir.match_buffer(b, (128, 128), "float32")
    C = tir.alloc_buffer((128, 128), "float32")
    D = tir.alloc_buffer((128, 128), "float32")
    with tir.block([128, 128]) as [i, j]:
        A[i, j] = tir.float32(0)
    with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]:
        with tir.init():
            for ii, jj in tir.grid(4, 4):
                B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
        for ii, jj in tir.grid(4, 4):
            for kk in range(0, 4):
                B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
            for kk in range(0, 4):
                B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]
Esempio n. 27
0
def transformed_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
    A = tir.match_buffer(a, [128, 128])
    B = tir.match_buffer(b, [128, 128])
    C = tir.match_buffer(c, [128, 128])

    for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(
            128, 128, 4, 8, 4):
        with tir.block([128, 128, tir.reduce_axis(0, 128)],
                       "update") as [vi, vj, vk]:
            tir.bind(vi, i0)
            tir.bind(vj, i1)
            tir.bind(vk, (((i2_outer * 32) +
                           (i2_inner_outer * 4)) + i2_inner_inner))
            tir.reads([C[vi, vj], A[vi, vk], B[vj, vk]])
            tir.writes([C[vi, vj]])
            with tir.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
Esempio n. 28
0
def tir_conv2d(a: ty.handle, w: ty.handle, b: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 16, 14, 14])
    W = tir.match_buffer(w, [16, 3, 3, 32])
    B = tir.match_buffer(b, [16, 32, 14, 14])
    Apad = tir.alloc_buffer([16, 16, 16, 16])

    with tir.block([16, 16, 16, 16], "Apad") as [nn, cc, yy, xx]:
        Apad[nn, cc, yy, xx] = tir.if_then_else(
            yy >= 1 and yy - 1 < 14 and xx >= 1 and xx - 1 < 14,
            A[nn, cc, yy - 1, xx - 1],
            0.0,
            dtype="float32",
        )
    with tir.block(
        [16, 32, 14, 14, tir.reduce_axis(0, 16), tir.reduce_axis(0, 3), tir.reduce_axis(0, 3)], "B"
    ) as [nn, ff, yy, xx, rc, ry, rx]:
        with tir.init():
            B[nn, ff, yy, xx] = 0.0
        B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff]
Esempio n. 29
0
def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None:
    A = tir.match_buffer(a, [16, 256, 256])
    D = tir.match_buffer(d, [16])
    C = tir.alloc_buffer([16])

    for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1):
        with tir.block(
            [16, tir.reduce_axis(0, 256),
             tir.reduce_axis(0, 256)], "C") as [b, i, j]:
            tir.bind(b, i0)
            tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256))
            tir.bind(j, tir.floormod(i1_i2_fused_outer, 256))
            tir.reads([C[b], A[b, i, j]])
            tir.writes([C[b]])
            with tir.init():
                C[b] = 0.0
            C[b] = C[b] + (A[b, i, j] * A[b, i, j])
    for i0_1 in tir.serial(0, 16):
        with tir.block([16], "D") as [b_1]:
            tir.bind(b_1, i0_1)
            tir.reads([C[b_1]])
            tir.writes([D[b_1]])
            D[b_1] = tir.sqrt(C[b_1], dtype="float32")
Esempio n. 30
0
def duplicate_init() -> None:
    with tir.block([16, 16]) as [vi, vj]:
        with tir.init():
            tir.evaluate(1.0)
        with tir.init():  # error
            tir.evaluate(1.0)