def transformed_three_stage_compute(A: T.Buffer[(16, 16), "float32"],
                                    D: T.Buffer[(16, 16), "float32"]) -> None:
    for tx in T.thread_binding(16, thread="threadIdx.x"):
        with T.block():
            T.reads(A[tx, 0:16])
            T.writes(D[tx, 0:16])
            B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
            C = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
            with T.block():
                T.reads(A[tx, 0:2], B[0:2, tx, 0])
                T.writes(B[0:2, tx, 0], C[0:2, tx, 0])
                for i in T.unroll(2):
                    with T.block():
                        T.reads(A[tx, i])
                        T.writes(B[0:2, tx, 0])
                        B[i, tx, 0] = A[tx, i] * T.float32(2)
                    with T.block():
                        T.where(1 <= i)
                        T.reads(B[0:2, tx, 0])
                        T.writes(C[0:2, tx, 0])
                        C[(i + 1) % 2, tx,
                          0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
            with T.block():
                T.reads(A[tx, 2:16], B[0:2, tx, 0], C[0:2, tx, 0])
                T.writes(B[0:2, tx, 0], C[0:2, tx, 0], D[tx, 0:14])
                for i in T.serial(14):
                    with T.block():
                        T.reads(A[tx, i + 2])
                        T.writes(B[0:2, tx, 0])
                        B[i % 2, tx, 0] = A[tx, i + 2] * T.float32(2)
                    with T.block():
                        T.reads(B[0:2, tx, 0])
                        T.writes(C[0:2, tx, 0])
                        C[(i + 1) % 2, tx,
                          0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
                    with T.block():
                        T.reads(C[0:2, tx, 0])
                        T.writes(D[tx, i])
                        D[tx, i] = C[i % 2, tx, 0] + T.float32(1)
            with T.block():
                T.reads(B[0:2, tx, 0], C[0:2, tx, 0])
                T.writes(C[0:2, tx, 0], D[tx, 14:16])
                for i in T.unroll(2):
                    with T.block():
                        T.where(i < 1)
                        T.reads(B[0:2, tx, 0])
                        T.writes(C[0:2, tx, 0])
                        C[(i + 1) % 2, tx,
                          0] = B[(i + 1) % 2, tx, 0] + T.float32(2)
                    with T.block():
                        T.reads(C[0:2, tx, 0])
                        T.writes(D[tx, i + 14])
                        D[tx, i + 14] = C[i, tx, 0] + T.float32(1)
def cuda_matmul_2(a: T.handle, b: T.handle, c: T.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = T.match_buffer(a, [2048, 2048], "float32")
    B = T.match_buffer(b, [2048, 2048], "float32")
    C = T.match_buffer(c, [2048, 2048], "float32")
    A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    with T.block([2048, 2048], "A_shared") as [v0, v1]:
        A_shared[v0, v1] = A[v0, v1]
    with T.block([2048, 2048], "B_shared") as [v0, v1]:
        B_shared[v0, v1] = B[v0, v1]
    with T.block([2048, 2048], "B_shared_local") as [v0, v1]:
        B_shared_local[v0, v1] = B_shared[v0, v1]
    for by in T.thread_binding(0, 32, thread="blockIdx.y"):
        for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
            for vy in T.thread_binding(0, 2, thread="vthread.y"):
                for vx in T.thread_binding(0, 2, thread="vthread.x"):
                    for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
                        for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
                            for k_0 in T.serial(0, 256):
                                for k_1 in T.unroll(0, 8):
                                    for i, j in T.grid(1, 4):
                                        with T.block(
                                            [2048, 2048],
                                                "A_shared_local") as [v0, v1]:
                                            T.bind(v0, k_0 * 8 + k_1 + i)
                                            T.bind(
                                                v1,
                                                by * 64 + vy * 32 + ty * 4 + j)
                                            A_shared_local[v0,
                                                           v1] = A_shared[v0,
                                                                          v1]
                                    for _, i, j in T.grid(1, 4, 4):
                                        with T.block([
                                                2048, 2048,
                                                T.reduce_axis(0, 2048)
                                        ], "C") as [vi, vj, vk]:
                                            T.bind(
                                                vi,
                                                by * 64 + vy * 32 + ty * 4 + i)
                                            T.bind(
                                                vj,
                                                bx * 64 + vx * 32 + tx * 4 + j)
                                            T.bind(vk, k_0 * 8 + k_1)
                                            with T.init():
                                                C_local[vi, vj] = T.float32(0)
                                            C_local[vi, vj] = C_local[
                                                vi, vj] + A_shared_local[
                                                    vk, vi] * B_shared_local[
                                                        vk, vj]
                            for i, j in T.grid(4, 4):
                                with T.block([2048, 2048],
                                             "C_local") as [v0, v1]:
                                    T.bind(v0, by * 64 + vy * 32 + ty * 4 + i)
                                    T.bind(v1, bx * 64 + vx * 32 + tx * 4 + j)
                                    C[v0, v1] = C_local[v0, v1]
def loop_syntax_sugar(a: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128, 128))
    for i in T.serial(128):
        for j in T.parallel(128):
            for k in T.vectorized(128):
                for x in T.unroll(128):
                    for y in T.thread_binding(128, "threadIdx.x"):
                        for z in T.thread_binding(128, thread="threadIdx.x"):
                            A[i, j, k, x] = A[i, j, k, x] * 2.0
Beispiel #4
0
def rowsum_unrolled(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128,))
    for i0 in T.unroll(0, 128):
        for i1 in T.serial(0, 128):
            with T.block("B"):
                vi, vk = T.axis.remap("SR", [i0, i1])
                with T.init():
                    B[vi] = 0.0
                B[vi] = B[vi] + A[vi, vk]
def rowsum_unrolled(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (128,))
    for i0 in T.unroll(0, 128):
        for i1 in T.serial(0, 128):
            with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]:
                T.bind(vi, i0)
                T.bind(vk, i1)
                with T.init():
                    B[vi] = 0.0
                B[vi] = B[vi] + A[vi, vk]
