Exemple #1
0
 def pooling_decompose_2(
         x: T.Buffer[(1, 16, 225, 225), "int8"],
         tensor: T.Buffer[(1, 16, 225, 225), "int8"]) -> None:
     pad_temp = T.alloc_buffer([1, 16, 231, 231], dtype="int8")
     for i0, i2_0, i3_0, ax0, ax1, ax2 in T.grid(1, 3, 3, 16, 81, 81):
         with T.block("pad_temp_pad_const"):
             ax0_1 = T.axis.spatial(1, 0)
             ax1_1 = T.axis.spatial(16, ax0)
             ax2_1 = T.axis.spatial(231, i2_0 * 75 + ax1)
             ax3 = T.axis.spatial(231, i3_0 * 75 + ax2)
             T.reads()
             T.writes(pad_temp[ax0_1, ax1_1, ax2_1, ax3])
             pad_temp[ax0_1, ax1_1, ax2_1, ax3] = T.int8(0)
     for i0, i2_0, i3_0 in T.grid(1, 3, 3):
         for ax0, ax1, ax2 in T.grid(16, 81, 81):
             with T.block("pad_temp"):
                 ax0_2 = T.axis.spatial(1, 0)
                 ax1_2 = T.axis.spatial(16, ax0)
                 ax2_2 = T.axis.spatial(225, i2_0 * 75 + ax1 - 3)
                 ax3 = T.axis.spatial(225, i3_0 * 75 + ax2 - 3)
                 T.where(3 <= i2_0 * 75 + ax1 and i2_0 * 75 + ax1 < 228
                         and 3 <= i3_0 * 75 + ax2 and i3_0 * 75 + ax2 < 228)
                 T.reads(x[ax0_2, ax1_2, ax2_2, ax3])
                 T.writes(pad_temp[ax0_2, ax1_2, ax2_2 + 3, ax3 + 3])
                 pad_temp[ax0_2, ax1_2, ax2_2 + 3,
                          ax3 + 3] = x[ax0_2, ax1_2, ax2_2, ax3]
         for i1, i2_1, i3_1, i4, i5 in T.grid(16, 75, 75, 7, 7):
             with T.block("tensor"):
                 ax0_3, ax1_3 = T.axis.remap("SS", [i0, i1])
                 ax2_3 = T.axis.spatial(225, i2_0 * 75 + i2_1)
                 ax3 = T.axis.spatial(225, i3_0 * 75 + i3_1)
                 rv0, rv1 = T.axis.remap("RR", [i4, i5])
                 T.reads(pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1])
                 T.writes(tensor[ax0_3, ax1_3, ax2_3, ax3])
                 with T.init():
                     tensor[ax0_3, ax1_3, ax2_3, ax3] = T.int8(0)
                 tensor[ax0_3, ax1_3, ax2_3, ax3] = (
                     tensor[ax0_3, ax1_3, ax2_3, ax3] +
                     pad_temp[ax0_3, ax1_3, ax2_3 + rv0, ax3 + rv1])
def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None:
    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with T.block("root"):
        T.reads([])
        T.writes([])
        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
        for i0 in T.serial(0, 128):
            for ax1 in T.serial(0, 128):
                with T.block("B"):
                    T.block_attr({"buffer_dim_align": [0]})
                    vi, vj = T.axis.remap("SS", [i0, ax1])
                    T.reads([A[vi, vj]])
                    T.writes([B[vi, vj]])
                    B[vi, vj] = (A[vi, vj]*T.float32(2))
            for i1 in T.serial(0, 128):
                with T.block("C"):
                    vi_1, vj_1 = T.axis.remap("SS", [i0, i1])
                    T.reads([B[vi_1, vj_1]])
                    T.writes([C[vi_1, vj_1]])
                    C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
Exemple #3
0
def shared_mem_func(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    C = T.match_buffer(c, (16, 16), "float32")
    for i0 in T.thread_binding(0, 2, thread="blockIdx.x"):
        for i1 in T.thread_binding(0, 2, thread="vthread"):
            for i2 in T.thread_binding(0, 4, thread="threadIdx.x"):
                with T.block():
                    T.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
                    T.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
                    B = T.alloc_buffer((16, 16), "float32", scope="shared")
                    for j in range(0, 16):
                        with T.block() as []:
                            T.reads(A[i0 * 8 + i1 * 4 + i2, j])
                            T.writes(B[i0 * 8 + i1 * 4 + i2, j])
                            B[i0 * 8 + i1 * 4 + i2,
                              j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
                    for j in range(0, 16):
                        with T.block() as []:
                            T.reads(B[i0 * 8 + i1 * 4 + i2, j])
                            T.writes(C[i0 * 8 + i1 * 4 + i2, j])
                            C[i0 * 8 + i1 * 4 + i2,
                              j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0
Exemple #4
0
def cache_read_shape_int64(var_A: T.handle, var_C: T.handle) -> None:
    A = T.match_buffer(var_A, (T.int64(128), T.int64(128)), dtype="float32")
    C = T.match_buffer(var_C, (T.int64(128), T.int64(128)), dtype="float32")
    B = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32")
    A_global = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32")
    for ax0, ax1 in T.grid(T.int64(128), T.int64(128)):
        with T.block("A_global"):
            v0, v1 = T.axis.remap("SS", [ax0, ax1])
            T.reads(A[v0, v1])
            T.writes(A_global[v0, v1])
            A_global[v0, v1] = A[v0, v1]
    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A_global[vi, vj])
            T.writes(B[vi, vj])
            B[vi, vj] = A_global[vi, vj] * T.float32(2)
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(B[vi, vj])
            T.writes(C[vi, vj])
            C[vi, vj] = B[vi, vj] + T.float32(1)
Exemple #5
0
def simple_compute(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16),
                                                                 "float32"]):
    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
        for i in T.serial(
                0,
                16,
                annotations={
                    "software_pipeline_stage": [0, 1],
                    "software_pipeline_order": [0, 1]
                },
        ):
            with T.block():
                T.reads(A[tx, i])
                T.writes(C[tx, i])
                B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
                with T.block():
                    T.reads(A[tx, i])
                    T.writes(B[tx, 0])
                    B[tx, 0] = A[tx, i] * T.float32(2)
                with T.block():
                    T.reads(B[tx, 0])
                    T.writes(C[tx, i])
                    C[tx, i] = B[tx, 0] + T.float32(1)
def element_wise(a: T.handle, c: T.handle) -> None:
    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with T.block([], "root"):
        T.reads([])
        T.writes([])
        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
        for i0 in T.serial(0, 128):
            for ax1 in T.serial(0, 128):
                with T.block([128, 128], "B") as [vi, vj]:
                    T.bind(vi, i0)
                    T.bind(vj, ax1)
                    T.reads([A[vi, vj]])
                    T.writes([B[vi, vj]])
                    B[vi, vj] = (A[vi, vj]*T.float32(2))
            for i1 in T.serial(0, 128):
                with T.block([128, 128], "C") as [vi_1, vj_1]:
                    T.bind(vi_1, i0)
                    T.bind(vj_1, i1)
                    T.reads([B[vi_1, vj_1]])
                    T.writes([C[vi_1, vj_1]])
                    C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
