Ejemplo n.º 1
0
def after_matmul_vectorize(
    placeholder: T.Buffer[(64, 768), "float32"],
    placeholder_1: T.Buffer[(768, 768), "float32"],
    T_matmul_NT: T.Buffer[(64, 768), "float32"],
) -> None:
    T.func_attr({
        "global_symbol": "main",
        "tir.noalias": True,
        "layout_free_placeholders": [1]
    })
    T_matmul_NT_global = T.alloc_buffer([64, 768], dtype="float32")
    for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 16, 1, 3):
        for i2_0, i0_2, i1_2, i2_1, i0_3 in T.grid(48, 8, 1, 16, 8):
            for i1_3_fused in T.vectorized(16):
                with T.block("T_matmul_NT"):
                    i = T.axis.spatial(64, i0_2 * 8 + i0_3)
                    j = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + i1_3_fused)
                    k = T.axis.reduce(768, i2_0 * 16 + i2_1)
                    T.reads(placeholder[i, k], placeholder_1[j, k])
                    T.writes(T_matmul_NT_global[i, j])
                    with T.init():
                        T_matmul_NT_global[i, j] = T.float32(0)
                    T_matmul_NT_global[i, j] = T_matmul_NT_global[
                        i, j] + placeholder[i, k] * placeholder_1[j, k]
        for ax0 in T.serial(64):
            for ax1_fused in T.vectorized(16):
                with T.block("T_matmul_NT_global"):
                    v0 = T.axis.spatial(64, ax0)
                    v1 = T.axis.spatial(768, i1_0 * 48 + i1_1 * 16 + ax1_fused)
                    T.reads(T_matmul_NT_global[v0, v1])
                    T.writes(T_matmul_NT[v0, v1])
                    T_matmul_NT[v0, v1] = T_matmul_NT_global[v0, v1]
Ejemplo n.º 2
0
def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 4], dtype="float16")
    B = T.match_buffer(b, [4, 16], dtype="float16")
    C = T.match_buffer(c, [16, 16], dtype="float32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([4], "float16", scope="local")
    MultiB = T.allocate([4], "float16", scope="local")
    Accum = T.allocate([8], "float32", scope="local")
    for i in range(8):
        Accum[i] = T.float32(0)

    for mma_multi_a_col in T.vectorized(4):
        MultiA[mma_multi_a_col] = A[
            ((tx % 32) % 4) + (4 * ((((tx % 32) // 16 + (tx % 32) % 16 // 4 * 2)) % 4)),
            mma_multi_a_col,
        ]
    for mma_multi_b_col in T.vectorized(4):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) % 4,
            mma_multi_b_col + (4 * ((tx % 32) // 8)),
        ]
    T.evaluate(
        T.ptx_mma(
            "m8n8k4",
            "row",
            "row",
            "fp16",
            "fp16",
            "fp32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="float32",
        )
    )
    for mma_accum_c_id in range(8):
        C[
            ((tx % 32) % 2)
            + ((mma_accum_c_id // 2 % 2) * 2)
            + 4 * ((tx % 32) // 16)
            + ((tx % 32) % 16 // 4) % 2 * 8,
            (tx % 32) % 4 // 2 * 2
            + (tx % 32) % 16 // 8 * 4
            + mma_accum_c_id % 2
            + mma_accum_c_id // 4 * 8,
        ] = T.load("float32", Accum, mma_accum_c_id)
def colsum_decompose_with_vectorization(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 32], dtype="float32")
    B = T.match_buffer(b, [32], dtype="float32")
    for i in T.vectorized(0, 32):
        with T.block("B_init"):
            vi = T.axis.S(32, i)
            B[vi] = T.float32(0)
    for k in T.serial(0, 128):
        for i in T.vectorized(0, 32):
            with T.block("B"):
                vk, vi = T.axis.remap("RS", [k, i])
                B[vi] = B[vi] + A[vk, vi]
Ejemplo n.º 4
0
def ptx_global_to_shared_dyn_copy_fp16x8(
    A: T.Buffer[(32, 128), "float16"],
    B: T.Buffer[(32, 128), "float16"],
    C: T.Buffer[(32, 128), "float16"],
) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    bx = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(bx, 1)
    T.launch_thread(tx, 32)
    with T.block():
        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
        B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
        T.reads(A[0:32, 0:128], B[0:32, 0:128])
        T.writes(C[0:32, 0:128])

        T.attr("default", "async_scope", 1)
        for i in T.serial(16):
            for j in T.vectorized(8):
                A_shared[tx, i * 8 + j] = A[tx, i * 8 + j]
                B_shared[tx, i * 8 + j] = B[tx, i * 8 + j]

        T.evaluate(T.ptx_commit_group(dtype=""))
        T.evaluate(T.ptx_wait_group(0, dtype=""))

        for i in range(128):
            C[tx, i] = A_shared[tx, i] + B_shared[tx, i]
Ejemplo n.º 5
0
def decomposed_gemm_parallelize_init(
    A: T.Buffer[(16, 16), "float32"],
    B: T.Buffer[(16, 16), "float32"],
    C: T.Buffer[(16, 16), "float32"],
) -> None:
    local = T.alloc_buffer([16, 16], dtype="float32")
    for i, j in T.grid(4, 4):
        for ii in T.serial(4):
            for jj in T.vectorized(4):
                with T.block("init"):
                    vi = T.axis.spatial(16, i * 4 + ii)
                    vj = T.axis.spatial(16, j * 4 + jj)
                    T.reads()
                    T.writes(local[vi, vj])
                    local[vi, vj] = 0
        for k, ii, jj in T.grid(16, 4, 4):
            with T.block("update"):
                vi = T.axis.spatial(16, i * 4 + ii)
                vj = T.axis.spatial(16, j * 4 + jj)
                vk = T.axis.reduce(16, k)
                T.reads(local[vi, vj], A[vi, vk], B[vj, vk])
                T.writes(local[vi, vj])
                local[vi, vj] = local[vi, vj] + A[vi, vk] * B[vj, vk]
        for ii, jj in T.grid(4, 4):
            with T.block("C"):
                vi = T.axis.spatial(16, i * 4 + ii)
                vj = T.axis.spatial(16, j * 4 + jj)
                T.reads(local[vi, vj])
                T.writes(C[vi, vj])
                C[vi, vj] = local[vi, vj]
Ejemplo n.º 6
0
def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 8], dtype="float16")
    B = T.match_buffer(b, [8, 8], dtype="float16")
    C = T.match_buffer(c, [16, 8], dtype="float32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([4], "float16", scope="local")
    MultiB = T.allocate([2], "float16", scope="local")
    Accum = T.allocate([4], "float32", scope="local")
    for i in range(4):
        Accum[i] = T.float32(0)

    for mma_multi_a_col in T.vectorized(4):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_a_col % 2
        ]
    for mma_multi_b_col in T.vectorized(4):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4 + mma_multi_b_col // 2 * 8, (tx % 32) % 4 * 2 + mma_multi_b_col % 2
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k8",
            "row",
            "col",
            "fp16",
            "fp16",
            "fp32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="float32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8, (tx % 32) % 4 * 2 + mma_accum_c_id % 2
        ] = T.load("float32", Accum, mma_accum_c_id)
Ejemplo n.º 7
0
def loop_syntax_sugar(a: T.handle) -> None:
    A = T.match_buffer(a, (128, 128, 128, 128))
    for i in T.serial(128):
        for j in T.parallel(128):
            for k in T.vectorized(128):
                for x in T.unroll(128):
                    for y in T.thread_binding(128, "threadIdx.x"):
                        for z in T.thread_binding(128, thread="threadIdx.x"):
                            A[i, j, k, x] = A[i, j, k, x] * 2.0
Ejemplo n.º 8
0
def vector_func(a: T.handle, b: T.handle):
    n = T.var("int32")
    m = 128
    A = T.match_buffer(a, (n, m))
    B = T.match_buffer(b, (n, m))

    for i in T.serial(n):
        for j in T.vectorized(m):
            A[i, j] = A[i, j] + B[i, j]
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
    Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
    X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
    Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
    for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
        for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
            for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"):
                for i1_3_init, i2_4_init in T.grid(4, 2):
                    with T.block("Z_init"):
                        b = T.axis.spatial(1, 0)
                        i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
                        j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
                        T.reads()
                        T.writes(Z_local[b, i, j])
                        Z_local[b, i, j] = T.float32(0)
                for i3_0 in T.serial(4):
                    for ax0_ax1_ax2_fused_0 in T.serial(4):
                        for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
                            for ax0_ax1_ax2_fused_2 in T.vectorized(2):
                                with T.block("X_shared"):
                                    v0 = T.axis.spatial(1, 0)
                                    v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32)
                                    v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32)
                                    T.reads(X[v0, v1, v2])
                                    T.writes(X_shared[v0, v1, v2])
                                    X_shared[v0, v1, v2] = X[v0, v1, v2]
                    for ax0_ax1_ax2_fused_0 in T.serial(8):
                        for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
                            with T.block("Y_shared"):
                                v0 = T.axis.spatial(1, 0)
                                v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32)
                                v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32)
                                T.reads(Y[v0, v1, v2])
                                T.writes(Y_shared[v0, v1, v2])
                                Y_shared[v0, v1, v2] = Y[v0, v1, v2]
                    for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2):
                        with T.block("Z_update"):
                            b = T.axis.spatial(1, 0)
                            i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
                            j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
                            k = T.axis.reduce(128, i3_0 * 32 + i3_2)
                            T.block_attr({
                                "meta_schedule.thread_extent_low_inclusive": 1024,
                                "meta_schedule.thread_extent_high_inclusive": 1024,
                            })
                            T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j])
                            T.writes(Z_local[b, i, j])
                            Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
                for ax0, ax1, ax2 in T.grid(1, 4, 2):
                    with T.block("Z_local"):
                        v0 = T.axis.spatial(1, ax0)
                        v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
                        v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
                        T.reads(Z_local[v0, v1, v2])
                        T.writes(Z[v0, v1, v2])
                        Z[v0, v1, v2] = Z_local[v0, v1, v2]