Beispiel #6
0
def after_unrolled_loop(
    placeholder: T.Buffer[(1, 56, 56, 64), "float32"], ) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    # body
    # with T.block("root")
    bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32")
    inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32")
    for i2_0_i3_0_i2_1_i3_1_fused_0 in T.thread_binding(13,
                                                        thread="blockIdx.x"):
        for i2_0_i3_0_i2_1_i3_1_fused_1 in T.thread_binding(
                1024, thread="threadIdx.x"):
            for i0 in T.unroll(4):
                for i1 in T.unroll(4):
                    for i4 in T.unroll(6):
                        for i5 in T.unroll(6):
                            with T.block("inverse"):
                                vh, vw = T.axis.remap("SS", [i0, i1])
                                p = T.axis.spatial(
                                    196,
                                    (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 +
                                     i2_0_i3_0_i2_1_i3_1_fused_1) // 128 * 2 +
                                    (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 +
                                     i2_0_i3_0_i2_1_i3_1_fused_1) % 32 // 16,
                                )
                                co = T.axis.spatial(
                                    64,
                                    (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 +
                                     i2_0_i3_0_i2_1_i3_1_fused_1) % 128 // 32 *
                                    16 + (i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 +
                                          i2_0_i3_0_i2_1_i3_1_fused_1) % 16,
                                )
                                r_a, r_b = T.axis.remap("RR", [i4, i5])
                                T.where(i2_0_i3_0_i2_1_i3_1_fused_0 * 1024 +
                                        i2_0_i3_0_i2_1_i3_1_fused_1 < 12544)
                                T.reads(bgemm[r_a, r_b, p, co])
                                T.writes(inverse[vh, vw, p, co])
                                with T.init():
                                    inverse[vh, vw, p, co] = T.float32(0)
                                inverse[vh, vw, p,
                                        co] = (inverse[vh, vw, p, co] +
                                               bgemm[r_a, r_b, p, co])