def compacted_strided_buffer_func(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    C = T.match_buffer(c, (16, 16), "float32")
    for i0 in range(0, 4):
        with T.block():
            T.reads(A[i0 * 4:i0 * 4 + 4, 0:16])
            T.writes(C[i0 * 4:i0 * 4 + 4, 0:16])
            B = T.alloc_buffer([4, 16],
                               "float32",
                               strides=[17, 1],
                               scope="global")
            for i1 in range(0, 4):
                for j in range(0, 16):
                    with T.block() as []:
                        T.reads(A[i0 * 4 + i1, j])
                        T.writes(B[i1, j])
                        B[i1, j] = A[i0 * 4 + i1, j] + 1.0
            for i1 in range(0, 4):
                for j in range(0, 16):
                    with T.block() as []:
                        T.reads(B[i1, j])
                        T.writes(C[i0 * 4 + i1, j])
                        C[i0 * 4 + i1, j] = B[i1, j] * 2.0
 def before_blockize_rca(
     A: T.Buffer[(128, 128), "float32"],
     C: T.Buffer[(128, 128), "float32"],
 ) -> None:
     B = T.alloc_buffer([128, 128], dtype="float32")
     for i, j in T.grid(8, 8):
         with T.block("B_o"):
             vi, vj = T.axis.remap("SS", [i, j])
             T.reads(A[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16])
             T.writes(B[vi * 16:vi * 16 + 16, vj * 16:vj * 16 + 16])
             for i_1, j_1 in T.grid(16, 16):
                 with T.block("B"):
                     vi_i, vj_i = T.axis.remap("SS", [i_1, j_1])
                     T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i])
                     T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i])
                     B[vi * 16 + vi_i, vj * 16 +
                       vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0
         for ax0, ax1 in T.grid(16, 16):
             with T.block("C"):
                 vi = T.axis.spatial(128, i * 16 + ax0)
                 vj = T.axis.spatial(128, j * 16 + ax1)
                 T.reads(B[vi, vj])
                 T.writes(C[vi, vj])
                 C[vi, vj] = B[vi, vj] + 1.0
Exemple #9
0
def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), dtype="float16")
    B = T.match_buffer(b, (128, 128), dtype="float16")
    C = T.match_buffer(c, (128, 128), dtype="float16")
    D = T.match_buffer(d, (128, 128), dtype="float16")

    for i, j in T.grid(128, 128):
        with T.block("load_store"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A[vi, vj])
            T.writes(D[vi, vj])
            D[vi, vj] = A[vi, vj]
    for i, j in T.grid(8, 8):
        with T.block("opaque"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
            T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
            T.evaluate(
                T.tvm_load_matrix_sync(
                    B.data,
                    16,
                    16,
                    16,
                    vi * 8 + vj,
                    T.tvm_access_ptr(
                        T.type_annotation(dtype="float16"),
                        A.data,
                        vi * 2048 + vj * 16,
                        128,
                        1,
                        dtype="handle",
                    ),
                    128,
                    "row_major",
                    dtype="handle",
                )
            )
    for i, j in T.grid(8, 8):
        with T.block("match_buffer"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
            T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
            A0 = T.match_buffer(
                A[
                    vi * 16 : vi * 16 + 16,
                    vj * 16 : vj * 16 + 16,
                ],
                (16, 16),
                "float16",
                strides=[128, 1],
                offset_factor=1,
            )
            C0 = T.match_buffer(
                C[
                    vi * 16 : vi * 16 + 16,
                    vj * 16 : vj * 16 + 16,
                ],
                (16, 16),
                "float16",
                strides=[128, 1],
                offset_factor=1,
            )
            T.evaluate(
                T.tvm_load_matrix_sync(
                    C0.data,
                    16,
                    16,
                    16,
                    vi * 8 + vj,
                    T.tvm_access_ptr(
                        T.type_annotation(dtype="float16"),
                        A0.data,
                        A0.elem_offset,
                        A0.strides[0],
                        1,
                        dtype="handle",
                    ),
                    128,
                    "row_major",
                    dtype="handle",
                )
            )
def mismatch_args() -> None:
    A = T.alloc_buffer((128, 128), "float32")
    with T.block():
        T.reads(A[0, 0], A[1, 1])  # error
        T.evaluate(1.0)
def implicit_root_has_read():
    T.reads([])  # error: implicit root does not support reads
    T.evaluate(0.0)
Exemple #12
0
    def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
        A = T.match_buffer(a, (WARP_SIZE, local_size),
                           in_dtype,
                           align=128,
                           offset_factor=16,
                           scope="warp")
        B = T.match_buffer(b, (WARP_SIZE, local_size),
                           in_dtype,
                           align=128,
                           offset_factor=16,
                           scope="warp")
        C = T.match_buffer(c, (WARP_SIZE, local_size_out),
                           out_dtype,
                           align=128,
                           offset_factor=16,
                           scope="warp")

        with T.block("root"):
            T.reads(
                C[0:WARP_SIZE, 0:local_size_out],
                A[0:WARP_SIZE, 0:local_size],
                B[0:WARP_SIZE, 0:local_size],
            )
            T.writes(C[0:WARP_SIZE, 0:local_size_out])
            tx = T.env_thread("threadIdx.x")
            T.launch_thread(tx, WARP_SIZE)

            T.evaluate(
                T.ptx_mma(
                    mma_prefix,
                    "row",
                    "col",
                    in_dtype_abbrv,
                    in_dtype_abbrv,
                    out_dtype_abbrv,
                    A.data,
                    A.elem_offset + tx * lift(local_size),
                    B.data,
                    B.elem_offset + tx * lift(local_size),
                    C.data,
                    C.elem_offset + tx * lift(local_size_out),
                    False,
                    dtype=out_dtype,
                ))

            T.evaluate(
                T.ptx_mma(
                    mma_prefix,
                    "row",
                    "col",
                    in_dtype_abbrv,
                    in_dtype_abbrv,
                    out_dtype_abbrv,
                    A.data,
                    A.elem_offset + tx * lift(local_size),
                    B.data,
                    B.elem_offset + tx * lift(local_size) +
                    lift(local_size) // 2,
                    C.data,
                    C.elem_offset + tx * lift(local_size_out) +
                    lift(local_size_out) // 2,
                    False,
                    dtype=out_dtype,
                ))
Exemple #13
0
def duplicate_reads() -> None:
    A = T.alloc_buffer((128, 128), "float32")
    with T.block([16, 16]) as [vi, vj]:
        T.reads(A[0:8, 0:8])
        T.reads(A[0:16, 0:16])  # error
        T.evaluate(1.0)
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128),
                                                                "float32"],
             Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
    Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
    X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
    Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
    for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
        for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
            for i0_2_i1_2_i2_2_fused in T.thread_binding(128,
                                                         thread="threadIdx.x"):
                for i1_3_init, i2_4_init in T.grid(4, 2):
                    with T.block("Z_init"):
                        b = T.axis.spatial(1, 0)
                        i = T.axis.spatial(
                            128, i0_0_i1_0_i2_0_fused // 4 * 32 +
                            i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
                        j = T.axis.spatial(
                            128, i0_0_i1_0_i2_0_fused % 4 * 32 +
                            i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
                        T.reads()
                        T.writes(Z_local[b, i, j])
                        Z_local[b, i, j] = T.float32(0)
                for i3_0 in T.serial(4):
                    for ax0_ax1_ax2_fused_0 in T.serial(4):
                        for ax0_ax1_ax2_fused_1 in T.thread_binding(
                                128, thread="threadIdx.x"):
                            for ax0_ax1_ax2_fused_2 in T.vectorized(2):
                                with T.block("X_shared"):
                                    v0 = T.axis.spatial(1, 0)
                                    v1 = T.axis.spatial(
                                        128, i0_0_i1_0_i2_0_fused // 4 * 32 +
                                        (ax0_ax1_ax2_fused_0 * 256 +
                                         ax0_ax1_ax2_fused_1 * 2 +
                                         ax0_ax1_ax2_fused_2) // 32)
                                    v2 = T.axis.spatial(
                                        128, i3_0 * 32 +
                                        (ax0_ax1_ax2_fused_0 * 256 +
                                         ax0_ax1_ax2_fused_1 * 2 +
                                         ax0_ax1_ax2_fused_2) % 32)
                                    T.reads(X[v0, v1, v2])
                                    T.writes(X_shared[v0, v1, v2])
                                    X_shared[v0, v1, v2] = X[v0, v1, v2]
                    for ax0_ax1_ax2_fused_0 in T.serial(8):
                        for ax0_ax1_ax2_fused_1 in T.thread_binding(
                                128, thread="threadIdx.x"):
                            with T.block("Y_shared"):
                                v0 = T.axis.spatial(1, 0)
                                v1 = T.axis.spatial(
                                    128,
                                    i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 +
                                                 ax0_ax1_ax2_fused_1) // 32)
                                v2 = T.axis.spatial(
                                    128, i0_0_i1_0_i2_0_fused % 4 * 32 +
                                    (ax0_ax1_ax2_fused_0 * 128 +
                                     ax0_ax1_ax2_fused_1) % 32)
                                T.reads(Y[v0, v1, v2])
                                T.writes(Y_shared[v0, v1, v2])
                                Y_shared[v0, v1, v2] = Y[v0, v1, v2]
                    for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(
                            1, 1, 4, 1, 32, 1, 1, 2):
                        with T.block("Z_update"):
                            b = T.axis.spatial(1, 0)
                            i = T.axis.spatial(
                                128, i0_0_i1_0_i2_0_fused // 4 * 32 +
                                i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
                            j = T.axis.spatial(
                                128, i0_0_i1_0_i2_0_fused % 4 * 32 +
                                i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
                            k = T.axis.reduce(128, i3_0 * 32 + i3_2)
                            T.block_attr({
                                "meta_schedule.thread_extent_low_inclusive":
                                1024,
                                "meta_schedule.thread_extent_high_inclusive":
                                1024,
                            })
                            T.reads(Z_local[b, i, j], X_shared[b, i, k],
                                    Y_shared[b, k, j])
                            T.writes(Z_local[b, i, j])
                            Z_local[b, i, j] = Z_local[
                                b, i,
                                j] + X_shared[b, i, k] * Y_shared[b, k, j]
                for ax0, ax1, ax2 in T.grid(1, 4, 2):
                    with T.block("Z_local"):
                        v0 = T.axis.spatial(1, ax0)
                        v1 = T.axis.spatial(
                            128, i0_0_i1_0_i2_0_fused // 4 * 32 +
                            i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
                        v2 = T.axis.spatial(
                            128, i0_0_i1_0_i2_0_fused % 4 * 32 +
                            i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
                        T.reads(Z_local[v0, v1, v2])
                        T.writes(Z[v0, v1, v2])
                        Z[v0, v1, v2] = Z_local[v0, v1, v2]
Exemple #15
0
 def main(
     placeholder: T.Buffer[(1, 384), "int64"],
     placeholder_1: T.Buffer[(30522, 768), "float32"],
     placeholder_2: T.Buffer[(1, 384, 768), "float32"],
     T_add: T.Buffer[(1, 384, 768), "float32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     compile_engine_const = T.alloc_buffer([], dtype="int64")
     T_less = T.alloc_buffer([1, 384], dtype="bool")
     compile_engine_const_1 = T.alloc_buffer([], dtype="int64")
     T_add_1 = T.alloc_buffer([1, 384], dtype="int64")
     T_where = T.alloc_buffer([1, 384], dtype="int64")
     T_take = T.alloc_buffer([1, 384, 768], dtype="float32")
     with T.block("compile_engine_const"):
         vi = T.axis.spatial(1, 0)
         T.reads()
         T.writes(compile_engine_const[()])
         compile_engine_const[()] = T.int64(0)
     for i0, i1 in T.grid(1, 384):
         with T.block("T_less"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(placeholder[ax0, ax1], compile_engine_const[()])
             T.writes(T_less[ax0, ax1])
             T_less[ax0,
                    ax1] = placeholder[ax0, ax1] < compile_engine_const[()]
     with T.block("compile_engine_const_1"):
         vi = T.axis.spatial(1, 0)
         T.reads()
         T.writes(compile_engine_const_1[()])
         compile_engine_const_1[()] = T.int64(30522)
     for i0, i1 in T.grid(1, 384):
         with T.block("T_add"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(placeholder[ax0, ax1], compile_engine_const_1[()])
             T.writes(T_add_1[ax0, ax1])
             T_add_1[ax0,
                     ax1] = placeholder[ax0,
                                        ax1] + compile_engine_const_1[()]
     for i0, i1 in T.grid(1, 384):
         with T.block("T_where"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_less[ax0, ax1], T_add_1[ax0, ax1], placeholder[ax0,
                                                                      ax1])
             T.writes(T_where[ax0, ax1])
             T_where[ax0, ax1] = T.Select(
                 T.cast(T_less[ax0, ax1], "int32") != 0, T_add_1[ax0, ax1],
                 placeholder[ax0, ax1])
     for i0, i1, i2 in T.grid(1, 384, 768):
         with T.block("T_take"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(
                 placeholder_1[T.min(T.max(T.int64(0), T_where[
                     ax0, ax1]), T.int64(30521)), ax2],
                 T_where[ax0, ax1],
             )
             T.writes(T_take[ax0, ax1, ax2])
             T_take[ax0, ax1, ax2] = placeholder_1[
                 T.min(T.max(T.int64(0), T_where[ax0,
                                                 ax1]), T.int64(30521)),
                 ax2]
     for i0, i1, i2 in T.grid(1, 384, 768):
         with T.block("T_add_1"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(T_take[ax0, ax1, ax2], placeholder_2[ax0, ax1, ax2])
             T.writes(T_add[ax0, ax1, ax2])
             T_add[ax0, ax1,
                   ax2] = T_take[ax0, ax1, ax2] + placeholder_2[ax0, ax1,
                                                                ax2]
 def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024),
                                                            "float32"],
          C: T.Buffer[(1024, 1024), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
     # body
     # with T.block("root")
     for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"):
         for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"):
             for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"):
                 for threadIdx_x in T.thread_binding(2,
                                                     thread="threadIdx.x"):
                     for k_0 in T.serial(32):
                         with T.block():
                             T.reads(
                                 A[blockIdx_y * 32:blockIdx_y * 32 + 32,
                                   k_0 * 32:k_0 * 32 + 32],
                                 B[k_0 * 32:k_0 * 32 + 32,
                                   blockIdx_x * 32:blockIdx_x * 32 + 32])
                             T.writes(
                                 C[blockIdx_y * 32:blockIdx_y * 32 + 32,
                                   blockIdx_x * 32:blockIdx_x * 32 + 32])
                             A_shared = T.alloc_buffer([1024, 1024],
                                                       dtype="float32",
                                                       scope="shared")
                             B_shared = T.alloc_buffer([1024, 1024],
                                                       dtype="float32",
                                                       scope="shared")
                             for ax0_ax1_fused_0 in T.serial(64):
                                 for ax0_ax1_fused_3 in T.vectorized(4):
                                     with T.block("A_shared"):
                                         T.reads(A[blockIdx_y * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) // 32,
                                                   k_0 * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) % 32])
                                         T.writes(A_shared[
                                             blockIdx_y * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32])
                                         T.block_attr({
                                             "tir.manifest_shared_memory_local_stage":
                                             1
                                         })
                                         A_shared[
                                             blockIdx_y * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32] = A[
                                                  blockIdx_y * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) // 32,
                                                  k_0 * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) % 32]
                             for ax0_ax1_fused_0 in T.serial(64):
                                 for ax0_ax1_fused_3 in T.vectorized(4):
                                     with T.block("B_shared"):
                                         T.reads(B[k_0 * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) // 32,
                                                   blockIdx_x * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) % 32])
                                         T.writes(B_shared[
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             blockIdx_x * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32])
                                         T.block_attr({
                                             "tir.manifest_shared_memory_local_stage":
                                             1
                                         })
                                         B_shared[
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             blockIdx_x * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32] = B[
                                                  k_0 * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) // 32,
                                                  blockIdx_x * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) % 32]
                             for k_1, i_2, j_2, k_2 in T.grid(
                                     2, 16, 16, 16):
                                 with T.block("C"):
                                     T.reads(
                                         A_shared[blockIdx_y * 32 +
                                                  threadIdx_y * 16 + i_2,
                                                  k_0 * 32 + k_1 * 16 +
                                                  k_2],
                                         B_shared[k_0 * 32 + k_1 * 16 + k_2,
                                                  blockIdx_x * 32 +
                                                  threadIdx_x * 16 + j_2])
                                     T.writes(C[blockIdx_y * 32 +
                                                threadIdx_y * 16 + i_2,
                                                blockIdx_x * 32 +
                                                threadIdx_x * 16 + j_2])
                                     if k_0 * 32 + k_1 * 16 + k_2 == 0:
                                         C[blockIdx_y * 32 +
                                           threadIdx_y * 16 + i_2,
                                           blockIdx_x * 32 +
                                           threadIdx_x * 16 +
                                           j_2] = T.float32(0)
                                     C[
                                         blockIdx_y * 32 +
                                         threadIdx_y * 16 + i_2,
                                         blockIdx_x * 32 +
                                         threadIdx_x * 16 + j_2] = C[
                                             blockIdx_y * 32 +
                                             threadIdx_y * 16 + i_2,
                                             blockIdx_x * 32 + threadIdx_x *
                                             16 + j_2] + A_shared[
                                                 blockIdx_y * 32 +
                                                 threadIdx_y * 16 + i_2,
                                                 k_0 * 32 + k_1 * 16 +
                                                 k_2] * B_shared[
                                                     k_0 * 32 + k_1 * 16 +
                                                     k_2, blockIdx_x * 32 +
                                                     threadIdx_x * 16 + j_2]
Exemple #17
0
def opaque_access(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, (128, 128), dtype="float16")
    B = T.match_buffer(b, (128, 128), dtype="float16")
    C = T.match_buffer(c, (128, 128), dtype="float16")
    D = T.match_buffer(d, (128, 128), dtype="float16")

    with T.block([128, 128], "load_store") as [vi, vj]:
        T.reads(A[vi, vj])
        T.writes(D[vi, vj])
        D.data[vi * 128 + vj] = T.load("float16", A.data, vi * 128 + vj)
    with T.block([8, 8], "opaque") as [vi, vj]:
        T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        T.evaluate(
            T.tvm_load_matrix_sync(
                B.data,
                16,
                16,
                16,
                vi * 8 + vj,
                T.tvm_access_ptr(
                    T.type_annotation(dtype="float16"),
                    A.data,
                    vi * 2048 + vj * 16,
                    128,
                    1,
                    dtype="handle",
                ),
                128,
                "row_major",
                dtype="handle",
            )
        )
    with T.block([8, 8], "match_buffer") as [vi, vj]:
        T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
        A0 = T.match_buffer(
            A[
                vi * 16 : vi * 16 + 16,
                vj * 16 : vj * 16 + 16,
            ],
            (16, 16),
            "float16",
            strides=[128, 1],
            offset_factor=1,
        )
        C0 = T.match_buffer(
            C[
                vi * 16 : vi * 16 + 16,
                vj * 16 : vj * 16 + 16,
            ],
            (16, 16),
            "float16",
            strides=[128, 1],
            offset_factor=1,
        )
        T.evaluate(
            T.tvm_load_matrix_sync(
                C0.data,
                16,
                16,
                16,
                vi * 8 + vj,
                T.tvm_access_ptr(
                    T.type_annotation(dtype="float16"),
                    A0.data,
                    A0.elem_offset,
                    A0.strides[0],
                    1,
                    dtype="handle",
                ),
                128,
                "row_major",
                dtype="handle",
            )
        )
Exemple #18
0
def lowered_single_reduction_loop_with_block_predicate(
        A: T.Buffer[(256, 256), "float32"],
        T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    cross_thread_0 = T.alloc_buffer([1],
                                    dtype="float32",
                                    strides=[1],
                                    scope="local")
    in_thread_0 = T.alloc_buffer([1],
                                 dtype="float32",
                                 strides=[1],
                                 scope="local")
    cross_thread_1 = T.alloc_buffer([1],
                                    dtype="float32",
                                    strides=[1],
                                    scope="local")
    in_thread_1 = T.alloc_buffer([1],
                                 dtype="float32",
                                 strides=[1],
                                 scope="local")
    for i0 in T.serial(256):
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem_in_thread_init"):
                    T.reads()
                    T.writes(in_thread_0[0])
                    in_thread_0[0] = T.float32(-3.4028234663852886e38)
                with T.block("T_softmax_maxelem_in_thread"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(A[i0_1, k], in_thread_0[0])
                    T.writes(in_thread_0[0])
                    in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
                with T.block("T_softmax_maxelem_cross_thread"):
                    T.reads(in_thread_0[0])
                    T.writes(cross_thread_0[0])
                    T.attr(
                        T.comm_reducer(lambda x, y: T.max(x, y),
                                       [T.float32(-3.4028234663852886e38)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(
                            T.uint32(1),
                            in_thread_0[0],
                            True,
                            cross_thread_0.data,
                            ax1_1,
                            dtype="handle",
                        ))
                with T.block("T_softmax_maxelem_write_back"):
                    i0_2 = T.axis.spatial(256, i0)
                    T.reads(cross_thread_0[0])
                    T.writes(T_softmax_maxelem_shared[i0_2])
                    T_softmax_maxelem_shared[i0_2] = cross_thread_0[0]
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_expsum_in_thread_init"):
                    T.reads()
                    T.writes(in_thread_1[0])
                    in_thread_1[0] = T.float32(0)
                with T.block("T_softmax_expsum_in_thread"):
                    i0_3 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3],
                            in_thread_1[0])
                    T.writes(in_thread_1[0])
                    in_thread_1[0] = in_thread_1[0] + T.exp(
                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
                        dtype="float32")
                with T.block("T_softmax_expsum_cross_thread"):
                    T.reads(in_thread_1[0])
                    T.writes(cross_thread_1[0])
                    T.attr(
                        T.comm_reducer(lambda x_1, y_1: x_1 + y_1,
                                       [T.float32(0)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(
                            T.uint32(1),
                            in_thread_1[0],
                            True,
                            cross_thread_1.data,
                            ax1_1,
                            dtype="handle",
                        ))
                with T.block("T_softmax_expsum_write_back"):
                    i0_4 = T.axis.spatial(256, i0)
                    T.reads(cross_thread_1[0])
                    T.writes(T_softmax_expsum_shared[i0_4])
                    T_softmax_expsum_shared[i0_4] = cross_thread_1[0]
        for i1_0 in T.serial(1):
            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_5 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_1)
                    T.where(i1_0 * 512 + i1_1 < 256)
                    T.reads(A[i0_5, i1], T_softmax_maxelem_shared[i0_5],
                            T_softmax_expsum_shared[i0_5])
                    T.writes(T_softmax_norm[i0_5, i1])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_5, i1] = (
                        T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
                              dtype="float32") / T_softmax_expsum_shared[i0_5])
Exemple #19
0
 def main(
     X: T.Buffer[(128, 128), "int8"],
     W: T.Buffer[(128, 128), "int8"],
     compute: T.Buffer[(128, 128), "int32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     compute_local = T.alloc_buffer([128, 128],
                                    dtype="int32",
                                    scope="local")
     X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(2,
                                                     thread="threadIdx.x"):
                 for i0_3_init, i1_3_init, i0_4_init in T.grid(4, 16, 4):
                     with T.block("compute_o_init"):
                         i = T.axis.spatial(
                             128, i0_0_i1_0_fused // 2 * 16 +
                             i0_3_init * 4 + i0_4_init)
                         j = T.axis.spatial(
                             128,
                             i0_0_i1_0_fused % 2 * 64 +
                             i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 +
                             i1_3_init,
                         )
                         T.reads()
                         T.writes(compute_local[i, j])
                         T.block_attr(
                             {"meta_schedule.auto_tensorize": "dp4a"})
                         with T.block("compute_init"):
                             T.reads()
                             T.writes(compute_local[i, j])
                             compute_local[i, j] = 0
                 for i2_0_0 in T.serial(2):
                     for ax0_ax1_fused in T.serial(1024):
                         with T.block("X_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(X[v0, v1])
                             T.writes(X_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 4})
                             X_shared[v0, v1] = X[v0, v1]
                     for ax0_ax1_fused in T.serial(4096):
                         with T.block("W_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused % 2 * 64 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(W[v0, v1])
                             T.writes(W_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 1})
                             W_shared[v0, v1] = W[v0, v1]
                     for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(
                             2, 4, 16, 8, 4, 1):
                         with T.block("compute_o_update"):
                             i = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 +
                                 i0_4)
                             j = T.axis.spatial(
                                 128,
                                 i0_0_i1_0_fused % 2 * 64 +
                                 i0_1_i1_1_fused * 32 +
                                 i0_2_i1_2_fused * 16 + i1_3,
                             )
                             k_o = T.axis.reduce(
                                 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2)
                             T.reads(
                                 compute_local[i, j],
                                 X_shared[i, k_o * 4:k_o * 4 + 4],
                                 W_shared[j, k_o * 4:k_o * 4 + 4],
                             )
                             T.writes(compute_local[i, j])
                             A = T.match_buffer(
                                 X_shared[i, k_o * 4:k_o * 4 + 4],
                                 [4],
                                 dtype="int8",
                                 scope="shared",
                                 align=4,
                                 offset_factor=1,
                             )
                             B = T.match_buffer(
                                 W_shared[j, k_o * 4:k_o * 4 + 4],
                                 [4],
                                 dtype="int8",
                                 scope="shared",
                                 align=4,
                                 offset_factor=1,
                             )
                             C = T.match_buffer(
                                 compute_local[i, j],
                                 [1],
                                 dtype="int32",
                                 scope="local",
                                 align=4,
                                 offset_factor=1,
                             )
                             C[0] = C[0] + T.call_pure_extern(
                                 "__dp4a",
                                 A[T.ramp(0, 1, 4)],
                                 B[T.ramp(0, 1, 4)],
                                 0,
                                 dtype="int32",
                             )
                 for ax0, ax1 in T.grid(16, 16):
                     with T.block("compute_local"):
                         v0 = T.axis.spatial(
                             128, i0_0_i1_0_fused // 2 * 16 + ax0)
                         v1 = T.axis.spatial(
                             128,
                             i0_0_i1_0_fused % 2 * 64 +
                             i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 +
                             ax1,
                         )
                         T.reads(compute_local[v0, v1])
                         T.writes(compute[v0, v1])
                         compute[v0, v1] = compute_local[v0, v1]
Exemple #20
0
 def main(
     X: T.Buffer[(128, 128), "int8"],
     W: T.Buffer[(128, 128), "int8"],
     compute: T.Buffer[(128, 128), "int32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     compute_local = T.alloc_buffer([128, 128],
                                    dtype="int32",
                                    scope="local")
     X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(2,
                                                     thread="threadIdx.x"):
                 for i2_0_0 in T.serial(2):
                     for ax0_ax1_fused in T.serial(1024):
                         with T.block("X_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(X[v0, v1])
                             T.writes(X_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 4})
                             X_shared[v0, v1] = X[v0, v1]
                     for ax0_ax1_fused in T.serial(4096):
                         with T.block("W_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused % 2 * 64 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(W[v0, v1])
                             T.writes(W_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 1})
                             W_shared[v0, v1] = W[v0, v1]
                     for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(
                             2, 4, 16, 8, 4, 1):
                         with T.block("compute_o"):
                             i = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 +
                                 i0_4)
                             j = T.axis.spatial(
                                 128,
                                 i0_0_i1_0_fused % 2 * 64 +
                                 i0_1_i1_1_fused * 32 +
                                 i0_2_i1_2_fused * 16 + i1_3,
                             )
                             k_o = T.axis.reduce(
                                 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2)
                             T.reads(
                                 X_shared[i, k_o * 4:k_o * 4 + 4],
                                 W_shared[j, k_o * 4:k_o * 4 + 4],
                             )
                             T.writes(compute_local[i, j])
                             T.block_attr(
                                 {"meta_schedule.auto_tensorize": "dp4a"})
                             with T.init():
                                 with T.block("compute_init"):
                                     T.reads()
                                     T.writes(compute_local[i, j])
                                     compute_local[i, j] = 0
                             for i2_1 in T.serial(4):
                                 with T.block("compute"):
                                     k = T.axis.reduce(4, i2_1)
                                     T.reads(
                                         compute_local[i, j],
                                         X_shared[i, k_o * 4 + k],
                                         W_shared[j, k_o * 4 + k],
                                     )
                                     T.writes(compute_local[i, j])
                                     T.block_attr({
                                         "meta_schedule.tiling_structure":
                                         "SSSRRSRS"
                                     })
                                     compute_local[
                                         i,
                                         j] = compute_local[i, j] + T.cast(
                                             X_shared[i, k_o * 4 + k],
                                             "int32") * T.cast(
                                                 W_shared[j, k_o * 4 + k],
                                                 "int32")
                 for ax0, ax1 in T.grid(16, 16):
                     with T.block("compute_local"):
                         v0 = T.axis.spatial(
                             128, i0_0_i1_0_fused // 2 * 16 + ax0)
                         v1 = T.axis.spatial(
                             128,
                             i0_0_i1_0_fused % 2 * 64 +
                             i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 +
                             ax1,
                         )
                         T.reads(compute_local[v0, v1])
                         T.writes(compute[v0, v1])
                         compute[v0, v1] = compute_local[v0, v1]
Exemple #21
0
 def main(
     placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
     placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
     conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     for (
             i0_0,
             i1_0,
             i2_0,
             i3_0,
             i4_0_0,
             i0_1,
             i1_1,
             i2_1,
             i3_1,
             i4_0_1,
             i5_0,
             i6_0,
             i7_0,
             i8_0,
             i9_0_0,
             i0_2,
             i1_2,
             i2_2,
             i3_2,
             i4_0_2,
             i5_1,
             i6_1,
             i7_1,
             i8_1,
             i9_0_1,
             i0_3,
             i1_3,
             i2_3,
             i3_3,
             i4_0_3,
     ) in T.grid(
             1,
             1,
             2,
             1,
             1,
             1,
             4,
             1,
             14,
             1,
             1,
             1,
             4,
             1,
             1,
             1,
             4,
             7,
             1,
             1,
             1,
             1,
             1,
             4,
             1,
             1,
             1,
             4,
             4,
             1,
     ):
         with T.block("conv2d_NCHWc_int8_o"):
             n = T.axis.spatial(1, 0)
             oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2)
             oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3)
             ow = T.axis.spatial(56, i3_1 * 4 + i3_3)
             oc_block_o = T.axis.spatial(1, 0)
             kh = T.axis.reduce(1, 0)
             kw = T.axis.reduce(1, 0)
             ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1])
             ic_s_inner_o = T.axis.reduce(1, 0)
             T.reads(
                 placeholder[n, ic_outer, oh + kh, ow + kw,
                             ic_f_inner * 4:ic_f_inner * 4 + 4],
                 placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16,
                               0:4],
             )
             T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
             T.block_attr({"meta_schedule.auto_tensorize": "dot_16x4_vnni"})
             with T.init():
                 for i4_1 in T.serial(16):
                     with T.block("conv2d_NCHWc_int8_init"):
                         oc_block_init = T.axis.spatial(16, i4_1)
                         T.reads()
                         T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                                    oc_block_init])
                         conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                           oc_block_init] = 0
             for i4_1, i9_1 in T.grid(16, 4):
                 with T.block("conv2d_NCHWc_int8"):
                     oc_block, ic_s_inner = T.axis.remap("SR", [i4_1, i9_1])
                     T.reads(
                         conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block],
                         placeholder[n, ic_outer, oh + kh, ow + kw,
                                     ic_f_inner * 4 + ic_s_inner],
                         placeholder_1[oc_chunk, ic_outer, kh, kw,
                                       ic_f_inner, oc_block, ic_s_inner],
                     )
                     T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                                oc_block])
                     T.block_attr(
                         {"meta_schedule.tiling_structure": "SSRSRS"})
                     conv2d_NCHWc_int8[
                         n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
                             n, oc_chunk, oh, ow, oc_block] + T.cast(
                                 placeholder[n, ic_outer, oh + kh, ow + kw,
                                             ic_f_inner * 4 + ic_s_inner],
                                 "int32",
                             ) * T.cast(
                                 placeholder_1[oc_chunk, ic_outer, kh, kw,
                                               ic_f_inner, oc_block,
                                               ic_s_inner],
                                 "int32",
                             )