Ejemplo n.º 10
0
def element_wise_split_predicate_vectorized(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    for i in T.vectorized(0, 128):
        for j_0, j_1 in T.grid(13, 10):
            with T.block("B"):
                T.where(j_0 * 10 + j_1 < 128)
                vi = T.axis.S(128, i)
                vj = T.axis.S(128, j_0 * 10 + j_1)
                B[vi, vj] = A[vi, vj] * 2.0
Ejemplo n.º 11
0
def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [8, 32], dtype="int4")
    B = T.match_buffer(b, [8, 32], dtype="uint4")
    C = T.match_buffer(c, [8, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([8], "int4", scope="local")
    MultiB = T.allocate([8], "uint4", scope="local")
    Accum = T.allocate([2], "int32", scope="local")
    for i in range(2):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in T.vectorized(8):
        MultiA[mma_multi_a_col] = A[(tx % 32) // 4, mma_multi_a_col + (tx % 32) % 4 * 8]
    for mma_multi_b_col in T.vectorized(8):
        MultiB[mma_multi_b_col] = B[(tx % 32) // 4, mma_multi_b_col + (tx % 32) % 4 * 8]
    T.evaluate(
        T.ptx_mma(
            "m8n8k32",
            "row",
            "col",
            "int4",
            "uint4",
            "int32",
            MultiA,
            0,
            MultiB,
            0,
            Accum,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(2):
        C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = T.load(
            "int32", Accum, mma_accum_c_id
        )
Ejemplo n.º 12
0
def element_wise_split_predicate_vectorized(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    for i in T.vectorized(0, 128):
        for j_0, j_1 in T.grid(13, 10):
            with T.block([128, 128], "B") as [vi, vj]:
                T.where(j_0 * 10 + j_1 < 128)
                T.bind(vi, i)
                T.bind(vj, j_0 * 10 + j_1)
                B[vi, vj] = A[vi, vj] * 2.0
Ejemplo n.º 13
0
def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle):
    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
    A = T.match_buffer(a, [16, 16], dtype="int8")
    B = T.match_buffer(b, [8, 16], dtype="uint8")
    C = T.match_buffer(c, [16, 8], dtype="int32")
    brow = T.env_thread("blockIdx.y")
    bcol = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(brow, 1)
    T.launch_thread(bcol, 1)
    T.launch_thread(tx, 32)
    MultiA = T.allocate([8], "int8", scope="local")
    MultiB = T.allocate([4], "uint8", scope="local")
    Accum = T.allocate([4], "int32", scope="local")
    for i in range(4):
        Accum[i] = T.int32(0)

    for mma_multi_a_col in range(8):
        MultiA[mma_multi_a_col] = A[
            (tx % 32) // 4 + mma_multi_a_col // 4 * 8,
            (tx % 32) % 4 * 4 + mma_multi_a_col % 4,
        ]
    for mma_multi_b_col in T.vectorized(4):
        MultiB[mma_multi_b_col] = B[
            (tx % 32) // 4,
            (tx % 32) % 4 * 4 + mma_multi_b_col,
        ]
    T.evaluate(
        T.ptx_mma(
            "m16n8k16",
            "row",
            "col",
            "int8",
            "uint8",
            "int32",
            MultiA.data,
            0,
            MultiB.data,
            0,
            Accum.data,
            0,
            False,
            dtype="int32",
        )
    )
    for mma_accum_c_id in range(4):
        C[
            (tx % 32) // 4 + mma_accum_c_id // 2 * 8,
            (tx % 32) % 4 * 2 + mma_accum_c_id % 2,
        ] = Accum[mma_accum_c_id]
 def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     A = T.match_buffer(var_A, [512, 512], dtype="float32")
     B = T.match_buffer(var_B, [512, 512], dtype="float32")
     C = T.match_buffer(var_C, [512, 512], dtype="float32")
     # body
     # with T.block("root")
     C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
     A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.x"):
                 for i2_0 in T.serial(0, 1):
                     for ax0_ax1_fused_0 in T.serial(0, 32768):
                         for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"):
                             with T.block("A_shared"):
                                 v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) // 512)
                                 v1 = T.axis.spatial(512, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1) % 512)
                                 T.reads([A[v0, v1]])
                                 T.writes([A_shared[v0, v1]])
                                 A_shared[v0, v1] = A[v0, v1]
                     for ax0_ax1_fused_0 in T.serial(0, 1024):
                         for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.x"):
                             for ax0_ax1_fused_2 in T.vectorized(0, 2):
                                 with T.block("B_shared"):
                                     v0 = T.axis.spatial(512, (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) // 32)
                                     v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + (ax0_ax1_fused_0 * 16 + ax0_ax1_fused_1 * 2 + ax0_ax1_fused_2) % 32)
                                     T.reads([B[v0, v1]])
                                     T.writes([B_shared[v0, v1]])
                                     B_shared[v0, v1] = B[v0, v1]
                     for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2):
                         with T.block("C"):
                             i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
                             j = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4)
                             k = T.axis.reduce(512, i2_1 * 32 + i2_2)
                             T.reads([A_shared[i, k], B_shared[k, j]])
                             T.writes([C_local[i, j]])
                             with T.init():
                                 C_local[i, j] = T.float32(0)
                             C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j]
                 for ax0, ax1 in T.grid(32, 4):
                     with T.block("C_local"):
                         v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0)
                         v1 = T.axis.spatial(512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1)
                         T.reads([C_local[v0, v1]])
                         T.writes([C[v0, v1]])
                         C[v0, v1] = C_local[v0, v1]