Beispiel #7
0
def cuda_matmul_5(a: T.handle, b: T.handle, c: T.handle) -> None:  # pylint: disable=undefined-loop-variable
    A = T.match_buffer(a, [2048, 2048], "float32")
    B = T.match_buffer(b, [2048, 2048], "float32")
    C = T.match_buffer(c, [2048, 2048], "float32")
    A_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    B_shared = T.alloc_buffer([2048, 2048], "float32", scope="shared")
    A_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    B_shared_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    C_local = T.alloc_buffer([2048, 2048], "float32", scope="local")
    for by in T.thread_binding(0, 32, thread = "blockIdx.y"):
        for bx in T.thread_binding(0, 32, thread = "blockIdx.x"):
            for vy in T.thread_binding(0, 2, thread = "vthread.y"):
                for vx in T.thread_binding(0, 2, thread = "vthread.x"):
                    for ty in T.thread_binding(0, 8, thread = "threadIdx.y"):
                        for tx in T.thread_binding(0, 8, thread = "threadIdx.x"):
                            for k0 in T.serial(0, 256):
                                for i, j in T.grid(8, 64):
                                    with T.block("A_shared"):
                                        v0 = T.axis.S(2048, k0 * 8 + i)
                                        v1 = T.axis.S(2048, by * 64 + j)
                                        A_shared[v0, v1] = A[v0, v1]
                                for i, j in T.grid(8, 64):
                                    with T.block("B_shared"):
                                        v0 = T.axis.S(2048, k0 * 8 + i)
                                        v1 = T.axis.S(2048, bx * 64 + j)
                                        B_shared[v0, v1] = B[v0, v1]
                                for k1 in T.unroll(0, 8):
                                    for i, j in T.grid(1, 4):
                                        with T.block("A_shared_local"):
                                            v0 = T.axis.S(2048, k0 * 8 + k1 + i)
                                            v1 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + j)
                                            A_shared_local[v0, v1] = A_shared[v0, v1]
                                    for i, j in T.grid(1, 4):
                                        with T.block("B_shared_local"):
                                            v0 = T.axis.S(2048, k0 * 8 + k1 + i)
                                            v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
                                            B_shared_local[v0, v1] = B_shared[v0, v1]
                                    for _, i, j in T.grid(1, 4, 4):
                                        with T.block("C"):
                                            vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i)
                                            vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
                                            vk = T.axis.R(2048, k0 * 8 + k1)
                                            with T.init():
                                                C_local[vi, vj] = 0.0
                                            C_local[vi, vj] = C_local[vi, vj] + A_shared_local[vk, vi] * B_shared_local[vk, vj]
                            for i, j in T.grid(4, 4):
                                with T.block("C_local"):
                                    v0 = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i)
                                    v1 = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j)
                                    C[v0, v1] = C_local[v0, v1]
Beispiel #8
0
def before_unrolled_loop(
    placeholder: T.Buffer[(1, 56, 56, 64), "float32"], ) -> None:
    # function attr dict
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    bgemm = T.alloc_buffer([6, 6, 196, 64], dtype="float32")
    inverse = T.alloc_buffer([4, 4, 196, 64], dtype="float32")
    for i2_0, i3_0, i2_1, i3_1 in T.grid(98, 4, 2, 16):
        for i0 in T.unroll(4):
            for i1 in T.unroll(4):
                for i4 in T.unroll(6):
                    for i5 in T.unroll(6):
                        with T.block("inverse"):
                            vh, vw = T.axis.remap("SS", [i0, i1])
                            p = T.axis.spatial(196, i2_0 * 2 + i2_1)
                            co = T.axis.spatial(64, i3_0 * 16 + i3_1)
                            r_a, r_b = T.axis.remap("RR", [i4, i5])
                            T.reads(bgemm[r_a, r_b, p, co])
                            T.writes(inverse[vh, vw, p, co])
                            with T.init():
                                inverse[vh, vw, p, co] = T.float32(0)
                            inverse[vh, vw, p,
                                    co] = inverse[vh, vw, p,
                                                  co] + bgemm[r_a, r_b, p, co]
 def main():
     for i in T.unroll(2):
         with T.allocate([16], "float32", "global") as buf:
             buf[0] = 0.0