Exemple #22
0
 def main(
     placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
     placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
     conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid(
             1, 1, 2, 1, 1, 1, 4, 1, 14, 1, 1, 1):
         for i1_2_init, i2_2_init, i2_3_init, i3_3_init in T.grid(
                 4, 7, 4, 4):
             with T.block("conv2d_NCHWc_int8_o_init"):
                 n = T.axis.spatial(1, 0)
                 oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2_init)
                 oh = T.axis.spatial(56,
                                     i2_0 * 28 + i2_2_init * 4 + i2_3_init)
                 ow = T.axis.spatial(56, i3_1 * 4 + i3_3_init)
                 oc_block_o = T.axis.spatial(1, 0)
                 T.reads()
                 T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
                 for i4_1 in T.vectorized(16):
                     with T.block("conv2d_NCHWc_int8_init"):
                         oc_block_init = T.axis.spatial(16, i4_1)
                         T.reads()
                         T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                                    oc_block_init])
                         conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                           oc_block_init] = 0
         for (
                 i7_0,
                 i8_0,
                 i9_0_0,
                 i0_2,
                 i1_2,
                 i2_2,
                 i3_2,
                 i4_0_2,
                 i5_1,
                 i6_1,
                 i7_1,
                 i8_1,
                 i9_0_1,
                 i0_3,
                 i1_3,
                 i2_3,
                 i3_3,
                 i4_0_3,
         ) in T.grid(4, 1, 1, 1, 4, 7, 1, 1, 1, 1, 1, 4, 1, 1, 1, 4, 4, 1):
             with T.block("conv2d_NCHWc_int8_o_update"):
                 n = T.axis.spatial(1, 0)
                 oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2)
                 oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3)
                 ow = T.axis.spatial(56, i3_1 * 4 + i3_3)
                 oc_block_o = T.axis.spatial(1, 0)
                 kh = T.axis.reduce(1, 0)
                 kw = T.axis.reduce(1, 0)
                 ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1])
                 ic_s_inner_o = T.axis.reduce(1, 0)
                 T.reads(
                     conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16],
                     placeholder[n, ic_outer, oh + kh, ow + kw,
                                 ic_f_inner * 4:ic_f_inner * 4 + 4],
                     placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                                   0:16, 0:4],
                 )
                 T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
                 A = T.match_buffer(
                     placeholder[n, ic_outer, oh + kh, ow + kw,
                                 ic_f_inner * 4:ic_f_inner * 4 + 4],
                     [4],
                     dtype="uint8",
                     offset_factor=1,
                 )
                 B = T.match_buffer(
                     placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner,
                                   0:16, 0:4],
                     [16, 4],
                     dtype="int8",
                     offset_factor=1,
                 )
                 C = T.match_buffer(
                     conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16],
                     [16],
                     dtype="int32",
                     offset_factor=1,
                 )
                 A_u8x4 = A.vload([0], "uint8x4")
                 A_i32 = T.reinterpret(A_u8x4, dtype="int32")
                 B_i8x64 = B.vload([0, 0], dtype="int8x64")
                 B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
                 C[T.ramp(
                     0, 1,
                     16)] = C[T.ramp(0, 1, 16)] + T.call_llvm_pure_intrin(
                         T.llvm_lookup_intrinsic_id(
                             "llvm.x86.avx512.vpdpbusd.512"),
                         T.uint32(0),
                         T.broadcast(0, 16),
                         T.broadcast(A_i32, 16),
                         B_i32x16,
                         dtype="int32x16",
                     )