Ejemplo n.º 15
0
 def func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None:
     for bx in T.thread_binding(144, thread="blockIdx.x"):
         for vx in T.thread_binding(2, thread="vthread.x"):
             for tx_p in T.thread_binding(256, thread="threadIdx.x"):
                 with T.block():
                     for k_0 in T.serial(193):
                         with T.block():
                             A_shared = T.alloc_buffer([960, 770], dtype="float32", scope="shared")
                             B_shared = T.alloc_buffer([770, 2304], dtype="float32", scope="shared")
                             for _u in T.serial(1):
                                 for tx in T.thread_binding(256, thread="threadIdx.x"):
                                     for vec in T.vectorized(3):
                                         with T.block("A_shared"):
                                             T.where(bx // 18 * 128 + ((_u * 256 + tx) * 3 + vec) // 4 < 960 and k_0 * 4 + ((_u * 256 + tx) * 3 + vec) % 4 < 770 and (_u * 256 + tx) * 3 + vec < 512)
                                             A_shared[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4] = A[bx // 18 * 128 + (_u * 768 + tx * 3 + vec) // 4, k_0 * 4 + (_u * 768 + tx * 3 + vec) % 4]
                             for _u in T.serial(1):
                                 for tx in T.thread_binding(256, thread="threadIdx.x"):
                                     for vec in T.vectorized(4):
                                         with T.block("B_shared"):
                                             T.where(k_0 * 4 + ((_u * 256 + tx) * 4 + vec) // 128 < 770 and (_u * 256 + tx) * 4 + vec < 512)
                                             B_shared[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128] = B[k_0 * 4 + (_u * 1024 + tx * 4 + vec) // 128, bx % 18 * 128 + (_u * 1024 + tx * 4 + vec) % 128]
                             for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2):
                                 with T.block("update_update"):
                                     C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] = C[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4] + A_shared[(((bx // 18 + 0) * 8 + tx_p // 32) * 8 + i_3) * 2 + i_4, (k_0 + k_1) * 4 + k_2] * B_shared[(k_0 + k_1) * 4 + k_2, ((bx % 18 * 2 + vx % 2) * 32 + tx_p % 32 + j_3) * 2 + j_4]
Ejemplo n.º 16
0
 def compacted_func(A: T.Buffer[(960, 770), "float32"], B: T.Buffer[(770, 2304), "float32"], C: T.Buffer[(960, 2304), "float32"]) -> None:
     for bx in T.thread_binding(144, thread="blockIdx.x"):
         for vx in T.thread_binding(2, thread="vthread.x"):
             for tx_p in T.thread_binding(256, thread="threadIdx.x"):
                 with T.block():
                     for k_0 in T.serial(193):
                         with T.block():
                             A_shared = T.alloc_buffer([128, 4], dtype="float32", scope="shared")
                             B_shared = T.alloc_buffer([4, 128], dtype="float32", scope="shared")
                             for v_u in T.serial(1):
                                 for tx in T.thread_binding(256, thread="threadIdx.x"):
                                     for vec in T.vectorized(3):
                                         with T.block("A_shared"):
                                             T.where(bx // 18 * 128 + (tx * 3 + vec) // 4 < 960 and k_0 * 4 + (tx * 3 + vec) % 4 < 770 and tx * 3 + vec < 512)
                                             A_shared[(tx * 3 + vec) // 4, (tx * 3 + vec) % 4] = A[bx // 18 * 128 + (tx * 3 + vec) // 4, k_0 * 4 + (tx * 3 + vec) % 4]
                             for v_u in T.serial(1):
                                 for tx in T.thread_binding(256, thread="threadIdx.x"):
                                     for vec in T.vectorized(4):
                                         with T.block("B_shared"):
                                             T.where(k_0 * 4 + tx // 32 < 770 and tx * 4 + vec < 512)
                                             B_shared[tx // 32, tx % 32 * 4 + vec] = B[k_0 * 4 + tx // 32, bx % 18 * 128 + tx % 32 * 4 + vec]
                             for k_1, i_3, j_3, k_2, i_4, j_4 in T.grid(1, 8, 1, 4, 2, 2):
                                 with T.block("update_update"):
                                     C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] = C[bx // 18 * 128 + tx_p // 32 * 16 + i_3 * 2 + i_4, bx % 18 * 128 + vx * 64 + tx_p % 32 * 2 + j_4] + A_shared[tx_p // 32 * 16 + i_3 * 2 + i_4, k_2] * B_shared[k_2, vx * 64 + tx_p % 32 * 2 + j_4]
Ejemplo n.º 17
0
def element_wise_compute_at_split_vectorized(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    B = T.alloc_buffer((128, 128))
    for i in T.serial(0, 128):
        for j0 in T.serial(0, 128):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j0])
                B[vi, vj] = A[vi, vj] * 2.0
        for j1o in T.serial(0, 32):
            for j1i in T.vectorized(0, 4):
                with T.block("C"):
                    vi = T.axis.S(128, i)
                    vj = T.axis.S(128, j1o * 4 + j1i)
                    C[vi, vj] = B[vi, vj] + 1.0
Ejemplo n.º 18
0
def element_wise_compute_at_split_vectorized(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (128, 128))
    C = T.match_buffer(c, (128, 128))
    B = T.alloc_buffer((128, 128))
    for i in T.serial(0, 128):
        for j0 in T.serial(0, 128):
            with T.block([128, 128], "B") as [vi, vj]:
                T.bind(vi, i)
                T.bind(vj, j0)
                B[vi, vj] = A[vi, vj] * 2.0
        for j1o in T.serial(0, 32):
            for j1i in T.vectorized(0, 4):
                with T.block([128, 128], "C") as [vi, vj]:
                    T.bind(vi, i)
                    T.bind(vj, j1o * 4 + j1i)
                    C[vi, vj] = B[vi, vj] + 1.0
Ejemplo n.º 19
0
def Move_PUV0(a: T.handle, b: T.handle) -> None:
    # function attr dict
    T.func_attr({"global_symbol": "main"})
    A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32")
    B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32")
    # body
    with T.block("root"):
        for i0_j0_fused in T.parallel(0, 8192):
            for i1, j1, k0, i2, j2 in T.grid(4, 4, 64, 4, 8):
                for k1_fused in T.vectorized(0, 32):
                    with T.block("move"):
                        vi = T.axis.spatial(
                            1024, i0_j0_fused // 64 * 16 + i1 * 4 + i2)
                        vj = T.axis.spatial(
                            1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2)
                        vk = T.axis.spatial(1024, k0 * 32 + k1_fused)
                        T.where(i0_j0_fused // 64 * 16 + i1 * 4 + i2 < 1024
                                and i0_j0_fused % 64 * 32 + j1 * 8 + j2 < 1024
                                and k0 * 32 + k1_fused < 1024)
                        T.reads([A[vi, vj, vk]])
                        T.writes([B[vi, vj, vk]])
                        B[vi, vj, vk] = A[vi, vj, vk]
Ejemplo n.º 20
0
    def ptx_global_to_shared_copy(A: T.Buffer[(32, 128), dtype],
                                  B: T.Buffer[(32, 128), dtype]) -> None:
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        bx = T.env_thread("blockIdx.x")
        tx = T.env_thread("threadIdx.x")
        T.launch_thread(bx, 1)
        T.launch_thread(tx, 32)
        with T.block():
            A_shared = T.alloc_buffer([32, 128], dtype, scope="shared")
            T.reads(A[0:32, 0:128])
            T.writes(B[0:32, 0:128])

            T.attr("default", "async_scope", 1)
            for i in T.serial(num_iters):
                for j in T.vectorized(vector_size):
                    A_shared[tx, i * vector_size_expr +
                             j] = A[tx, i * vector_size_expr + j]

            T.evaluate(T.ptx_commit_group(dtype=""))
            T.evaluate(T.ptx_wait_group(0, dtype=""))

            for i in range(128):
                B[tx, i] = A_shared[tx, i]
Ejemplo n.º 21
0
def decomposed_gemm_after_vectorize(
    A: T.Buffer[(16, 16), "float32"],
    B: T.Buffer[(16, 16), "float32"],
    C: T.Buffer[(16, 16), "float32"],
):
    local = T.alloc_buffer((16, 16), "float32")
    for i, j in T.grid(4, 4):
        for ii, jj in T.grid(4, 4):
            with T.block("init"):
                vi = T.axis.S(16, i * 4 + ii)
                vj = T.axis.S(16, j * 4 + jj)
                local[vi, vj] = 0
        for k, ii, jj in T.grid(16, 4, 4):
            with T.block("update"):
                vi = T.axis.S(16, i * 4 + ii)
                vj = T.axis.S(16, j * 4 + jj)
                vk = T.axis.R(16, k)
                local[vi, vj] += A[vi, vk] * B[vj, vk]
        for ii in range(4):
            for jj in T.vectorized(4):
                with T.block("C"):
                    vi = T.axis.S(16, i * 4 + ii)
                    vj = T.axis.S(16, j * 4 + jj)
                    C[vi, vj] = local[vi, vj]
Ejemplo n.º 22
0
def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle,
                    handle_c: T.handle) -> None:
    # pylint: disable=missing-function-docstring
    # match buffer
    match_buffer_a = T.match_buffer(handle_a, [1024, 1024], "float16")
    match_buffer_b = T.match_buffer(handle_b, [1024, 1024], "float16")
    match_buffer_c = T.match_buffer(handle_c, [1024, 1024], "float32")

    # body
    for block_idx_x in T.thread_binding(0, 16, "blockIdx.x"):
        for block_idx_y in T.thread_binding(0, 8, "blockIdx.y"):
            with T.block():
                axis_bx, axis_by = T.axis.remap("SS",
                                                [block_idx_x, block_idx_y])
                shared_a = T.alloc_buffer([1024, 1024],
                                          "float16",
                                          scope="shared")
                shared_b = T.alloc_buffer([1024, 1024],
                                          "float16",
                                          scope="shared")
                wmma_a = T.alloc_buffer([1024, 1024],
                                        "float16",
                                        scope="wmma.matrix_a")
                wmma_b = T.alloc_buffer([1024, 1024],
                                        "float16",
                                        scope="wmma.matrix_b")
                wmma_c = T.alloc_buffer([1024, 1024],
                                        "float32",
                                        scope="wmma.accumulator")

                # pylint: disable=too-many-nested-blocks
                for thread_ty in T.thread_binding(0, 2, "threadIdx.y"):
                    for thread_tz in T.thread_binding(0, 2, "threadIdx.z"):
                        for index_i, index_jj in T.grid(2, 4):
                            with T.block():
                                new_axis_vi = T.axis.S(
                                    64, axis_bx * 4 + thread_ty * 2 + index_i)
                                new_axis_vj = T.axis.S(
                                    64, axis_by * 8 + thread_tz * 4 + index_jj)
                                T.reads([])
                                T.writes(wmma_c[new_axis_vi *
                                                16:new_axis_vi * 16 + 16,
                                                new_axis_vj *
                                                16:new_axis_vj * 16 + 16, ])
                                match_buffer_c0 = T.match_buffer(
                                    wmma_c[new_axis_vi * 16:new_axis_vi * 16 +
                                           16, new_axis_vj *
                                           16:new_axis_vj * 16 + 16, ],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_fill_fragment(
                                        match_buffer_c0.data,
                                        16,
                                        16,
                                        16,
                                        index_i * 4 + index_jj,
                                        T.float32(0),  # pylint: disable=not-callable
                                        dtype="handle",
                                    ))

                        for k_o in range(0, 32):
                            # copy data from global to shared
                            for thread_tx in T.thread_binding(
                                    0, 32, "threadIdx.x"):
                                for index_i0, index_j0 in T.grid(1, 4):
                                    for index_j1 in T.vectorized(0, 4):
                                        with T.block():
                                            new_axis_vi = T.axis.S(
                                                1024,
                                                axis_bx * 64 + thread_ty * 32 +
                                                thread_tx + index_i0,
                                            )
                                            new_axis_vj = T.axis.S(
                                                1024,
                                                k_o * 32 + thread_tz * 16 +
                                                index_j0 * 4 + index_j1,
                                            )
                                            shared_a[new_axis_vi, new_axis_vj +
                                                     8] = match_buffer_a[
                                                         new_axis_vi,
                                                         new_axis_vj]

                                for index_i0, index_j0 in T.grid(2, 4):
                                    for index_j1 in T.vectorized(0, 4):
                                        with T.block():
                                            new_axis_vi = T.axis.S(
                                                1024,
                                                axis_by * 128 +
                                                thread_ty * 64 +
                                                thread_tx * 2 + index_i0,
                                            )
                                            new_axis_vj = T.axis.S(
                                                1024,
                                                k_o * 32 + thread_tz * 16 +
                                                index_j0 * 4 + index_j1,
                                            )
                                            shared_b[new_axis_vi, new_axis_vj +
                                                     8] = match_buffer_b[
                                                         new_axis_vi,
                                                         new_axis_vj]

                            for k_i in range(0, 2):
                                for index_i in range(0, 2):
                                    with T.block():
                                        new_axis_vi = T.axis.S(
                                            64, axis_bx * 4 + thread_ty * 2 +
                                            index_i)
                                        axis_vk = T.axis.S(64, k_o * 2 + k_i)
                                        T.reads(shared_a[new_axis_vi *
                                                         16:new_axis_vi * 16 +
                                                         16, axis_vk *
                                                         16:axis_vk * 16 + 16 +
                                                         8, ])
                                        T.writes(
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ])
                                        stride0 = T.var("int32")
                                        stride1 = T.var("int32")
                                        match_buffer_a0 = T.match_buffer(
                                            shared_a[new_axis_vi *
                                                     16:new_axis_vi * 16 + 16,
                                                     axis_vk *
                                                     16:axis_vk * 16 + 16 +
                                                     8, ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[stride0, stride1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_a0 = T.match_buffer(
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_a0.data,
                                                16,
                                                16,
                                                16,
                                                index_i,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(
                                                        dtype="float16"),
                                                    match_buffer_a0.data,
                                                    match_buffer_a0.elem_offset
                                                    + 8,
                                                    match_buffer_a0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                match_buffer_a0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            ))
                                for index_jj in range(0, 4):
                                    with T.block():
                                        new_axis_vj = T.axis.S(
                                            64, axis_by * 8 + thread_tz * 4 +
                                            index_jj)
                                        axis_vk = T.axis.S(64, k_o * 2 + k_i)
                                        T.reads(shared_b[new_axis_vj *
                                                         16:new_axis_vj * 16 +
                                                         16, axis_vk *
                                                         16:axis_vk * 16 + 16 +
                                                         8, ])
                                        T.writes(
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ])
                                        stride0 = T.var("int32")
                                        stride1 = T.var("int32")
                                        match_buffer_b0 = T.match_buffer(
                                            shared_b[new_axis_vj *
                                                     16:new_axis_vj * 16 + 16,
                                                     axis_vk *
                                                     16:axis_vk * 16 + 16 +
                                                     8, ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[stride0, stride1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_b0 = T.match_buffer(
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_b0.data,
                                                16,
                                                16,
                                                16,
                                                index_jj,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(
                                                        dtype="float16"),
                                                    match_buffer_b0.data,
                                                    match_buffer_b0.elem_offset
                                                    + 8,
                                                    match_buffer_b0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                match_buffer_b0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            ))
                                for index_i, index_jj in T.grid(2, 4):
                                    with T.block():
                                        new_axis_vi = T.axis.S(
                                            64, axis_bx * 4 + thread_ty * 2 +
                                            index_i)
                                        new_axis_vj = T.axis.S(
                                            64, axis_by * 8 + thread_tz * 4 +
                                            index_jj)
                                        axis_vk = T.axis.R(64, k_o * 2 + k_i)
                                        T.reads([
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            wmma_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ],
                                        ])
                                        T.writes(
                                            wmma_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ])
                                        wmma_a1 = T.match_buffer(
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_b1 = T.match_buffer(
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_c1 = T.match_buffer(
                                            wmma_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_mma_sync(
                                                wmma_c1.data,
                                                index_i * 4 + index_jj,
                                                wmma_a1.data,
                                                index_i,
                                                wmma_b1.data,
                                                index_jj,
                                                wmma_c1.data,
                                                index_i * 4 + index_jj,
                                                dtype="handle",
                                            ))
                        for index_i, index_jj in T.grid(2, 4):
                            with T.block():
                                new_axis_vi = T.axis.S(
                                    64, axis_bx * 4 + thread_ty * 2 + index_i)
                                new_axis_vj = T.axis.S(
                                    64, axis_by * 8 + thread_tz * 4 + index_jj)
                                T.reads(wmma_c[new_axis_vi *
                                               16:new_axis_vi * 16 + 16,
                                               new_axis_vj *
                                               16:new_axis_vj * 16 + 16, ])
                                T.writes(
                                    match_buffer_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ])
                                stride0 = T.var("int32")
                                stride1 = T.var("int32")
                                wmma_c2 = T.match_buffer(
                                    wmma_c[new_axis_vi * 16:new_axis_vi * 16 +
                                           16, new_axis_vj *
                                           16:new_axis_vj * 16 + 16, ],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                match_buffer_c1 = T.match_buffer(
                                    match_buffer_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ],
                                    (16, 16),
                                    "float32",
                                    strides=[stride0, stride1],
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_store_matrix_sync(
                                        wmma_c2.data,
                                        16,
                                        16,
                                        16,
                                        index_i * 4 + index_jj,
                                        T.tvm_access_ptr(
                                            T.type_annotation(dtype="float32"),
                                            match_buffer_c1.data,
                                            match_buffer_c1.elem_offset,
                                            match_buffer_c1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        match_buffer_c1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    ))
Ejemplo n.º 23
0
 def main(
     placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
     placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
     conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid(
             1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1):
         for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid(
                 4, 7, 4, 4):
             with T.block("conv2d_NCHWc_int8_o_init"):
                 n = T.axis.spatial(1, 0)
                 oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init)
                 oh = T.axis.spatial(56,
                                     i2_0 * 28 + i2_2_init * 4 + i2_3_init)
                 ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init)
                 oc_block_o = T.axis.spatial(1, 0)
                 T.reads()
                 T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
                 for i4_1 in T.vectorized(16):
                     with T.block("conv2d_NCHWc_int8_init"):
                         oc_block_init = T.axis.spatial(16, i4_1)
                         T.reads()
                         T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                                    oc_block_init])
                         conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                           oc_block_init] = 0
         for (
                 i7_0,
                 i8_0,
                 i9_0_0,
                 i0_2,
                 i1_2,
                 i2_2,
                 i3_2,
                 i4_0_2,
                 i5_1,
                 i6_1,
                 i7_1,
                 i8_1,
                 i9_0_1,
                 i0_3,
                 i1_3,
                 i2_3,
                 i3_3,
                 i4_0_3,
         ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1):
             with T.block("conv2d_NCHWc_int8_o_update"):
                 n = T.axis.spatial(1, 0)
                 oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2)
                 oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3)
                 ow = T.axis.spatial(56, i3_1 * 4 + i3_3)
                 oc_block_o = T.axis.spatial(1, 0)
                 kh = T.axis.reduce(1, 0)
                 kw = T.axis.reduce(1, 0)
                 ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1])
                 ic_s_inner_o = T.axis.reduce(1, 0)
                 T.reads(
                     conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16],
                     placeholder[n, ic_outer, oh + kh, ow + kw,
                                 ic_f_inner * 4:ic_f_inner * 4 + 4],
                     placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                                   0:16, 0:4],
                 )
                 T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
                 A = T.match_buffer(
                     placeholder[n, ic_outer, oh + kh, ow + kw,
                                 ic_f_inner * 4:ic_f_inner * 4 + 4],
                     [4],
                     dtype="uint8",
                     offset_factor=1,
                 )
                 B = T.match_buffer(
                     placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                                   0:16, 0:4],
                     [16, 4],
                     dtype="int8",
                     offset_factor=1,
                 )
                 C = T.match_buffer(
                     conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16],
                     [16],
                     dtype="int32",
                     offset_factor=1,
                 )
                 A_u8x4 = A.vload([0], "uint8x4")
                 A_i32 = T.reinterpret(A_u8x4, dtype="int32")
                 B_i8x64 = B.vload([0, 0], dtype="int8x64")
                 B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
                 C[T.ramp(
                     0, 1,
                     16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin(
                         T.llvm_lookup_intrinsic_id(
                             "llvm.x86.avx512.vpdpbusd.512"),
                         T.uint32(0),
                         T.broadcast(0, 16),
                         T.broadcast(A_i32, 16),
                         B_i32x16,
                         dtype="int32x16",
                     )
Ejemplo n.º 24
0
def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None:
    # match buffer
    A = T.match_buffer(a, [1024, 1024], "float16")
    B = T.match_buffer(b, [1024, 1024], "float16")
    C = T.match_buffer(c, [1024, 1024], "float32")

    # body
    for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"):
        for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"):
            with T.block([16, 8]) as [bx, by]:
                T.bind(bx, blockIdx_x)
                T.bind(by, blockIdx_y)
                shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared")
                shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared")
                wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a")
                wmma_B = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b")
                wmma_C = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator")
                for ty in T.thread_binding(0, 2, "threadIdx.y"):
                    for tz in T.thread_binding(0, 2, "threadIdx.z"):
                        for i, j in T.grid(2, 4):
                            with T.block([64, 64]) as [vi, vj]:
                                T.bind(vi, bx * 4 + ty * 2 + i)
                                T.bind(vj, by * 8 + tz * 4 + j)
                                T.reads([])
                                T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                C0 = T.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_fill_fragment(
                                        C0.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        T.float32(0),
                                        dtype="handle",
                                    )
                                )

                        for ko in range(0, 32):
                            # copy data from global to shared
                            for tx in T.thread_binding(0, 32, "threadIdx.x"):
                                for i0, j0 in T.grid(1, 4):
                                    for j1 in T.vectorized(0, 4):
                                        with T.block([1024, 1024]) as [vi, vj]:
                                            T.bind(vi, bx * 64 + ty * 32 + tx + i0)
                                            T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_A[vi, vj + 8] = A[vi, vj]

                                for i0, j0 in T.grid(2, 4):
                                    for j1 in T.vectorized(0, 4):
                                        with T.block([1024, 1024]) as [vi, vj]:
                                            T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0)
                                            T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_B[vi, vj + 8] = B[vi, vj]

                            for ki in range(0, 2):
                                for i in range(0, 2):
                                    with T.block([64, 64]) as [vi, vk]:
                                        T.bind(vi, bx * 4 + ty * 2 + i)
                                        T.bind(vk, ko * 2 + ki)
                                        T.reads(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        T.writes(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = T.var("int32")
                                        s1 = T.var("int32")
                                        A0 = T.match_buffer(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_A0 = T.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_A0.data,
                                                16,
                                                16,
                                                16,
                                                i,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(dtype="float16"),
                                                    A0.data,
                                                    A0.elem_offset + 8,
                                                    A0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                A0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            )
                                        )
                                for j in range(0, 4):
                                    with T.block([64, 64]) as [vj, vk]:
                                        T.bind(vj, by * 8 + tz * 4 + j)
                                        T.bind(vk, ko * 2 + ki)
                                        T.reads(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        T.writes(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = T.var("int32")
                                        s1 = T.var("int32")
                                        B0 = T.match_buffer(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_B0 = T.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_B0.data,
                                                16,
                                                16,
                                                16,
                                                j,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(dtype="float16"),
                                                    B0.data,
                                                    B0.elem_offset + 8,
                                                    B0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                B0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            )
                                        )
                                for i, j in T.grid(2, 4):
                                    with T.block([64, 64, T.reduce_axis(0, 64)]) as [
                                        vi,
                                        vj,
                                        vk,
                                    ]:
                                        T.bind(vi, bx * 4 + ty * 2 + i)
                                        T.bind(vj, by * 8 + tz * 4 + j)
                                        T.bind(vk, ko * 2 + ki)
                                        T.reads(
                                            [
                                                wmma_A[
                                                    vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_B[
                                                    vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_C[
                                                    vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16
                                                ],
                                            ]
                                        )
                                        T.writes(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]
                                        )
                                        wmma_A1 = T.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_B1 = T.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_C1 = T.match_buffer(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_mma_sync(
                                                wmma_C1.data,
                                                i * 4 + j,
                                                wmma_A1.data,
                                                i,
                                                wmma_B1.data,
                                                j,
                                                wmma_C1.data,
                                                i * 4 + j,
                                                dtype="handle",
                                            )
                                        )
                        for i, j in T.grid(2, 4):
                            with T.block([64, 64]) as [vi, vj]:
                                T.bind(vi, bx * 4 + ty * 2 + i)
                                T.bind(vj, by * 8 + tz * 4 + j)
                                T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                s0 = T.var("int32")
                                s1 = T.var("int32")
                                wmma_C2 = T.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                C1 = T.match_buffer(
                                    C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[s0, s1],
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_store_matrix_sync(
                                        wmma_C2.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        T.tvm_access_ptr(
                                            T.type_annotation(dtype="float32"),
                                            C1.data,
                                            C1.elem_offset,
                                            C1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        C1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    )
                                )
Ejemplo n.º 25
0
 def main(var_X: T.handle, var_W: T.handle, var_B: T.handle,
          var_bn_scale: T.handle, var_bn_offset: T.handle,
          var_compute: T.handle) -> None:
     X = T.match_buffer(var_X, [1, 512, 56, 56], dtype="float32")
     W = T.match_buffer(var_W, [512, 512, 3, 3], dtype="float32")
     B = T.match_buffer(var_B, [512, 1, 1], dtype="float32")
     bn_scale = T.match_buffer(var_bn_scale, [512, 1, 1], dtype="float32")
     bn_offset = T.match_buffer(var_bn_offset, [512, 1, 1], dtype="float32")
     compute = T.match_buffer(var_compute, [1, 512, 56, 56],
                              dtype="float32")
     pad_temp = T.alloc_buffer([1, 512, 58, 58], dtype="float32")
     compute_1 = T.alloc_buffer([1, 512, 56, 56], dtype="float32")
     compute_local = T.alloc_buffer([1, 512, 56, 56],
                                    dtype="float32",
                                    scope="local")
     pad_temp_shared = T.alloc_buffer([1, 512, 58, 58],
                                      dtype="float32",
                                      scope="shared")
     W_shared = T.alloc_buffer([512, 512, 3, 3],
                               dtype="float32",
                               scope="shared")
     for i0, i1, i2, i3 in T.grid(1, 512, 58, 58):
         with T.block("pad_temp"):
             i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
             pad_temp[i0_1, i1_1, i2_1,
                      i3_1] = T.if_then_else(i2_1 >= 1 and i2_1 < 57
                                             and i3_1 >= 1 and i3_1 < 57,
                                             X[i0_1, i1_1, i2_1 - 1,
                                               i3_1 - 1],
                                             T.float32(0),
                                             dtype="float32")
     for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(0,
                                                       224,
                                                       thread="blockIdx.x"):
         for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(
                 0, 2, thread="vthread.x"):
             for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(
                     0, 8, thread="threadIdx.x"):
                 for i4_0, i5_0, i6_0 in T.grid(1, 3, 1):
                     for ax0_ax1_ax2_ax3_fused_0 in T.serial(
                             0,
                             40960,
                             annotations={
                                 "meta_schedule.cooperative_fetch": 1
                             }):
                         for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 3):
                             with T.block("pad_temp_shared"):
                                 v0 = T.axis.spatial(1, 0)
                                 v1 = T.axis.spatial(
                                     512, (ax0_ax1_ax2_ax3_fused_0 * 3 +
                                           ax0_ax1_ax2_ax3_fused_1) // 30 //
                                     8 % 512)
                                 v2 = T.axis.spatial(
                                     58,
                                     i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8
                                     + i5_0 +
                                     (ax0_ax1_ax2_ax3_fused_0 * 3 +
                                      ax0_ax1_ax2_ax3_fused_1) // 30 % 8)
                                 v3 = T.axis.spatial(
                                     58,
                                     i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 +
                                     (ax0_ax1_ax2_ax3_fused_0 * 3 +
                                      ax0_ax1_ax2_ax3_fused_1) % 30)
                                 pad_temp_shared[v0, v1, v2,
                                                 v3] = pad_temp[v0, v1, v2,
                                                                v3]
                     for ax0_ax1_ax2_ax3_fused_0 in T.serial(
                             0,
                             12288,
                             annotations={
                                 "meta_schedule.cooperative_fetch": 1
                             }):
                         for ax0_ax1_ax2_ax3_fused_1 in T.vectorized(0, 4):
                             with T.block("W_shared"):
                                 v0 = T.axis.spatial(
                                     512,
                                     i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 +
                                     (ax0_ax1_ax2_ax3_fused_0 * 4 +
                                      ax0_ax1_ax2_ax3_fused_1) // 1536)
                                 v1 = T.axis.spatial(
                                     512,
                                     (ax0_ax1_ax2_ax3_fused_0 * 4 +
                                      ax0_ax1_ax2_ax3_fused_1) // 3 % 512)
                                 v2 = T.axis.spatial(3, i5_0)
                                 v3 = T.axis.spatial(
                                     3, (ax0_ax1_ax2_ax3_fused_0 * 4 +
                                         ax0_ax1_ax2_ax3_fused_1) % 3)
                                 W_shared[v0, v1, v2, v3] = W[v0, v1, v2,
                                                              v3]
                     for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(
                             32, 1, 1, 1, 1, 1, 1, 16, 1, 3, 1, 8, 2, 28):
                         with T.block("compute"):
                             nn = T.axis.spatial(1, 0)
                             ff = T.axis.spatial(
                                 512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 +
                                 i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + i1_4)
                             yy = T.axis.spatial(
                                 56,
                                 i0_0_i1_0_i2_0_i3_0_fused // 2 % 7 * 8 +
                                 i0_1_i1_1_i2_1_i3_1_fused * 4 +
                                 i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + i2_4)
                             xx = T.axis.spatial(
                                 56,
                                 i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + i3_4)
                             rc = T.axis.reduce(512, i4_1 * 16 + i4_2)
                             ry, rx = T.axis.remap("RR", [i5_0, i6_2])
                             with T.init():
                                 compute_local[nn, ff, yy,
                                               xx] = T.float32(0)
                             compute_local[nn, ff, yy, xx] = compute_local[
                                 nn, ff, yy, xx] + pad_temp_shared[
                                     nn, rc, yy + ry,
                                     xx + rx] * W_shared[ff, rc, ry, rx]
                 for ax0, ax1, ax2, ax3 in T.grid(1, 8, 2, 28):
                     with T.block("compute_local"):
                         v0 = T.axis.spatial(1, ax0)
                         v1 = T.axis.spatial(
                             512, i0_0_i1_0_i2_0_i3_0_fused // 14 * 32 +
                             i0_2_i1_2_i2_2_i3_2_fused // 2 * 8 + ax1)
                         v2 = T.axis.spatial(
                             56, i0_0_i1_0_i2_0_i3_0_fused % 14 // 2 * 8 +
                             i0_1_i1_1_i2_1_i3_1_fused * 4 +
                             i0_2_i1_2_i2_2_i3_2_fused % 2 * 2 + ax2)
                         v3 = T.axis.spatial(
                             56, i0_0_i1_0_i2_0_i3_0_fused % 2 * 28 + ax3)
                         compute_1[v0, v1, v2, v3] = compute_local[v0, v1,
                                                                   v2, v3]
     for i0, i1, i2, i3 in T.grid(1, 512, 56, 56):
         with T.block("compute_1"):
             i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3])
             compute[i0_2, i1_2, i2_2, i3_2] = T.max(
                 (compute_1[i0_2, i1_2, i2_2, i3_2] + B[i1_2, 0, 0]) *
                 bn_scale[i1_2, 0, 0] + bn_offset[i1_2, 0, 0], T.float32(0))
Ejemplo n.º 26
0
def loops() -> None:
    for i in T.parallel(0, 2):
        for j in T.serial(0, 1):
            for z in T.vectorized(3, 4):
                T.evaluate(0)
Ejemplo n.º 27
0
 def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None:
     A = T.match_buffer(var_A, [512, 512], dtype="float32")
     B = T.match_buffer(var_B, [512, 512], dtype="float32")
     C = T.match_buffer(var_C, [512, 512], dtype="float32")
     C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
     A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.x"):
                 for i2_0 in T.serial(0, 1):
                     for ax0_ax1_fused_0 in T.serial(0, 32768):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.x"):
                             with T.block("A_shared"):
                                 v0 = T.axis.spatial(
                                     512,
                                     (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1)
                                     // 512)
                                 v1 = T.axis.spatial(
                                     512,
                                     (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1)
                                     % 512)
                                 T.reads([A[v0, v1]])
                                 T.writes([A_shared[v0, v1]])
                                 T.block_attr(
                                     {"meta_schedule.cooperative_fetch": 1})
                                 A_shared[v0, v1] = A[v0, v1]
                     for ax0_ax1_fused_0 in T.serial(0, 1024):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.x"):
                             for ax0_ax1_fused_2 in T.vectorized(0, 2):
                                 with T.block("B_shared"):
                                     v0 = T.axis.spatial(
                                         512, (ax0_ax1_fused_0 * 16 +
                                               ax0_ax1_fused_1 * 2 +
                                               ax0_ax1_fused_2) // 32)
                                     v1 = T.axis.spatial(
                                         512, i0_0_i1_0_fused * 32 +
                                         (ax0_ax1_fused_0 * 16 +
                                          ax0_ax1_fused_1 * 2 +
                                          ax0_ax1_fused_2) % 32)
                                     T.reads([B[v0, v1]])
                                     T.writes([B_shared[v0, v1]])
                                     T.block_attr({
                                         "meta_schedule.cooperative_fetch":
                                         2
                                     })
                                     B_shared[v0, v1] = B[v0, v1]
                     for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(
                             2, 2, 16, 2):
                         with T.block("C_init"):
                             i = T.axis.spatial(
                                 512, i0_1_i1_1_fused * 32 +
                                 i0_3_init * 16 + i0_4_init)
                             j = T.axis.spatial(
                                 512, i0_0_i1_0_fused * 32 +
                                 i0_2_i1_2_fused * 4 + i1_3_init * 2 +
                                 i1_4_init)
                             T.reads([])
                             T.writes([C_local[i, j]])
                             C_local[i, j] = T.float32(0)
                     for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(
                             16, 2, 2, 32, 16, 2):
                         with T.block("C_update"):
                             i = T.axis.spatial(
                                 512,
                                 i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
                             j = T.axis.spatial(
                                 512, i0_0_i1_0_fused * 32 +
                                 i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4)
                             k = T.axis.reduce(512, i2_1 * 32 + i2_2)
                             T.reads([
                                 C_local[i, j], A_shared[i, k], B_shared[k,
                                                                         j]
                             ])
                             T.writes([C_local[i, j]])
                             C_local[i, j] = C_local[
                                 i, j] + A_shared[i, k] * B_shared[k, j]
                 for ax0, ax1 in T.grid(32, 4):
                     with T.block("C_local"):
                         v0 = T.axis.spatial(512,
                                             i0_1_i1_1_fused * 32 + ax0)
                         v1 = T.axis.spatial(
                             512, i0_0_i1_0_fused * 32 +
                             i0_2_i1_2_fused * 4 + ax1)
                         T.reads([C_local[v0, v1]])
                         T.writes([C[v0, v1]])
                         C[v0, v1] = C_local[v0, v1]
 def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024),
                                                            "float32"],
          C: T.Buffer[(1024, 1024), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
     # body
     # with T.block("root")
     for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"):
         for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"):
             for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"):
                 for threadIdx_x in T.thread_binding(2,
                                                     thread="threadIdx.x"):
                     for k_0 in T.serial(32):
                         with T.block():
                             T.reads(
                                 A[blockIdx_y * 32:blockIdx_y * 32 + 32,
                                   k_0 * 32:k_0 * 32 + 32],
                                 B[k_0 * 32:k_0 * 32 + 32,
                                   blockIdx_x * 32:blockIdx_x * 32 + 32])
                             T.writes(
                                 C[blockIdx_y * 32:blockIdx_y * 32 + 32,
                                   blockIdx_x * 32:blockIdx_x * 32 + 32])
                             A_shared = T.alloc_buffer([1024, 1024],
                                                       dtype="float32",
                                                       scope="shared")
                             B_shared = T.alloc_buffer([1024, 1024],
                                                       dtype="float32",
                                                       scope="shared")
                             for ax0_ax1_fused_0 in T.serial(64):
                                 for ax0_ax1_fused_3 in T.vectorized(4):
                                     with T.block("A_shared"):
                                         T.reads(A[blockIdx_y * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) // 32,
                                                   k_0 * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) % 32])
                                         T.writes(A_shared[
                                             blockIdx_y * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32])
                                         T.block_attr({
                                             "tir.manifest_shared_memory_local_stage":
                                             1
                                         })
                                         A_shared[
                                             blockIdx_y * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32] = A[
                                                  blockIdx_y * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) // 32,
                                                  k_0 * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) % 32]
                             for ax0_ax1_fused_0 in T.serial(64):
                                 for ax0_ax1_fused_3 in T.vectorized(4):
                                     with T.block("B_shared"):
                                         T.reads(B[k_0 * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) // 32,
                                                   blockIdx_x * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) % 32])
                                         T.writes(B_shared[
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             blockIdx_x * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32])
                                         T.block_attr({
                                             "tir.manifest_shared_memory_local_stage":
                                             1
                                         })
                                         B_shared[
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             blockIdx_x * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32] = B[
                                                  k_0 * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) // 32,
                                                  blockIdx_x * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) % 32]
                             for k_1, i_2, j_2, k_2 in T.grid(
                                     2, 16, 16, 16):
                                 with T.block("C"):
                                     T.reads(
                                         A_shared[blockIdx_y * 32 +
                                                  threadIdx_y * 16 + i_2,
                                                  k_0 * 32 + k_1 * 16 +
                                                  k_2],
                                         B_shared[k_0 * 32 + k_1 * 16 + k_2,
                                                  blockIdx_x * 32 +
                                                  threadIdx_x * 16 + j_2])
                                     T.writes(C[blockIdx_y * 32 +
                                                threadIdx_y * 16 + i_2,
                                                blockIdx_x * 32 +
                                                threadIdx_x * 16 + j_2])
                                     if k_0 * 32 + k_1 * 16 + k_2 == 0:
                                         C[blockIdx_y * 32 +
                                           threadIdx_y * 16 + i_2,
                                           blockIdx_x * 32 +
                                           threadIdx_x * 16 +
                                           j_2] = T.float32(0)
                                     C[
                                         blockIdx_y * 32 +
                                         threadIdx_y * 16 + i_2,
                                         blockIdx_x * 32 +
                                         threadIdx_x * 16 + j_2] = C[
                                             blockIdx_y * 32 +
                                             threadIdx_y * 16 + i_2,
                                             blockIdx_x * 32 + threadIdx_x *
                                             16 + j_2] + A_shared[
                                                 blockIdx_y * 32 +
                                                 threadIdx_y * 16 + i_2,
                                                 k_0 * 32 + k_1 * 16 +
                                                 k_2] * B_shared[
                                                     k_0 * 32 + k_1 * 16 +
                                                     k_2, blockIdx_x * 32 +
                                                     threadIdx_x * 16 + j_2]