Exemple #23
0
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256],
                                    dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    reduce_temp1 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp1 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_maxelem_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.min_value("float32")
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_maxelem_normal_reduction"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([A[i0_1, k], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0],
                                                   A[i0_1, k])
            with T.block("T_softmax_maxelem_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: T.max(x, y),
                                   [T.min_value("float32")]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_maxelem_write_back"):
                i0_2 = T.axis.spatial(256, i0)
                T.reads([reduce_temp0[0]])
                T.writes([T_softmax_maxelem_shared[i0_2]])
                T_softmax_maxelem_shared[i0_2] = reduce_temp0[0]
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_expsum_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp1[0]])
                normal_reduce_temp1[0] = T.float32(0)
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_expsum_normal_reduction"):
                    i0_3 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([
                        A[i0_3, k],
                        T_softmax_maxelem_shared[i0_3],
                        normal_reduce_temp1[0],
                    ])
                    T.writes([normal_reduce_temp1[0]])
                    normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp(
                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
                        dtype="float32")
            with T.block("T_softmax_expsum_cross_thread_reduction"):
                T.reads([normal_reduce_temp1[0]])
                T.writes([reduce_temp1[0]])
                T.attr(
                    T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp1[0],
                        True,
                        reduce_temp1.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_expsum_write_back"):
                i0_4 = T.axis.spatial(256, i0)
                T.reads([reduce_temp1[0]])
                T.writes([T_softmax_expsum_shared[i0_4]])
                T_softmax_expsum_shared[i0_4] = reduce_temp1[0]
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_5 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads([
                        A[i0_5, i1],
                        T_softmax_maxelem_shared[i0_5],
                        T_softmax_expsum_shared[i0_5],
                    ])
                    T.writes([T_softmax_norm[i0_5, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_5, i1] = (T.exp(
                        A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
                        dtype="float32",
                    ) / T_softmax_expsum_shared[i0_5])
Exemple #24
0
 def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     A = T.match_buffer(var_A, [512, 512], dtype="float32")
     B = T.match_buffer(var_B, [512, 512], dtype="float32")
     C = T.match_buffer(var_C, [512, 512], dtype="float32")
     # body
     # with T.block("root")
     C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
     A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.x"):
                 for i2_0 in T.serial(0, 1):
                     for ax0_ax1_fused_0 in T.serial(0, 32768):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.x"):
                             with T.block("A_shared"):
                                 v0 = T.axis.spatial(
                                     512,
                                     (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1)
                                     // 512)
                                 v1 = T.axis.spatial(
                                     512,
                                     (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1)
                                     % 512)
                                 T.reads([A[v0, v1]])
                                 T.writes([A_shared[v0, v1]])
                                 T.block_attr(
                                     {"meta_schedule.cooperative_fetch": 1})
                                 A_shared[v0, v1] = A[v0, v1]
                     for ax0_ax1_fused_0 in T.serial(0, 1024):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.x"):
                             for ax0_ax1_fused_2 in T.vectorized(0, 2):
                                 with T.block("B_shared"):
                                     v0 = T.axis.spatial(
                                         512, (ax0_ax1_fused_0 * 16 +
                                               ax0_ax1_fused_1 * 2 +
                                               ax0_ax1_fused_2) // 32)
                                     v1 = T.axis.spatial(
                                         512, i0_0_i1_0_fused * 32 +
                                         (ax0_ax1_fused_0 * 16 +
                                          ax0_ax1_fused_1 * 2 +
                                          ax0_ax1_fused_2) % 32)
                                     T.reads([B[v0, v1]])
                                     T.writes([B_shared[v0, v1]])
                                     T.block_attr({
                                         "meta_schedule.cooperative_fetch":
                                         2
                                     })
                                     B_shared[v0, v1] = B[v0, v1]
                     for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(
                             16, 2, 2, 32, 16, 2):
                         with T.block("C"):
                             i = T.axis.spatial(
                                 512,
                                 i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
                             j = T.axis.spatial(
                                 512, i0_0_i1_0_fused * 32 +
                                 i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4)
                             k = T.axis.reduce(512, i2_1 * 32 + i2_2)
                             T.reads([
                                 C_local[i, j], A_shared[i, k], B_shared[k,
                                                                         j]
                             ])
                             T.writes([C_local[i, j]])
                             with T.init():
                                 C_local[i, j] = T.float32(0)
                             C_local[i, j] = C_local[
                                 i, j] + A_shared[i, k] * B_shared[k, j]
                 for ax0, ax1 in T.grid(32, 4):
                     with T.block("C_local"):
                         v0 = T.axis.spatial(512,
                                             i0_1_i1_1_fused * 32 + ax0)
                         v1 = T.axis.spatial(
                             512, i0_0_i1_0_fused * 32 +
                             i0_2_i1_2_fused * 4 + ax1)
                         T.reads([C_local[v0, v1]])
                         T.writes([C[v0, v1]])
                         C[v0, v1] = C_local[v0, v1]
Exemple #25
0
def tensorcore_gemm(a: T.handle, b: T.handle, c: T.handle) -> None:
    # match buffer
    A = T.match_buffer(a, [1024, 1024], "float16")
    B = T.match_buffer(b, [1024, 1024], "float16")
    C = T.match_buffer(c, [1024, 1024], "float32")

    # body
    for blockIdx_x in T.thread_binding(0, 16, "blockIdx.x"):
        for blockIdx_y in T.thread_binding(0, 8, "blockIdx.y"):
            with T.block([16, 8]) as [bx, by]:
                T.bind(bx, blockIdx_x)
                T.bind(by, blockIdx_y)
                shared_A = T.alloc_buffer([1024, 1024], "float16", scope="shared")
                shared_B = T.alloc_buffer([1024, 1024], "float16", scope="shared")
                wmma_A = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_a")
                wmma_B = T.alloc_buffer([1024, 1024], "float16", scope="wmma.matrix_b")
                wmma_C = T.alloc_buffer([1024, 1024], "float32", scope="wmma.accumulator")
                for ty in T.thread_binding(0, 2, "threadIdx.y"):
                    for tz in T.thread_binding(0, 2, "threadIdx.z"):
                        for i, j in T.grid(2, 4):
                            with T.block([64, 64]) as [vi, vj]:
                                T.bind(vi, bx * 4 + ty * 2 + i)
                                T.bind(vj, by * 8 + tz * 4 + j)
                                T.reads([])
                                T.writes(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                C0 = T.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_fill_fragment(
                                        C0.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        T.float32(0),
                                        dtype="handle",
                                    )
                                )

                        for ko in range(0, 32):
                            # copy data from global to shared
                            for tx in T.thread_binding(0, 32, "threadIdx.x"):
                                for i0, j0 in T.grid(1, 4):
                                    for j1 in T.vectorized(0, 4):
                                        with T.block([1024, 1024]) as [vi, vj]:
                                            T.bind(vi, bx * 64 + ty * 32 + tx + i0)
                                            T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_A[vi, vj + 8] = A[vi, vj]

                                for i0, j0 in T.grid(2, 4):
                                    for j1 in T.vectorized(0, 4):
                                        with T.block([1024, 1024]) as [vi, vj]:
                                            T.bind(vi, by * 128 + ty * 64 + tx * 2 + i0)
                                            T.bind(vj, ko * 32 + tz * 16 + j0 * 4 + j1)
                                            shared_B[vi, vj + 8] = B[vi, vj]

                            for ki in range(0, 2):
                                for i in range(0, 2):
                                    with T.block([64, 64]) as [vi, vk]:
                                        T.bind(vi, bx * 4 + ty * 2 + i)
                                        T.bind(vk, ko * 2 + ki)
                                        T.reads(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        T.writes(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = T.var("int32")
                                        s1 = T.var("int32")
                                        A0 = T.match_buffer(
                                            shared_A[
                                                vi * 16 : vi * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_A0 = T.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_A0.data,
                                                16,
                                                16,
                                                16,
                                                i,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(dtype="float16"),
                                                    A0.data,
                                                    A0.elem_offset + 8,
                                                    A0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                A0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            )
                                        )
                                for j in range(0, 4):
                                    with T.block([64, 64]) as [vj, vk]:
                                        T.bind(vj, by * 8 + tz * 4 + j)
                                        T.bind(vk, ko * 2 + ki)
                                        T.reads(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ]
                                        )
                                        T.writes(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16]
                                        )
                                        s0 = T.var("int32")
                                        s1 = T.var("int32")
                                        B0 = T.match_buffer(
                                            shared_B[
                                                vj * 16 : vj * 16 + 16,
                                                vk * 16 : vk * 16 + 16 + 8,
                                            ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[s0, s1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_B0 = T.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_B0.data,
                                                16,
                                                16,
                                                16,
                                                j,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(dtype="float16"),
                                                    B0.data,
                                                    B0.elem_offset + 8,
                                                    B0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                B0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            )
                                        )
                                for i, j in T.grid(2, 4):
                                    with T.block([64, 64, T.reduce_axis(0, 64)]) as [
                                        vi,
                                        vj,
                                        vk,
                                    ]:
                                        T.bind(vi, bx * 4 + ty * 2 + i)
                                        T.bind(vj, by * 8 + tz * 4 + j)
                                        T.bind(vk, ko * 2 + ki)
                                        T.reads(
                                            [
                                                wmma_A[
                                                    vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_B[
                                                    vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16
                                                ],
                                                wmma_C[
                                                    vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16
                                                ],
                                            ]
                                        )
                                        T.writes(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]
                                        )
                                        wmma_A1 = T.match_buffer(
                                            wmma_A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_B1 = T.match_buffer(
                                            wmma_B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_C1 = T.match_buffer(
                                            wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_mma_sync(
                                                wmma_C1.data,
                                                i * 4 + j,
                                                wmma_A1.data,
                                                i,
                                                wmma_B1.data,
                                                j,
                                                wmma_C1.data,
                                                i * 4 + j,
                                                dtype="handle",
                                            )
                                        )
                        for i, j in T.grid(2, 4):
                            with T.block([64, 64]) as [vi, vj]:
                                T.bind(vi, bx * 4 + ty * 2 + i)
                                T.bind(vj, by * 8 + tz * 4 + j)
                                T.reads(wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
                                s0 = T.var("int32")
                                s1 = T.var("int32")
                                wmma_C2 = T.match_buffer(
                                    wmma_C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                C1 = T.match_buffer(
                                    C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
                                    (16, 16),
                                    "float32",
                                    strides=[s0, s1],
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_store_matrix_sync(
                                        wmma_C2.data,
                                        16,
                                        16,
                                        16,
                                        i * 4 + j,
                                        T.tvm_access_ptr(
                                            T.type_annotation(dtype="float32"),
                                            C1.data,
                                            C1.elem_offset,
                                            C1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        C1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    )
                                )
 def main(
     A: T.Buffer[(512, 512), "float32"],
     B: T.Buffer[(512, 512), "float32"],
     C: T.Buffer[(512, 512), "float32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
     A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.y"):
                 for i2_0 in T.serial(0, 1):
                     for ax0_ax1_fused_0 in T.serial(0, 1024):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.y"):
                             for ax0_ax1_fused_2 in T.thread_binding(
                                     0, 32, thread="threadIdx.x"):
                                 with T.block("A_shared"):
                                     v0 = T.axis.spatial(
                                         512,
                                         (ax0_ax1_fused_0 * 256 +
                                          ax0_ax1_fused_1 * 32 +
                                          ax0_ax1_fused_2) // 512,
                                     )
                                     v1 = T.axis.spatial(
                                         512,
                                         (ax0_ax1_fused_0 * 256 +
                                          ax0_ax1_fused_1 * 32 +
                                          ax0_ax1_fused_2) % 512,
                                     )
                                     T.reads([A[v0, v1]])
                                     T.writes([A_shared[v0, v1]])
                                     A_shared[v0, v1] = A[v0, v1]
                     for ax0_ax1_fused_0 in T.serial(0, 32):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.y"):
                             for ax0_ax1_fused_2 in T.thread_binding(
                                     0, 32, thread="threadIdx.x"):
                                 for ax0_ax1_fused_3 in T.vectorized(0, 2):
                                     with T.block("B_shared"):
                                         v0 = T.axis.spatial(
                                             512,
                                             (ax0_ax1_fused_0 * 512 +
                                              ax0_ax1_fused_1 * 64 +
                                              ax0_ax1_fused_2 * 2 +
                                              ax0_ax1_fused_3) // 32,
                                         )
                                         v1 = T.axis.spatial(
                                             512,
                                             i0_0_i1_0_fused * 32 +
                                             (ax0_ax1_fused_0 * 512 +
                                              ax0_ax1_fused_1 * 64 +
                                              ax0_ax1_fused_2 * 2 +
                                              ax0_ax1_fused_3) % 32,
                                         )
                                         T.reads([B[v0, v1]])
                                         T.writes([B_shared[v0, v1]])
                                         B_shared[v0, v1] = B[v0, v1]
                     for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(
                             16, 2, 2, 32, 16, 2):
                         with T.block("C"):
                             i = T.axis.spatial(
                                 512,
                                 i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
                             j = T.axis.spatial(
                                 512,
                                 i0_0_i1_0_fused * 32 +
                                 i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4,
                             )
                             k = T.axis.reduce(
                                 512, i2_0 * 512 + i2_1 * 32 + i2_2)
                             T.reads([A_shared[i, k], B_shared[k, j]])
                             T.writes([C_local[i, j]])
                             T.block_attr({"warp_execution": 1})
                             with T.init():
                                 C_local[i, j] = T.float32(0)
                             C_local[i, j] = C_local[
                                 i, j] + A_shared[i, k] * B_shared[k, j]
                 for ax0, ax1 in T.grid(32, 4):
                     with T.block("C_local"):
                         v0 = T.axis.spatial(512,
                                             i0_1_i1_1_fused * 32 + ax0)
                         v1 = T.axis.spatial(
                             512, i0_0_i1_0_fused * 32 +
                             i0_2_i1_2_fused * 4 + ax1)
                         T.reads([C_local[v0, v1]])
                         T.writes([C[v0, v1]])
                         C[v0, v1] = C_local[v0, v1]
Exemple #27
0
def tensorcore_gemm(handle_a: T.handle, handle_b: T.handle,
                    handle_c: T.handle) -> None:
    # pylint: disable=missing-function-docstring
    # match buffer
    match_buffer_a = T.match_buffer(handle_a, [1024, 1024], "float16")
    match_buffer_b = T.match_buffer(handle_b, [1024, 1024], "float16")
    match_buffer_c = T.match_buffer(handle_c, [1024, 1024], "float32")

    # body
    for block_idx_x in T.thread_binding(0, 16, "blockIdx.x"):
        for block_idx_y in T.thread_binding(0, 8, "blockIdx.y"):
            with T.block():
                axis_bx, axis_by = T.axis.remap("SS",
                                                [block_idx_x, block_idx_y])
                shared_a = T.alloc_buffer([1024, 1024],
                                          "float16",
                                          scope="shared")
                shared_b = T.alloc_buffer([1024, 1024],
                                          "float16",
                                          scope="shared")
                wmma_a = T.alloc_buffer([1024, 1024],
                                        "float16",
                                        scope="wmma.matrix_a")
                wmma_b = T.alloc_buffer([1024, 1024],
                                        "float16",
                                        scope="wmma.matrix_b")
                wmma_c = T.alloc_buffer([1024, 1024],
                                        "float32",
                                        scope="wmma.accumulator")

                # pylint: disable=too-many-nested-blocks
                for thread_ty in T.thread_binding(0, 2, "threadIdx.y"):
                    for thread_tz in T.thread_binding(0, 2, "threadIdx.z"):
                        for index_i, index_jj in T.grid(2, 4):
                            with T.block():
                                new_axis_vi = T.axis.S(
                                    64, axis_bx * 4 + thread_ty * 2 + index_i)
                                new_axis_vj = T.axis.S(
                                    64, axis_by * 8 + thread_tz * 4 + index_jj)
                                T.reads([])
                                T.writes(wmma_c[new_axis_vi *
                                                16:new_axis_vi * 16 + 16,
                                                new_axis_vj *
                                                16:new_axis_vj * 16 + 16, ])
                                match_buffer_c0 = T.match_buffer(
                                    wmma_c[new_axis_vi * 16:new_axis_vi * 16 +
                                           16, new_axis_vj *
                                           16:new_axis_vj * 16 + 16, ],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_fill_fragment(
                                        match_buffer_c0.data,
                                        16,
                                        16,
                                        16,
                                        index_i * 4 + index_jj,
                                        T.float32(0),  # pylint: disable=not-callable
                                        dtype="handle",
                                    ))

                        for k_o in range(0, 32):
                            # copy data from global to shared
                            for thread_tx in T.thread_binding(
                                    0, 32, "threadIdx.x"):
                                for index_i0, index_j0 in T.grid(1, 4):
                                    for index_j1 in T.vectorized(0, 4):
                                        with T.block():
                                            new_axis_vi = T.axis.S(
                                                1024,
                                                axis_bx * 64 + thread_ty * 32 +
                                                thread_tx + index_i0,
                                            )
                                            new_axis_vj = T.axis.S(
                                                1024,
                                                k_o * 32 + thread_tz * 16 +
                                                index_j0 * 4 + index_j1,
                                            )
                                            shared_a[new_axis_vi, new_axis_vj +
                                                     8] = match_buffer_a[
                                                         new_axis_vi,
                                                         new_axis_vj]

                                for index_i0, index_j0 in T.grid(2, 4):
                                    for index_j1 in T.vectorized(0, 4):
                                        with T.block():
                                            new_axis_vi = T.axis.S(
                                                1024,
                                                axis_by * 128 +
                                                thread_ty * 64 +
                                                thread_tx * 2 + index_i0,
                                            )
                                            new_axis_vj = T.axis.S(
                                                1024,
                                                k_o * 32 + thread_tz * 16 +
                                                index_j0 * 4 + index_j1,
                                            )
                                            shared_b[new_axis_vi, new_axis_vj +
                                                     8] = match_buffer_b[
                                                         new_axis_vi,
                                                         new_axis_vj]

                            for k_i in range(0, 2):
                                for index_i in range(0, 2):
                                    with T.block():
                                        new_axis_vi = T.axis.S(
                                            64, axis_bx * 4 + thread_ty * 2 +
                                            index_i)
                                        axis_vk = T.axis.S(64, k_o * 2 + k_i)
                                        T.reads(shared_a[new_axis_vi *
                                                         16:new_axis_vi * 16 +
                                                         16, axis_vk *
                                                         16:axis_vk * 16 + 16 +
                                                         8, ])
                                        T.writes(
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ])
                                        stride0 = T.var("int32")
                                        stride1 = T.var("int32")
                                        match_buffer_a0 = T.match_buffer(
                                            shared_a[new_axis_vi *
                                                     16:new_axis_vi * 16 + 16,
                                                     axis_vk *
                                                     16:axis_vk * 16 + 16 +
                                                     8, ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[stride0, stride1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_a0 = T.match_buffer(
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_a0.data,
                                                16,
                                                16,
                                                16,
                                                index_i,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(
                                                        dtype="float16"),
                                                    match_buffer_a0.data,
                                                    match_buffer_a0.elem_offset
                                                    + 8,
                                                    match_buffer_a0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                match_buffer_a0.strides[0],
                                                "row_major",
                                                dtype="handle",
                                            ))
                                for index_jj in range(0, 4):
                                    with T.block():
                                        new_axis_vj = T.axis.S(
                                            64, axis_by * 8 + thread_tz * 4 +
                                            index_jj)
                                        axis_vk = T.axis.S(64, k_o * 2 + k_i)
                                        T.reads(shared_b[new_axis_vj *
                                                         16:new_axis_vj * 16 +
                                                         16, axis_vk *
                                                         16:axis_vk * 16 + 16 +
                                                         8, ])
                                        T.writes(
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ])
                                        stride0 = T.var("int32")
                                        stride1 = T.var("int32")
                                        match_buffer_b0 = T.match_buffer(
                                            shared_b[new_axis_vj *
                                                     16:new_axis_vj * 16 + 16,
                                                     axis_vk *
                                                     16:axis_vk * 16 + 16 +
                                                     8, ],
                                            (16, 16 + 8),
                                            "float16",
                                            strides=[stride0, stride1],
                                            scope="shared",
                                            offset_factor=1,
                                        )
                                        wmma_b0 = T.match_buffer(
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_load_matrix_sync(
                                                wmma_b0.data,
                                                16,
                                                16,
                                                16,
                                                index_jj,
                                                T.tvm_access_ptr(
                                                    T.type_annotation(
                                                        dtype="float16"),
                                                    match_buffer_b0.data,
                                                    match_buffer_b0.elem_offset
                                                    + 8,
                                                    match_buffer_b0.strides[0],
                                                    1,
                                                    dtype="handle",
                                                ),
                                                match_buffer_b0.strides[0],
                                                "col_major",
                                                dtype="handle",
                                            ))
                                for index_i, index_jj in T.grid(2, 4):
                                    with T.block():
                                        new_axis_vi = T.axis.S(
                                            64, axis_bx * 4 + thread_ty * 2 +
                                            index_i)
                                        new_axis_vj = T.axis.S(
                                            64, axis_by * 8 + thread_tz * 4 +
                                            index_jj)
                                        axis_vk = T.axis.R(64, k_o * 2 + k_i)
                                        T.reads([
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            wmma_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ],
                                        ])
                                        T.writes(
                                            wmma_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ])
                                        wmma_a1 = T.match_buffer(
                                            wmma_a[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_a",
                                            offset_factor=1,
                                        )
                                        wmma_b1 = T.match_buffer(
                                            wmma_b[new_axis_vj *
                                                   16:new_axis_vj * 16 + 16,
                                                   axis_vk * 16:axis_vk * 16 +
                                                   16, ],
                                            (16, 16),
                                            "float16",
                                            strides=[16, 1],
                                            scope="wmma.matrix_b",
                                            offset_factor=1,
                                        )
                                        wmma_c1 = T.match_buffer(
                                            wmma_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ],
                                            (16, 16),
                                            "float32",
                                            strides=[16 * 4, 1],
                                            scope="wmma.accumulator",
                                            offset_factor=1,
                                        )
                                        T.evaluate(
                                            T.tvm_mma_sync(
                                                wmma_c1.data,
                                                index_i * 4 + index_jj,
                                                wmma_a1.data,
                                                index_i,
                                                wmma_b1.data,
                                                index_jj,
                                                wmma_c1.data,
                                                index_i * 4 + index_jj,
                                                dtype="handle",
                                            ))
                        for index_i, index_jj in T.grid(2, 4):
                            with T.block():
                                new_axis_vi = T.axis.S(
                                    64, axis_bx * 4 + thread_ty * 2 + index_i)
                                new_axis_vj = T.axis.S(
                                    64, axis_by * 8 + thread_tz * 4 + index_jj)
                                T.reads(wmma_c[new_axis_vi *
                                               16:new_axis_vi * 16 + 16,
                                               new_axis_vj *
                                               16:new_axis_vj * 16 + 16, ])
                                T.writes(
                                    match_buffer_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ])
                                stride0 = T.var("int32")
                                stride1 = T.var("int32")
                                wmma_c2 = T.match_buffer(
                                    wmma_c[new_axis_vi * 16:new_axis_vi * 16 +
                                           16, new_axis_vj *
                                           16:new_axis_vj * 16 + 16, ],
                                    (16, 16),
                                    "float32",
                                    strides=[16 * 4, 1],
                                    scope="wmma.accumulator",
                                    offset_factor=1,
                                )
                                match_buffer_c1 = T.match_buffer(
                                    match_buffer_c[new_axis_vi *
                                                   16:new_axis_vi * 16 + 16,
                                                   new_axis_vj *
                                                   16:new_axis_vj * 16 + 16, ],
                                    (16, 16),
                                    "float32",
                                    strides=[stride0, stride1],
                                    offset_factor=1,
                                )
                                T.evaluate(
                                    T.tvm_store_matrix_sync(
                                        wmma_c2.data,
                                        16,
                                        16,
                                        16,
                                        index_i * 4 + index_jj,
                                        T.tvm_access_ptr(
                                            T.type_annotation(dtype="float32"),
                                            match_buffer_c1.data,
                                            match_buffer_c1.elem_offset,
                                            match_buffer_c1.strides[0],
                                            1,
                                            dtype="handle",
                                        ),
                                        match_buffer_c1.strides[0],
                                        "row_major",
                                        dtype="handle",
                                    ))
Exemple #28
0
def transformed_nested_pipeline_double_buffer(
        A: T.Buffer[(16, 16, 16), "float32"], C: T.Buffer[(16, 16, 16),
                                                          "float32"]) -> None:
    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
        with T.block():
            T.reads([A[tx, 0:16, 0:16]])
            T.writes([C[tx, 0:16, 0:16]])
            A_shared = T.alloc_buffer([16, 1, 16],
                                      dtype="float32",
                                      scope="shared")
            A_local = T.alloc_buffer([2, 1, 1, 16],
                                     dtype="float32",
                                     scope="local")
            B = T.alloc_buffer([2, 16, 1, 1], dtype="float32", scope="shared")
            with T.block():
                T.reads([
                    A[tx, 0, 0:16], A_shared[tx, 0, 0:16], A_local[0, tx, 0, 0]
                ])
                T.writes([
                    A_shared[tx, 0, 0:16], A_local[0, 0, 0, 0:16], B[0, tx, 0,
                                                                     0]
                ])
                with T.block():
                    T.reads([A[tx, 0, 0:16]])
                    T.writes([A_shared[tx, 0, 0:16]])
                    for j in T.serial(0, 16):
                        with T.block():
                            T.reads([A[tx, 0, j]])
                            T.writes([A_shared[tx, 0, j]])
                            A_shared[tx, 0, j] = A[tx, 0, j]
                with T.block():
                    T.reads([A_shared[tx, 0, 0:16]])
                    T.writes([A_local[0, 0, 0, 0:16]])
                    for j in T.serial(0, 16):
                        with T.block():
                            T.reads([A_shared[tx, 0, j]])
                            T.writes([A_local[0, 0, 0, j]])
                            T.block_attr({"double_buffer_scope": 0})
                            A_local[0, 0, 0, j] = A_shared[tx, 0, j]
                with T.block():
                    T.reads([A_local[0, tx, 0, 0]])
                    T.writes([B[0, tx, 0, 0]])
                    B[0, tx, 0, 0] = A_local[0, 0, 0, 0] * T.float32(2)
            with T.block():
                T.reads([
                    A[tx, 1:16, 0:16],
                    A_local[0:2, tx, 0:16, 0:16],
                    B[0:2, tx, 0:15, 0],
                    A_shared[tx, 0, 0:16],
                ])
                T.writes([
                    A_shared[tx, 0, 0:16],
                    B[0:2, tx, 0:16, 0],
                    C[tx, 0:15, 0:16],
                    A_local[0:2, 0, 0, 0:16],
                ])
                for i in T.serial(0, 15):
                    with T.block():
                        T.reads([A[tx, i + 1, 0:16]])
                        T.writes([A_shared[tx, 0, 0:16]])
                        for j in T.serial(0, 16):
                            with T.block():
                                T.reads([A[tx, i + 1, j]])
                                T.writes([A_shared[tx, 0, j]])
                                A_shared[tx, 0, j] = A[tx, i + 1, j]
                    with T.block():
                        T.reads(
                            [A_local[i % 2, tx, i, 1:16], B[0:2, tx, i, 0]])
                        T.writes([B[0:2, tx, i, 0], C[tx, i, 0:15]])
                        for j in T.serial(0, 15):
                            with T.block():
                                T.reads([A_local[i % 2, tx, i, j + 1]])
                                T.writes([B[(j + 1) % 2, tx, i, 0]])
                                B[(j + 1) % 2, tx, i,
                                  0] = A_local[i % 2, 0, 0,
                                               j + 1] * T.float32(2)
                            with T.block():
                                T.reads([B[j % 2, tx, i, 0]])
                                T.writes([C[tx, i, j]])
                                C[tx, i, j] = B[j % 2, tx, i, 0] + T.float32(1)
                    with T.block():
                        T.reads([A_shared[tx, 0, 0:16]])
                        T.writes([A_local[(i + 1) % 2, 0, 0, 0:16]])
                        for j in T.serial(0, 16):
                            with T.block():
                                T.reads([A_shared[tx, 0, j]])
                                T.writes([A_local[(i + 1) % 2, 0, 0, j]])
                                T.block_attr({"double_buffer_scope": 0})
                                A_local[(i + 1) % 2, 0, 0,
                                        j] = A_shared[tx, i + 1, j]
                    with T.block():
                        T.reads([A_local[(i + 1) % 2, tx, i + 1, 0]])
                        T.writes([B[0, tx, i + 1, 0]])
                        B[0, tx, i + 1,
                          0] = A_local[(i + 1) % 2, 0, 0, 0] * T.float32(2)
                    with T.block():
                        T.reads([B[1, tx, i, 0]])
                        T.writes([C[tx, i, 15]])
                        C[tx, i, 15] = B[1, tx, i, 0] + T.float32(1)
            with T.block():
                T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]])
                T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:16]])
                with T.block():
                    T.reads([A_local[1, tx, 15, 1:16], B[0:2, tx, 15, 0]])
                    T.writes([B[0:2, tx, 15, 0], C[tx, 15, 0:15]])
                    for j in T.serial(0, 15):
                        with T.block():
                            T.reads([A_local[1, tx, 15, j + 1]])
                            T.writes([B[(j + 1) % 2, tx, 15, 0]])
                            B[(j + 1) % 2, tx, 15,
                              0] = A_local[1, 0, 0, j + 1] * T.float32(2)
                        with T.block():
                            T.reads([B[j % 2, tx, 15, 0]])
                            T.writes([C[tx, 15, j]])
                            C[tx, 15, j] = B[j % 2, tx, 15, 0] + T.float32(1)
                with T.block():
                    T.reads([B[1, tx, 15, 0]])
                    T.writes([C[tx, 15, 15]])
                    C[tx, 15, 15] = B[1, tx, 15, 0] + T.float32(1)
Exemple #29
0
def conv2d_winograd_cuda(  # type: ignore
    placeholder: T.Buffer[(1, 14, 14, 128), "float32"],  # type: ignore
    placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"],  # type: ignore
    conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"],  # type: ignore
) -> None:
    # type: ignore
    data_pad = T.alloc_buffer([1, 16, 16, 128])
    input_tile = T.alloc_buffer([6, 6, 9, 128])
    B = T.alloc_buffer([6, 6])
    data_pack = T.alloc_buffer([6, 6, 9, 128])
    bgemm = T.alloc_buffer([6, 6, 9, 128])
    A = T.alloc_buffer([6, 4])
    inverse = T.alloc_buffer([4, 4, 9, 128])
    for i0, i1, i2, i3 in T.grid(1, 16, 16, 128):
        with T.block("data_pad"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.block_attr({"schedule_rule": "None"})
            T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]])
            T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]])
            data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
                0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14,  # type: ignore
                placeholder[i0_1, i1_1, i2_1, i3_1],
                T.float32(0),
                dtype="float32",
            )
    for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128):
        with T.block("input_tile"):
            eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2])
            T.block_attr({"schedule_rule": "None"})
            T.reads(
                [
                    data_pad[
                        T.floordiv(p, 9),  # type: ignore
                        ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                        ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                        ci,
                    ]
                ]
            )
            T.writes([input_tile[eps, nu, p, ci]])
            input_tile[eps, nu, p, ci] = data_pad[
                T.floordiv(p, 9),  # type: ignore
                ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                ci,
            ]
    for i0_3, i1_3 in T.grid(6, 6):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0_3, i1_3])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([B[i, j]])
            # fmt: off
            B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6):
        with T.block("data_pack"):
            eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap(
                "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"})
            T.reads(
                [
                    data_pack[eps_1, nu_1, p_1, ci_1],
                    input_tile[r_a, r_b, p_1, ci_1],
                    B[
                        T.min(r_a, r_b) : (  # type: ignore
                            T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))  # type: ignore
                        ),
                        T.min(eps_1, nu_1) : (  # type: ignore
                            T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1))  # type: ignore
                        ),
                    ],
                ]
            )
            T.writes([data_pack[eps_1, nu_1, p_1, ci_1]])
            with T.init():
                data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0)
            data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + (
                (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1]
            )
    for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128):
        with T.block("bgemm"):
            eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1])
            T.block_attr({"meta_schedule.write_cache_level": [3]})
            T.reads(
                [
                    bgemm[eps_2, nu_2, p_2, co],
                    data_pack[eps_2, nu_2, p_2, ci_2],
                    placeholder_1[eps_2, nu_2, co, ci_2],
                ]
            )
            T.writes([bgemm[eps_2, nu_2, p_2, co]])
            with T.init():
                bgemm[eps_2, nu_2, p_2, co] = T.float32(0)
            bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + (
                data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2]
            )
    for i0_6, i1_6 in T.grid(6, 4):
        with T.block("A"):
            i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([A[i_1, j_1]])
            # fmt: off
            A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6):
        with T.block("inverse"):
            vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap(
                "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"})
            T.reads(
                [
                    inverse[vh, vw, p_3, co_1],
                    bgemm[r_a_1, r_b_1, p_3, co_1],
                    A[
                        T.min(r_a_1, r_b_1) : (  # type: ignore
                            T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))  # type: ignore
                        ),
                        T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))),  # type: ignore
                    ],
                ]
            )
            T.writes([inverse[vh, vw, p_3, co_1]])
            with T.init():
                inverse[vh, vw, p_3, co_1] = T.float32(0)
            inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + (
                (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw]
            )
    for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128):
        with T.block("conv2d_winograd"):
            n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6])
            T.reads(
                [
                    inverse[
                        T.floormod(h, 4),  # type: ignore
                        T.floormod(w, 4),  # type: ignore
                        (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                        co_2,
                    ]
                ]
            )
            T.writes([conv2d_winograd[n, h, w, co_2]])
            conv2d_winograd[n, h, w, co_2] = inverse[
                T.floormod(h, 4),  # type: ignore
                T.floormod(w, 4),  # type: ignore
                (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                co_2,
            ]
 def main(placeholder: T.Buffer[(1, 13, 13, 3, 85), "float32"], placeholder_1: T.Buffer[(1, 26, 26, 3, 85), "float32"], placeholder_2: T.Buffer[(1, 52, 52, 3, 85), "float32"], T_expand_dims: T.Buffer[(1, 80, 10647), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     T_strided_slice_with_axes = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
     T_sigmoid = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32")
     T_strided_slice_with_axes_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
     T_sigmoid_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
     T_multiply = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32")
     T_reshape = T.alloc_buffer([8112, 80], dtype="float32")
     T_strided_slice_with_axes_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32")
     T_sigmoid_2 = T.alloc_buffer([1, 26, 26, 3, 1], dtype="float32")
     T_strided_slice_with_axes_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32")
     T_sigmoid_3 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32")
     T_multiply_1 = T.alloc_buffer([1, 26, 26, 3, 80], dtype="float32")
     T_reshape_1 = T.alloc_buffer([2028, 80], dtype="float32")
     T_strided_slice_with_axes_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32")
     T_sigmoid_4 = T.alloc_buffer([1, 13, 13, 3, 1], dtype="float32")
     T_strided_slice_with_axes_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32")
     T_sigmoid_5 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32")
     T_multiply_2 = T.alloc_buffer([1, 13, 13, 3, 80], dtype="float32")
     T_reshape_2 = T.alloc_buffer([507, 80], dtype="float32")
     T_concat = T.alloc_buffer([10647, 80], dtype="float32")
     T_transpose = T.alloc_buffer([80, 10647], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1):
         with T.block("T_strided_slice_with_axes"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)])
             T.writes(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 1):
         with T.block("T_sigmoid"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80):
         with T.block("T_strided_slice_with_axes_1"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)])
             T.writes(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4] = placeholder_2[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80):
         with T.block("T_sigmoid_1"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_1[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_1[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_1[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 52, 52, 3, 80):
         with T.block("T_multiply"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_sigmoid[ax0, ax1, ax2, ax3, 0], T_sigmoid_1[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_multiply[ax0, ax1, ax2, ax3, ax4])
             T_multiply[ax0, ax1, ax2, ax3, ax4] = T_sigmoid[ax0, ax1, ax2, ax3, 0] * T_sigmoid_1[ax0, ax1, ax2, ax3, ax4]
     for i0, i1 in T.grid(8112, 80):
         with T.block("T_reshape"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80])
             T.writes(T_reshape[ax0, ax1])
             T_reshape[ax0, ax1] = T_multiply[0, (ax1 // 80 + ax0) % 8112 // 156, (ax1 // 80 + ax0) % 156 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1):
         with T.block("T_strided_slice_with_axes_2"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)])
             T.writes(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 1):
         with T.block("T_sigmoid_2"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_2[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_2[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_2[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80):
         with T.block("T_strided_slice_with_axes_3"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)])
             T.writes(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4] = placeholder_1[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80):
         with T.block("T_sigmoid_3"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_3[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_3[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_3[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 26, 26, 3, 80):
         with T.block("T_multiply_1"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_sigmoid_2[ax0, ax1, ax2, ax3, 0], T_sigmoid_3[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_multiply_1[ax0, ax1, ax2, ax3, ax4])
             T_multiply_1[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_2[ax0, ax1, ax2, ax3, 0] * T_sigmoid_3[ax0, ax1, ax2, ax3, ax4]
     for i0, i1 in T.grid(2028, 80):
         with T.block("T_reshape_1"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80])
             T.writes(T_reshape_1[ax0, ax1])
             T_reshape_1[ax0, ax1] = T_multiply_1[0, (ax1 // 80 + ax0) % 2028 // 78, (ax1 // 80 + ax0) % 78 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1):
         with T.block("T_strided_slice_with_axes_4"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)])
             T.writes(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(4)]
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 1):
         with T.block("T_sigmoid_4"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_4[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_4[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_4[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80):
         with T.block("T_strided_slice_with_axes_5"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)])
             T.writes(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4])
             T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4] = placeholder[ax0, ax1, ax2, ax3, T.cast(ax4, "int64") + T.int64(5)]
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80):
         with T.block("T_sigmoid_5"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_sigmoid_5[ax0, ax1, ax2, ax3, ax4])
             T_sigmoid_5[ax0, ax1, ax2, ax3, ax4] = T.sigmoid(T_strided_slice_with_axes_5[ax0, ax1, ax2, ax3, ax4], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 13, 13, 3, 80):
         with T.block("T_multiply_2"):
             ax0, ax1, ax2, ax3, ax4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
             T.reads(T_sigmoid_4[ax0, ax1, ax2, ax3, 0], T_sigmoid_5[ax0, ax1, ax2, ax3, ax4])
             T.writes(T_multiply_2[ax0, ax1, ax2, ax3, ax4])
             T_multiply_2[ax0, ax1, ax2, ax3, ax4] = T_sigmoid_4[ax0, ax1, ax2, ax3, 0] * T_sigmoid_5[ax0, ax1, ax2, ax3, ax4]
     for i0, i1 in T.grid(507, 80):
         with T.block("T_reshape_2"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80])
             T.writes(T_reshape_2[ax0, ax1])
             T_reshape_2[ax0, ax1] = T_multiply_2[0, (ax1 // 80 + ax0) % 507 // 39, (ax1 // 80 + ax0) % 39 // 3, (ax1 // 80 + ax0) % 3, ax1 % 80]
     for i0, i1 in T.grid(10647, 80):
         with T.block("T_concat"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_reshape[ax0 - 2535, ax1], T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1])
             T.writes(T_concat[ax0, ax1])
             T_concat[ax0, ax1] = T.if_then_else(2535 <= ax0, T_reshape[ax0 - 2535, ax1], T.if_then_else(507 <= ax0, T_reshape_1[ax0 - 507, ax1], T_reshape_2[ax0, ax1], dtype="float32"), dtype="float32")
     for i0, i1 in T.grid(80, 10647):
         with T.block("T_transpose"):
             ax0, ax1 = T.axis.remap("SS", [i0, i1])
             T.reads(T_concat[ax1, ax0])
             T.writes(T_transpose[ax0, ax1])
             T_transpose[ax0, ax1] = T_concat[ax1, ax0]
     for i0, i1, i2 in T.grid(1, 80, 10647):
         with T.block("T_expand_dims"):
             ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
             T.reads(T_transpose[ax1, ax2])
             T.writes(T_expand_dims[ax0, ax1, ax2])
             T_expand_dims[ax0, ax1, ax2] = T_transpose[ax1, ax2]