def element_wise_invalid_annotation(a: T.handle, c: T.handle) -> None:
    C = T.match_buffer(c, [128, 128],
                       elem_offset=0,
                       align=128,
                       offset_factor=1)
    A = T.match_buffer(a, [128, 128],
                       elem_offset=0,
                       align=128,
                       offset_factor=1)
    # body
    with T.block("root"):
        T.reads([])
        T.writes([])
        B = T.alloc_buffer([128, 128],
                           elem_offset=0,
                           align=128,
                           offset_factor=1)
        for i0 in T.serial(0, 128):
            for ax1 in T.serial(0, 128):
                with T.block("B"):
                    T.block_attr({"buffer_dim_align": [0]})
                    vi, vj = T.axis.remap("SS", [i0, ax1])
                    T.reads([A[vi, vj]])
                    T.writes([B[vi, vj]])
                    B[vi, vj] = (A[vi, vj] * T.float32(2))
            for i1 in T.serial(0, 128):
                with T.block("C"):
                    vi_1, vj_1 = T.axis.remap("SS", [i0, i1])
                    T.reads([B[vi_1, vj_1]])
                    T.writes([C[vi_1, vj_1]])
                    C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
示例#2
0
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]])
                    T.writes([T_softmax_maxelem_shared[i0_1]])
                    with T.init():
                        T_softmax_maxelem_shared[i0_1] = T.min_value("float32")
                    T_softmax_maxelem_shared[i0_1] = T.max(
                        T_softmax_maxelem_shared[i0_1], A[i0_1, k]
                    )
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_expsum"):
                    i0_2 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads(
                        [
                            T_softmax_expsum_shared[i0_2],
                            A[i0_2, k],
                            T_softmax_maxelem_shared[i0_2],
                        ]
                    )
                    T.writes([T_softmax_expsum_shared[i0_2]])
                    with T.init():
                        T_softmax_expsum_shared[i0_2] = T.float32(0)
                    T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
                        A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32"
                    )
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_3 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads(
                        [
                            A[i0_3, i1],
                            T_softmax_maxelem_shared[i0_3],
                            T_softmax_expsum_shared[i0_3],
                        ]
                    )
                    T.writes([T_softmax_norm[i0_3, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_3, i1] = (
                        T.exp(
                            A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
                            dtype="float32",
                        )
                        / T_softmax_expsum_shared[i0_3]
                    )
def GmmCuda2(X: T.Buffer[(1, 128, 128), "float32"], Y: T.Buffer[(1, 128, 128), "float32"], Z: T.Buffer[(1, 128, 128), "float32"]) -> None:
    Z_local = T.alloc_buffer([1, 128, 128], dtype="float32", scope="local")
    X_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
    Y_shared = T.alloc_buffer([1, 128, 128], dtype="float32", scope="shared")
    for i0_0_i1_0_i2_0_fused in T.thread_binding(16, thread="blockIdx.x"):
        for i0_1_i1_1_i2_1_fused in T.thread_binding(1, thread="vthread.x"):
            for i0_2_i1_2_i2_2_fused in T.thread_binding(128, thread="threadIdx.x"):
                for i1_3_init, i2_4_init in T.grid(4, 2):
                    with T.block("Z_init"):
                        b = T.axis.spatial(1, 0)
                        i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3_init)
                        j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4_init)
                        T.reads()
                        T.writes(Z_local[b, i, j])
                        Z_local[b, i, j] = T.float32(0)
                for i3_0 in T.serial(4):
                    for ax0_ax1_ax2_fused_0 in T.serial(4):
                        for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
                            for ax0_ax1_ax2_fused_2 in T.vectorized(2):
                                with T.block("X_shared"):
                                    v0 = T.axis.spatial(1, 0)
                                    v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) // 32)
                                    v2 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 256 + ax0_ax1_ax2_fused_1 * 2 + ax0_ax1_ax2_fused_2) % 32)
                                    T.reads(X[v0, v1, v2])
                                    T.writes(X_shared[v0, v1, v2])
                                    X_shared[v0, v1, v2] = X[v0, v1, v2]
                    for ax0_ax1_ax2_fused_0 in T.serial(8):
                        for ax0_ax1_ax2_fused_1 in T.thread_binding(128, thread="threadIdx.x"):
                            with T.block("Y_shared"):
                                v0 = T.axis.spatial(1, 0)
                                v1 = T.axis.spatial(128, i3_0 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) // 32)
                                v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + (ax0_ax1_ax2_fused_0 * 128 + ax0_ax1_ax2_fused_1) % 32)
                                T.reads(Y[v0, v1, v2])
                                T.writes(Y_shared[v0, v1, v2])
                                Y_shared[v0, v1, v2] = Y[v0, v1, v2]
                    for i3_1, i0_3, i1_3, i2_3, i3_2, i0_4, i1_4, i2_4 in T.grid(1, 1, 4, 1, 32, 1, 1, 2):
                        with T.block("Z_update"):
                            b = T.axis.spatial(1, 0)
                            i = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + i1_3)
                            j = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + i2_4)
                            k = T.axis.reduce(128, i3_0 * 32 + i3_2)
                            T.block_attr({
                                "meta_schedule.thread_extent_low_inclusive": 1024,
                                "meta_schedule.thread_extent_high_inclusive": 1024,
                            })
                            T.reads(Z_local[b, i, j], X_shared[b, i, k], Y_shared[b, k, j])
                            T.writes(Z_local[b, i, j])
                            Z_local[b, i, j] = Z_local[b, i, j] + X_shared[b, i, k] * Y_shared[b, k, j]
                for ax0, ax1, ax2 in T.grid(1, 4, 2):
                    with T.block("Z_local"):
                        v0 = T.axis.spatial(1, ax0)
                        v1 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused // 4 * 32 + i0_2_i1_2_i2_2_fused // 16 * 4 + ax1)
                        v2 = T.axis.spatial(128, i0_0_i1_0_i2_0_fused % 4 * 32 + i0_2_i1_2_i2_2_fused % 16 * 2 + ax2)
                        T.reads(Z_local[v0, v1, v2])
                        T.writes(Z[v0, v1, v2])
                        Z[v0, v1, v2] = Z_local[v0, v1, v2]
def matmul_with_annotation(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [128, 128])
    B = T.match_buffer(b, [128, 128])
    C = T.match_buffer(c, [128, 128])
    for i, j, k in T.grid(128, 128, 128):
        with T.block("update"):
            T.block_attr({"test_annotation": 1})
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
示例#5
0
def square_sum_with_annotation(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    C = T.match_buffer(c, [16])

    for b0, i0, j0 in T.grid(16, 256, 256):
        with T.block("C"):
            T.block_attr({"test_annotation": 1})
            b, i, j = T.axis.remap("SRR", [b0, i0, j0])
            with T.init():
                C[b] = 0.0
            C[b] = C[b] + A[b, i, j] * A[b, i, j]
def annotated_matmul(
    A: T.Buffer[(128, 128), "float32"],
    B: T.Buffer[(128, 128), "float32"],
    C: T.Buffer[(128, 128), "float32"],
) -> None:
    for i, j, k in T.grid(128, 128, 128):
        with T.block("update"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            T.block_attr({"test_annotation": True})
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
示例#7
0
def nested_pipeline_double_buffer(A: T.Buffer[(16, 16, 16), "float32"],
                                  C: T.Buffer[(16, 16, 16), "float32"]):
    for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
        for i in T.serial(
                0,
                16,
                annotations={
                    "software_pipeline_stage": [0, 0, 0, 1, 1],
                    "software_pipeline_order": [0, 2, 3, 1, 4],
                },
        ):
            with T.block():
                T.reads(A[tx, i, 0:16])
                T.writes(C[tx, i, 0:16])
                A_shared = T.alloc_buffer((16, 1, 16),
                                          dtype="float32",
                                          scope="shared")
                A_local = T.alloc_buffer((1, 1, 16),
                                         dtype="float32",
                                         scope="local")
                for j in T.serial(0, 16):
                    with T.block():
                        T.reads(A[tx, i, j])
                        T.writes(A_shared[tx, 0, j])
                        A_shared[tx, 0, j] = A[tx, i, j]
                for j in T.serial(0, 16):
                    with T.block():
                        T.block_attr({"double_buffer_scope": 0})
                        T.reads(A_shared[tx, 0, j])
                        T.writes(A_local[0, 0, j])
                        A_local[0, 0, j] = A_shared[tx, i, j]
                for j in T.serial(
                        0,
                        16,
                        annotations={
                            "software_pipeline_stage": [0, 1],
                            "software_pipeline_order": [0, 1],
                        },
                ):
                    with T.block():
                        T.reads(A_local[0, 0, j])
                        T.writes(C[tx, i, j])
                        B = T.alloc_buffer((16, 1, 1),
                                           dtype="float32",
                                           scope="shared")
                        with T.block():
                            T.reads(A_local[tx, i, j])
                            T.writes(B[tx, i, 0])
                            B[tx, i, 0] = A_local[0, 0, j] * T.float32(2)
                        with T.block():
                            T.reads(B[tx, i, 0])
                            T.writes(C[tx, i, j])
                            C[tx, i, j] = B[tx, i, 0] + T.float32(1)
def annotated_mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
    B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
    C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)

    with T.block("root"):
        T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16])
        T.writes(C[0:16, 0:16])
        for i, j, k in T.grid(16, 16, 16):
            with T.block("update"):
                T.block_attr({"test_annotation": True})
                vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
                C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
示例#9
0
def single_reduction_loop_with_block_predicate(
        A: T.Buffer[(256, 256), "float32"],
        T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    for i0 in T.serial(256):
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k])
                    T.writes(T_softmax_maxelem_shared[i0_1])
                    with T.init():
                        T_softmax_maxelem_shared[i0_1] = T.float32(
                            -3.4028234663852886e38)
                    T_softmax_maxelem_shared[i0_1] = T.max(
                        T_softmax_maxelem_shared[i0_1], A[i0_1, k])
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_expsum"):
                    i0_2 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k],
                            T_softmax_maxelem_shared[i0_2])
                    T.writes(T_softmax_expsum_shared[i0_2])
                    with T.init():
                        T_softmax_expsum_shared[i0_2] = T.float32(0)
                    T_softmax_expsum_shared[
                        i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
                            A[i0_2, k] - T_softmax_maxelem_shared[i0_2],
                            dtype="float32")
        for i1_0 in T.serial(1):
            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_3 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_1)
                    T.where(i1_0 * 512 + i1_1 < 256)
                    T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3],
                            T_softmax_expsum_shared[i0_3])
                    T.writes(T_softmax_norm[i0_3, i1])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_3, i1] = (
                        T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
                              dtype="float32") / T_softmax_expsum_shared[i0_3])
 def main(a: T.handle, b: T.handle, c: T.handle) -> None:
     T.func_attr({"global_symbol": "main"})
     A = T.match_buffer(a, (1024, 1024), "float32")
     B = T.match_buffer(b, (1024, 1024), "float32")
     C = T.match_buffer(c, (1024, 1024), "float32")
     with T.block("root"):
         T.reads([])
         T.writes([])
         T.block_attr({"meta_schedule.parallel": 128, "meta_schedule.vectorize": 16, "meta_schedule.unroll_explicit": 2})
         for i, j, k in T.grid(1024, 1024, 1024):
             with T.block("matmul"):
                 vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                 with T.init():
                     C[vi, vj] = 0.0
                 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
示例#11
0
 def main(a: T.handle, b: T.handle, c: T.handle) -> None:
     T.func_attr({"global_symbol": "main"})
     A = T.match_buffer(a, (1024, 1024), "float32")
     B = T.match_buffer(b, (1024, 1024), "float32")
     C = T.match_buffer(c, (1024, 1024), "float32")
     with T.block("root"):
         for i, j, k in T.grid(1024, 1024, 1024):
             with T.block("matmul"):
                 T.block_attr({
                     "schedule_rule":
                     "tvm.meta_schedule.test.custom_search_space"
                 })
                 vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                 with T.init():
                     C[vi, vj] = 0.0
                 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
示例#12
0
 def main(
     placeholder: T.Buffer[(1, 1, 16, 16, 3), "float32"],
     placeholder_1: T.Buffer[(2, 1, 5, 5, 3, 4), "float32"],
     conv2d_NCHWc: T.Buffer[(1, 2, 16, 16, 4), "float32"]
 ) -> None:  # type: ignore
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     data_pad = T.alloc_buffer([1, 1, 20, 20, 3], dtype="float32")
     for i0, i1, i2, i3, i4 in T.grid(1, 1, 20, 20, 3):
         with T.block("data_pad"):
             i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap(
                 "SSSSS", [i0, i1, i2, i3, i4])
             T.reads(placeholder[i0_1, i1_1, i2_1 - 2, i3_1 - 2, i4_1])
             T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1])
             data_pad[i0_1, i1_1, i2_1, i3_1,
                      i4_1] = T.if_then_else(2 <= i2_1 and i2_1 < 18
                                             and 2 <= i3_1 and i3_1 < 18,
                                             placeholder[i0_1, i1_1,
                                                         i2_1 - 2, i3_1 - 2,
                                                         i4_1],
                                             T.float32(0),
                                             dtype="float32")  # type: ignore # pylint: disable=R1716
     for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(1, 2, 16, 16, 4, 3, 5, 5):
         with T.block("conv2d_NCHWc"):
             n, oc_chunk, oh, ow, oc_block, ic, kh, kw = T.axis.remap(
                 "SSSSSRRR", [i0, i1, i2, i3, i4, i5, i6, i7])
             T.reads(data_pad[n, ic // 3, oh + kh, ow + kw, ic % 3],
                     placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3,
                                   oc_block])  # type: ignore
             T.writes(conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block])
             T.block_attr({
                 "workload": [
                     "conv2d_NCHWc.x86",
                     ["TENSOR", [1, 1, 16, 16, 3], "float32"],
                     ["TENSOR", [2, 1, 5, 5, 3, 4], "float32"], [1, 1],
                     [2, 2, 2, 2], [1, 1], "NCHW3c", "NCHW4c", "float32"
                 ]
             })
             with T.init():
                 conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = T.float32(0)
             conv2d_NCHWc[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc[
                 n, oc_chunk, oh, ow, oc_block] + data_pad[
                     n, ic // 3, oh + kh, ow + kw, ic %
                     3] * placeholder_1[oc_chunk, ic // 3, kh, kw, ic % 3,
                                        oc_block]  # type: ignore
 def main(a: T.handle, b: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main"})
     A = T.match_buffer(a, [1024, 1024, 1024], dtype="float32")
     B = T.match_buffer(b, [1024, 1024, 1024], dtype="float32")
     # body
     with T.block("root"):
         T.block_attr({"meta_schedule.parallel":128, "meta_schedule.vectorize":32})
         for i0, j0, i1, j1, k0, i2, j2, k1 in T.grid(128, 64, 4, 4, 64, 4, 8, 32):
             with T.block("move"):
                 vi = T.axis.spatial(1024, i0 * 16 + i1 * 4 + i2)
                 vj = T.axis.spatial(1024, j0 * 32 + j1 * 8 + j2)
                 vk = T.axis.spatial(1024, k0 * 32 + k1)
                 T.where((i0 * 4 + i1) * 4 + i2 < 1024 and (j0 * 4 + j1) * 8 + j2 < 1024 and k0 * 32 + k1 < 1024)
                 T.reads([A[vi, vj, vk]])
                 T.writes([B[vi, vj, vk]])
                 B[vi, vj, vk] = A[vi, vj, vk]
示例#14
0
def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None:
    A = T.match_buffer(a, (1024, 1024))
    B = T.match_buffer(b, (1024, 1024))
    C = T.alloc_buffer((1024, 1024))
    D = T.match_buffer(d, (1024, 1024))
    for i, j, k in T.grid(1024, 1024, 1024):
        with T.block("matmul"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            T.block_attr({"test1": "aaa"})
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
    for i, j in T.grid(1024, 1024):
        with T.block("relu"):
            vi, vj = T.axis.remap("SS", [i, j])
            T.block_attr({"test2": 0.22, "test3": ["aa", 1]})
            D[vi, vj] = T.max(C[vi, vj], 0.0)
示例#15
0
def compacted_storage_align_func(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")
    C = T.match_buffer(c, (16, 16), "float32")
    for i in range(0, 16):
        with T.block():
            T.reads(A[i, 0:16])
            T.writes(C[i, 0:16])
            B = T.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32")
            for j in range(0, 16):
                with T.block() as []:
                    T.reads(A[i, j])
                    T.writes(B[0, j])
                    T.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]})
                    B[0, j] = A[i, j] + 1.0
            for j in range(0, 16):
                with T.block() as []:
                    T.reads(B[0, j])
                    T.writes(C[i, j])
                    C[i, j] = B[0, j] * 2.0
示例#16
0
def square_sum_with_annotation_rfactor(a: T.handle, c: T.handle) -> None:
    A = T.match_buffer(a, [16, 256, 256])
    C = T.match_buffer(c, [16])
    C_rf = T.alloc_buffer([16, 256])

    for i0, i1, i2 in T.grid(16, 256, 256):
        with T.block("C_rf"):
            T.block_attr({"test_annotation": 1})
            vi2, b, i = T.axis.remap("SSR", [i2, i0, i1])
            with T.init():
                C_rf[b, vi2] = 0.0
            C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2])

    for i0_1, i2_1 in T.grid(16, 256):
        with T.block("C"):
            T.block_attr({"test_annotation": 1})
            vi2_1, b_1 = T.axis.remap("RS", [i2_1, i0_1])
            with T.init():
                C[b_1] = 0.0
            C[b_1] = C[b_1] + C_rf[b_1, vi2_1]
def rewritten_tir_matmul(
    A: T.Buffer[(16, 16), "float32"],
    B: T.Buffer[(16, 16), "float32"],
    C: T.Buffer[(16, 16), "float32"],
) -> None:
    T.func_attr({"layout_free_buffers": [1]})
    B_reindex = T.alloc_buffer([16, 4, 4], dtype="float32")
    for ax0, ax1 in T.grid(16, 16):
        with T.block("layout_rewrite"):
            i0, i1 = T.axis.remap("SS", [ax0, ax1])
            T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
            B_reindex[i1, i0 // 4, i0 % 4] = B[i0, i1]
    for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
        with T.block("matmul"):
            vi = T.axis.spatial(16, i0 * 4 + i1)
            vj = T.axis.spatial(16, j)
            vk = T.axis.reduce(16, k0 * 4 + k1)
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B_reindex[vj, vk // 4, vk % 4]
def element_wise_storage_align(a: T.handle, c: T.handle) -> None:
    C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
    A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
    # body
    with T.block([], "root"):
        T.reads([])
        T.writes([])
        B = T.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1)
        for i0 in T.serial(0, 128):
            for ax1 in T.serial(0, 128):
                with T.block([128, 128], "B") as [vi, vj]:
                    T.bind(vi, i0)
                    T.bind(vj, ax1)
                    T.reads([A[vi, vj]])
                    T.writes([B[vi, vj]])
                    T.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]})
                    B[vi, vj] = (A[vi, vj]*T.float32(2))
            for i1 in T.serial(0, 128):
                with T.block([128, 128], "C") as [vi_1, vj_1]:
                    T.bind(vi_1, i0)
                    T.bind(vj_1, i1)
                    T.reads([B[vi_1, vj_1]])
                    T.writes([C[vi_1, vj_1]])
                    C[vi_1, vj_1] = (B[vi_1, vj_1] + T.float32(1))
示例#19
0
 def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     A = T.match_buffer(var_A, [512, 512], dtype="float32")
     B = T.match_buffer(var_B, [512, 512], dtype="float32")
     C = T.match_buffer(var_C, [512, 512], dtype="float32")
     # body
     # with T.block("root")
     C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
     A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.x"):
                 for i2_0 in T.serial(0, 1):
                     for ax0_ax1_fused_0 in T.serial(0, 32768):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.x"):
                             with T.block("A_shared"):
                                 v0 = T.axis.spatial(
                                     512,
                                     (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1)
                                     // 512)
                                 v1 = T.axis.spatial(
                                     512,
                                     (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1)
                                     % 512)
                                 T.reads([A[v0, v1]])
                                 T.writes([A_shared[v0, v1]])
                                 T.block_attr(
                                     {"meta_schedule.cooperative_fetch": 1})
                                 A_shared[v0, v1] = A[v0, v1]
                     for ax0_ax1_fused_0 in T.serial(0, 1024):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.x"):
                             for ax0_ax1_fused_2 in T.vectorized(0, 2):
                                 with T.block("B_shared"):
                                     v0 = T.axis.spatial(
                                         512, (ax0_ax1_fused_0 * 16 +
                                               ax0_ax1_fused_1 * 2 +
                                               ax0_ax1_fused_2) // 32)
                                     v1 = T.axis.spatial(
                                         512, i0_0_i1_0_fused * 32 +
                                         (ax0_ax1_fused_0 * 16 +
                                          ax0_ax1_fused_1 * 2 +
                                          ax0_ax1_fused_2) % 32)
                                     T.reads([B[v0, v1]])
                                     T.writes([B_shared[v0, v1]])
                                     T.block_attr({
                                         "meta_schedule.cooperative_fetch":
                                         2
                                     })
                                     B_shared[v0, v1] = B[v0, v1]
                     for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(
                             16, 2, 2, 32, 16, 2):
                         with T.block("C"):
                             i = T.axis.spatial(
                                 512,
                                 i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
                             j = T.axis.spatial(
                                 512, i0_0_i1_0_fused * 32 +
                                 i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4)
                             k = T.axis.reduce(512, i2_1 * 32 + i2_2)
                             T.reads([
                                 C_local[i, j], A_shared[i, k], B_shared[k,
                                                                         j]
                             ])
                             T.writes([C_local[i, j]])
                             with T.init():
                                 C_local[i, j] = T.float32(0)
                             C_local[i, j] = C_local[
                                 i, j] + A_shared[i, k] * B_shared[k, j]
                 for ax0, ax1 in T.grid(32, 4):
                     with T.block("C_local"):
                         v0 = T.axis.spatial(512,
                                             i0_1_i1_1_fused * 32 + ax0)
                         v1 = T.axis.spatial(
                             512, i0_0_i1_0_fused * 32 +
                             i0_2_i1_2_fused * 4 + ax1)
                         T.reads([C_local[v0, v1]])
                         T.writes([C[v0, v1]])
                         C[v0, v1] = C_local[v0, v1]
示例#20
0
 def main(
     X: T.Buffer[(128, 128), "int8"],
     W: T.Buffer[(128, 128), "int8"],
     compute: T.Buffer[(128, 128), "int32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     compute_local = T.alloc_buffer([128, 128],
                                    dtype="int32",
                                    scope="local")
     X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(2,
                                                     thread="threadIdx.x"):
                 for i0_3_init, i1_3_init, i0_4_init in T.grid(4, 16, 4):
                     with T.block("compute_o_init"):
                         i = T.axis.spatial(
                             128, i0_0_i1_0_fused // 2 * 16 +
                             i0_3_init * 4 + i0_4_init)
                         j = T.axis.spatial(
                             128,
                             i0_0_i1_0_fused % 2 * 64 +
                             i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 +
                             i1_3_init,
                         )
                         T.reads()
                         T.writes(compute_local[i, j])
                         T.block_attr(
                             {"meta_schedule.auto_tensorize": "dp4a"})
                         with T.block("compute_init"):
                             T.reads()
                             T.writes(compute_local[i, j])
                             compute_local[i, j] = 0
                 for i2_0_0 in T.serial(2):
                     for ax0_ax1_fused in T.serial(1024):
                         with T.block("X_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(X[v0, v1])
                             T.writes(X_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 4})
                             X_shared[v0, v1] = X[v0, v1]
                     for ax0_ax1_fused in T.serial(4096):
                         with T.block("W_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused % 2 * 64 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(W[v0, v1])
                             T.writes(W_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 1})
                             W_shared[v0, v1] = W[v0, v1]
                     for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(
                             2, 4, 16, 8, 4, 1):
                         with T.block("compute_o_update"):
                             i = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 +
                                 i0_4)
                             j = T.axis.spatial(
                                 128,
                                 i0_0_i1_0_fused % 2 * 64 +
                                 i0_1_i1_1_fused * 32 +
                                 i0_2_i1_2_fused * 16 + i1_3,
                             )
                             k_o = T.axis.reduce(
                                 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2)
                             T.reads(
                                 compute_local[i, j],
                                 X_shared[i, k_o * 4:k_o * 4 + 4],
                                 W_shared[j, k_o * 4:k_o * 4 + 4],
                             )
                             T.writes(compute_local[i, j])
                             A = T.match_buffer(
                                 X_shared[i, k_o * 4:k_o * 4 + 4],
                                 [4],
                                 dtype="int8",
                                 scope="shared",
                                 align=4,
                                 offset_factor=1,
                             )
                             B = T.match_buffer(
                                 W_shared[j, k_o * 4:k_o * 4 + 4],
                                 [4],
                                 dtype="int8",
                                 scope="shared",
                                 align=4,
                                 offset_factor=1,
                             )
                             C = T.match_buffer(
                                 compute_local[i, j],
                                 [1],
                                 dtype="int32",
                                 scope="local",
                                 align=4,
                                 offset_factor=1,
                             )
                             C[0] = C[0] + T.call_pure_extern(
                                 "__dp4a",
                                 A[T.ramp(0, 1, 4)],
                                 B[T.ramp(0, 1, 4)],
                                 0,
                                 dtype="int32",
                             )
                 for ax0, ax1 in T.grid(16, 16):
                     with T.block("compute_local"):
                         v0 = T.axis.spatial(
                             128, i0_0_i1_0_fused // 2 * 16 + ax0)
                         v1 = T.axis.spatial(
                             128,
                             i0_0_i1_0_fused % 2 * 64 +
                             i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 +
                             ax1,
                         )
                         T.reads(compute_local[v0, v1])
                         T.writes(compute[v0, v1])
                         compute[v0, v1] = compute_local[v0, v1]
 def main(A: T.Buffer[(1024, 1024), "float32"], B: T.Buffer[(1024, 1024),
                                                            "float32"],
          C: T.Buffer[(1024, 1024), "float32"]) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
     # body
     # with T.block("root")
     for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"):
         for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"):
             for threadIdx_y in T.thread_binding(2, thread="threadIdx.y"):
                 for threadIdx_x in T.thread_binding(2,
                                                     thread="threadIdx.x"):
                     for k_0 in T.serial(32):
                         with T.block():
                             T.reads(
                                 A[blockIdx_y * 32:blockIdx_y * 32 + 32,
                                   k_0 * 32:k_0 * 32 + 32],
                                 B[k_0 * 32:k_0 * 32 + 32,
                                   blockIdx_x * 32:blockIdx_x * 32 + 32])
                             T.writes(
                                 C[blockIdx_y * 32:blockIdx_y * 32 + 32,
                                   blockIdx_x * 32:blockIdx_x * 32 + 32])
                             A_shared = T.alloc_buffer([1024, 1024],
                                                       dtype="float32",
                                                       scope="shared")
                             B_shared = T.alloc_buffer([1024, 1024],
                                                       dtype="float32",
                                                       scope="shared")
                             for ax0_ax1_fused_0 in T.serial(64):
                                 for ax0_ax1_fused_3 in T.vectorized(4):
                                     with T.block("A_shared"):
                                         T.reads(A[blockIdx_y * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) // 32,
                                                   k_0 * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) % 32])
                                         T.writes(A_shared[
                                             blockIdx_y * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32])
                                         T.block_attr({
                                             "tir.manifest_shared_memory_local_stage":
                                             1
                                         })
                                         A_shared[
                                             blockIdx_y * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32] = A[
                                                  blockIdx_y * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) // 32,
                                                  k_0 * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) % 32]
                             for ax0_ax1_fused_0 in T.serial(64):
                                 for ax0_ax1_fused_3 in T.vectorized(4):
                                     with T.block("B_shared"):
                                         T.reads(B[k_0 * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) // 32,
                                                   blockIdx_x * 32 +
                                                   (ax0_ax1_fused_0 * 16 +
                                                    threadIdx_y * 8 +
                                                    threadIdx_x * 4 +
                                                    ax0_ax1_fused_3) % 32])
                                         T.writes(B_shared[
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             blockIdx_x * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32])
                                         T.block_attr({
                                             "tir.manifest_shared_memory_local_stage":
                                             1
                                         })
                                         B_shared[
                                             k_0 * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) // 32,
                                             blockIdx_x * 32 +
                                             (ax0_ax1_fused_0 * 16 +
                                              threadIdx_y * 8 +
                                              threadIdx_x * 4 +
                                              ax0_ax1_fused_3) % 32] = B[
                                                  k_0 * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) // 32,
                                                  blockIdx_x * 32 +
                                                  (ax0_ax1_fused_0 * 16 +
                                                   threadIdx_y * 8 +
                                                   threadIdx_x * 4 +
                                                   ax0_ax1_fused_3) % 32]
                             for k_1, i_2, j_2, k_2 in T.grid(
                                     2, 16, 16, 16):
                                 with T.block("C"):
                                     T.reads(
                                         A_shared[blockIdx_y * 32 +
                                                  threadIdx_y * 16 + i_2,
                                                  k_0 * 32 + k_1 * 16 +
                                                  k_2],
                                         B_shared[k_0 * 32 + k_1 * 16 + k_2,
                                                  blockIdx_x * 32 +
                                                  threadIdx_x * 16 + j_2])
                                     T.writes(C[blockIdx_y * 32 +
                                                threadIdx_y * 16 + i_2,
                                                blockIdx_x * 32 +
                                                threadIdx_x * 16 + j_2])
                                     if k_0 * 32 + k_1 * 16 + k_2 == 0:
                                         C[blockIdx_y * 32 +
                                           threadIdx_y * 16 + i_2,
                                           blockIdx_x * 32 +
                                           threadIdx_x * 16 +
                                           j_2] = T.float32(0)
                                     C[
                                         blockIdx_y * 32 +
                                         threadIdx_y * 16 + i_2,
                                         blockIdx_x * 32 +
                                         threadIdx_x * 16 + j_2] = C[
                                             blockIdx_y * 32 +
                                             threadIdx_y * 16 + i_2,
                                             blockIdx_x * 32 + threadIdx_x *
                                             16 + j_2] + A_shared[
                                                 blockIdx_y * 32 +
                                                 threadIdx_y * 16 + i_2,
                                                 k_0 * 32 + k_1 * 16 +
                                                 k_2] * B_shared[
                                                     k_0 * 32 + k_1 * 16 +
                                                     k_2, blockIdx_x * 32 +
                                                     threadIdx_x * 16 + j_2]
示例#22
0
def implicit_root_has_attrs():
    T.block_attr({})  # error: implicit root does not support block_attr
    T.evaluate(0.0)
示例#23
0
def duplicate_annotations() -> None:
    for i, j in T.grid(16, 16):
        with T.block():
            vi, vj = T.axis.remap("SS", [i, j])
            T.block_attr({})
            T.block_attr({})  # error
示例#24
0
def duplicate_annotations() -> None:
    with T.block([16, 16]) as [vi, vj]:
        T.block_attr({})
        T.block_attr({})  # error
示例#25
0
 def main(
     X: T.Buffer[(128, 128), "int8"],
     W: T.Buffer[(128, 128), "int8"],
     compute: T.Buffer[(128, 128), "int32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     compute_local = T.alloc_buffer([128, 128],
                                    dtype="int32",
                                    scope="local")
     X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(2, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(2,
                                                     thread="threadIdx.x"):
                 for i2_0_0 in T.serial(2):
                     for ax0_ax1_fused in T.serial(1024):
                         with T.block("X_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(X[v0, v1])
                             T.writes(X_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 4})
                             X_shared[v0, v1] = X[v0, v1]
                     for ax0_ax1_fused in T.serial(4096):
                         with T.block("W_shared"):
                             v0 = T.axis.spatial(
                                 128, i0_0_i1_0_fused % 2 * 64 +
                                 ax0_ax1_fused // 64)
                             v1 = T.axis.spatial(
                                 128, i2_0_0 * 64 + ax0_ax1_fused % 64)
                             T.reads(W[v0, v1])
                             T.writes(W_shared[v0, v1])
                             T.block_attr(
                                 {"meta_schedule.cooperative_fetch": 1})
                             W_shared[v0, v1] = W[v0, v1]
                     for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(
                             2, 4, 16, 8, 4, 1):
                         with T.block("compute_o"):
                             i = T.axis.spatial(
                                 128, i0_0_i1_0_fused // 2 * 16 + i0_3 * 4 +
                                 i0_4)
                             j = T.axis.spatial(
                                 128,
                                 i0_0_i1_0_fused % 2 * 64 +
                                 i0_1_i1_1_fused * 32 +
                                 i0_2_i1_2_fused * 16 + i1_3,
                             )
                             k_o = T.axis.reduce(
                                 32, i2_0_0 * 16 + i2_0_1 * 8 + i2_0_2)
                             T.reads(
                                 X_shared[i, k_o * 4:k_o * 4 + 4],
                                 W_shared[j, k_o * 4:k_o * 4 + 4],
                             )
                             T.writes(compute_local[i, j])
                             T.block_attr(
                                 {"meta_schedule.auto_tensorize": "dp4a"})
                             with T.init():
                                 with T.block("compute_init"):
                                     T.reads()
                                     T.writes(compute_local[i, j])
                                     compute_local[i, j] = 0
                             for i2_1 in T.serial(4):
                                 with T.block("compute"):
                                     k = T.axis.reduce(4, i2_1)
                                     T.reads(
                                         compute_local[i, j],
                                         X_shared[i, k_o * 4 + k],
                                         W_shared[j, k_o * 4 + k],
                                     )
                                     T.writes(compute_local[i, j])
                                     T.block_attr({
                                         "meta_schedule.tiling_structure":
                                         "SSSRRSRS"
                                     })
                                     compute_local[
                                         i,
                                         j] = compute_local[i, j] + T.cast(
                                             X_shared[i, k_o * 4 + k],
                                             "int32") * T.cast(
                                                 W_shared[j, k_o * 4 + k],
                                                 "int32")
                 for ax0, ax1 in T.grid(16, 16):
                     with T.block("compute_local"):
                         v0 = T.axis.spatial(
                             128, i0_0_i1_0_fused // 2 * 16 + ax0)
                         v1 = T.axis.spatial(
                             128,
                             i0_0_i1_0_fused % 2 * 64 +
                             i0_1_i1_1_fused * 32 + i0_2_i1_2_fused * 16 +
                             ax1,
                         )
                         T.reads(compute_local[v0, v1])
                         T.writes(compute[v0, v1])
                         compute[v0, v1] = compute_local[v0, v1]
 def main(
     A: T.Buffer[(512, 512), "float32"],
     B: T.Buffer[(512, 512), "float32"],
     C: T.Buffer[(512, 512), "float32"],
 ) -> None:
     # function attr dict
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     # body
     # with T.block("root")
     C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local")
     A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared")
     for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"):
         for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"):
             for i0_2_i1_2_fused in T.thread_binding(0,
                                                     8,
                                                     thread="threadIdx.y"):
                 for i2_0 in T.serial(0, 1):
                     for ax0_ax1_fused_0 in T.serial(0, 1024):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.y"):
                             for ax0_ax1_fused_2 in T.thread_binding(
                                     0, 32, thread="threadIdx.x"):
                                 with T.block("A_shared"):
                                     v0 = T.axis.spatial(
                                         512,
                                         (ax0_ax1_fused_0 * 256 +
                                          ax0_ax1_fused_1 * 32 +
                                          ax0_ax1_fused_2) // 512,
                                     )
                                     v1 = T.axis.spatial(
                                         512,
                                         (ax0_ax1_fused_0 * 256 +
                                          ax0_ax1_fused_1 * 32 +
                                          ax0_ax1_fused_2) % 512,
                                     )
                                     T.reads([A[v0, v1]])
                                     T.writes([A_shared[v0, v1]])
                                     A_shared[v0, v1] = A[v0, v1]
                     for ax0_ax1_fused_0 in T.serial(0, 32):
                         for ax0_ax1_fused_1 in T.thread_binding(
                                 0, 8, thread="threadIdx.y"):
                             for ax0_ax1_fused_2 in T.thread_binding(
                                     0, 32, thread="threadIdx.x"):
                                 for ax0_ax1_fused_3 in T.vectorized(0, 2):
                                     with T.block("B_shared"):
                                         v0 = T.axis.spatial(
                                             512,
                                             (ax0_ax1_fused_0 * 512 +
                                              ax0_ax1_fused_1 * 64 +
                                              ax0_ax1_fused_2 * 2 +
                                              ax0_ax1_fused_3) // 32,
                                         )
                                         v1 = T.axis.spatial(
                                             512,
                                             i0_0_i1_0_fused * 32 +
                                             (ax0_ax1_fused_0 * 512 +
                                              ax0_ax1_fused_1 * 64 +
                                              ax0_ax1_fused_2 * 2 +
                                              ax0_ax1_fused_3) % 32,
                                         )
                                         T.reads([B[v0, v1]])
                                         T.writes([B_shared[v0, v1]])
                                         B_shared[v0, v1] = B[v0, v1]
                     for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(
                             16, 2, 2, 32, 16, 2):
                         with T.block("C"):
                             i = T.axis.spatial(
                                 512,
                                 i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4)
                             j = T.axis.spatial(
                                 512,
                                 i0_0_i1_0_fused * 32 +
                                 i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4,
                             )
                             k = T.axis.reduce(
                                 512, i2_0 * 512 + i2_1 * 32 + i2_2)
                             T.reads([A_shared[i, k], B_shared[k, j]])
                             T.writes([C_local[i, j]])
                             T.block_attr({"warp_execution": 1})
                             with T.init():
                                 C_local[i, j] = T.float32(0)
                             C_local[i, j] = C_local[
                                 i, j] + A_shared[i, k] * B_shared[k, j]
                 for ax0, ax1 in T.grid(32, 4):
                     with T.block("C_local"):
                         v0 = T.axis.spatial(512,
                                             i0_1_i1_1_fused * 32 + ax0)
                         v1 = T.axis.spatial(
                             512, i0_0_i1_0_fused * 32 +
                             i0_2_i1_2_fused * 4 + ax1)
                         T.reads([C_local[v0, v1]])
                         T.writes([C[v0, v1]])
                         C[v0, v1] = C_local[v0, v1]
示例#27
0
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256],
                                    dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    reduce_temp1 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp1 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_maxelem_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.min_value("float32")
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_maxelem_normal_reduction"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([A[i0_1, k], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0],
                                                   A[i0_1, k])
            with T.block("T_softmax_maxelem_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: T.max(x, y),
                                   [T.min_value("float32")]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_maxelem_write_back"):
                i0_2 = T.axis.spatial(256, i0)
                T.reads([reduce_temp0[0]])
                T.writes([T_softmax_maxelem_shared[i0_2]])
                T_softmax_maxelem_shared[i0_2] = reduce_temp0[0]
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_expsum_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp1[0]])
                normal_reduce_temp1[0] = T.float32(0)
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_expsum_normal_reduction"):
                    i0_3 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([
                        A[i0_3, k],
                        T_softmax_maxelem_shared[i0_3],
                        normal_reduce_temp1[0],
                    ])
                    T.writes([normal_reduce_temp1[0]])
                    normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp(
                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
                        dtype="float32")
            with T.block("T_softmax_expsum_cross_thread_reduction"):
                T.reads([normal_reduce_temp1[0]])
                T.writes([reduce_temp1[0]])
                T.attr(
                    T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp1[0],
                        True,
                        reduce_temp1.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_expsum_write_back"):
                i0_4 = T.axis.spatial(256, i0)
                T.reads([reduce_temp1[0]])
                T.writes([T_softmax_expsum_shared[i0_4]])
                T_softmax_expsum_shared[i0_4] = reduce_temp1[0]
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_5 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads([
                        A[i0_5, i1],
                        T_softmax_maxelem_shared[i0_5],
                        T_softmax_expsum_shared[i0_5],
                    ])
                    T.writes([T_softmax_norm[i0_5, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_5, i1] = (T.exp(
                        A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
                        dtype="float32",
                    ) / T_softmax_expsum_shared[i0_5])
示例#28
0
def lowered_single_reduction_loop_with_block_predicate(
        A: T.Buffer[(256, 256), "float32"],
        T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    cross_thread_0 = T.alloc_buffer([1],
                                    dtype="float32",
                                    strides=[1],
                                    scope="local")
    in_thread_0 = T.alloc_buffer([1],
                                 dtype="float32",
                                 strides=[1],
                                 scope="local")
    cross_thread_1 = T.alloc_buffer([1],
                                    dtype="float32",
                                    strides=[1],
                                    scope="local")
    in_thread_1 = T.alloc_buffer([1],
                                 dtype="float32",
                                 strides=[1],
                                 scope="local")
    for i0 in T.serial(256):
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem_in_thread_init"):
                    T.reads()
                    T.writes(in_thread_0[0])
                    in_thread_0[0] = T.float32(-3.4028234663852886e38)
                with T.block("T_softmax_maxelem_in_thread"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(A[i0_1, k], in_thread_0[0])
                    T.writes(in_thread_0[0])
                    in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
                with T.block("T_softmax_maxelem_cross_thread"):
                    T.reads(in_thread_0[0])
                    T.writes(cross_thread_0[0])
                    T.attr(
                        T.comm_reducer(lambda x, y: T.max(x, y),
                                       [T.float32(-3.4028234663852886e38)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(
                            T.uint32(1),
                            in_thread_0[0],
                            True,
                            cross_thread_0.data,
                            ax1_1,
                            dtype="handle",
                        ))
                with T.block("T_softmax_maxelem_write_back"):
                    i0_2 = T.axis.spatial(256, i0)
                    T.reads(cross_thread_0[0])
                    T.writes(T_softmax_maxelem_shared[i0_2])
                    T_softmax_maxelem_shared[i0_2] = cross_thread_0[0]
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_expsum_in_thread_init"):
                    T.reads()
                    T.writes(in_thread_1[0])
                    in_thread_1[0] = T.float32(0)
                with T.block("T_softmax_expsum_in_thread"):
                    i0_3 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3],
                            in_thread_1[0])
                    T.writes(in_thread_1[0])
                    in_thread_1[0] = in_thread_1[0] + T.exp(
                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
                        dtype="float32")
                with T.block("T_softmax_expsum_cross_thread"):
                    T.reads(in_thread_1[0])
                    T.writes(cross_thread_1[0])
                    T.attr(
                        T.comm_reducer(lambda x_1, y_1: x_1 + y_1,
                                       [T.float32(0)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(
                            T.uint32(1),
                            in_thread_1[0],
                            True,
                            cross_thread_1.data,
                            ax1_1,
                            dtype="handle",
                        ))
                with T.block("T_softmax_expsum_write_back"):
                    i0_4 = T.axis.spatial(256, i0)
                    T.reads(cross_thread_1[0])
                    T.writes(T_softmax_expsum_shared[i0_4])
                    T_softmax_expsum_shared[i0_4] = cross_thread_1[0]
        for i1_0 in T.serial(1):
            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_5 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_1)
                    T.where(i1_0 * 512 + i1_1 < 256)
                    T.reads(A[i0_5, i1], T_softmax_maxelem_shared[i0_5],
                            T_softmax_expsum_shared[i0_5])
                    T.writes(T_softmax_norm[i0_5, i1])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_5, i1] = (
                        T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
                              dtype="float32") / T_softmax_expsum_shared[i0_5])
示例#29
0
def conv2d_winograd_cuda(  # type: ignore
    placeholder: T.Buffer[(1, 14, 14, 128), "float32"],  # type: ignore
    placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"],  # type: ignore
    conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"],  # type: ignore
) -> None:
    # type: ignore
    data_pad = T.alloc_buffer([1, 16, 16, 128])
    input_tile = T.alloc_buffer([6, 6, 9, 128])
    B = T.alloc_buffer([6, 6])
    data_pack = T.alloc_buffer([6, 6, 9, 128])
    bgemm = T.alloc_buffer([6, 6, 9, 128])
    A = T.alloc_buffer([6, 4])
    inverse = T.alloc_buffer([4, 4, 9, 128])
    for i0, i1, i2, i3 in T.grid(1, 16, 16, 128):
        with T.block("data_pad"):
            i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.block_attr({"schedule_rule": "None"})
            T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]])
            T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]])
            data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
                0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14,  # type: ignore
                placeholder[i0_1, i1_1, i2_1, i3_1],
                T.float32(0),
                dtype="float32",
            )
    for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128):
        with T.block("input_tile"):
            eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2])
            T.block_attr({"schedule_rule": "None"})
            T.reads(
                [
                    data_pad[
                        T.floordiv(p, 9),  # type: ignore
                        ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                        ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                        ci,
                    ]
                ]
            )
            T.writes([input_tile[eps, nu, p, ci]])
            input_tile[eps, nu, p, ci] = data_pad[
                T.floordiv(p, 9),  # type: ignore
                ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps),  # type: ignore
                ((T.floormod(p, 3) * 4) + nu),  # type: ignore
                ci,
            ]
    for i0_3, i1_3 in T.grid(6, 6):
        with T.block("B"):
            i, j = T.axis.remap("SS", [i0_3, i1_3])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([B[i, j]])
            # fmt: off
            B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6):
        with T.block("data_pack"):
            eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap(
                "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"})
            T.reads(
                [
                    data_pack[eps_1, nu_1, p_1, ci_1],
                    input_tile[r_a, r_b, p_1, ci_1],
                    B[
                        T.min(r_a, r_b) : (  # type: ignore
                            T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b))  # type: ignore
                        ),
                        T.min(eps_1, nu_1) : (  # type: ignore
                            T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1))  # type: ignore
                        ),
                    ],
                ]
            )
            T.writes([data_pack[eps_1, nu_1, p_1, ci_1]])
            with T.init():
                data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0)
            data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + (
                (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1]
            )
    for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128):
        with T.block("bgemm"):
            eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1])
            T.block_attr({"meta_schedule.write_cache_level": [3]})
            T.reads(
                [
                    bgemm[eps_2, nu_2, p_2, co],
                    data_pack[eps_2, nu_2, p_2, ci_2],
                    placeholder_1[eps_2, nu_2, co, ci_2],
                ]
            )
            T.writes([bgemm[eps_2, nu_2, p_2, co]])
            with T.init():
                bgemm[eps_2, nu_2, p_2, co] = T.float32(0)
            bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + (
                data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2]
            )
    for i0_6, i1_6 in T.grid(6, 4):
        with T.block("A"):
            i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6])
            T.block_attr({"schedule_rule": "meta_schedule.compute_inline"})
            T.writes([A[i_1, j_1]])
            # fmt: off
            A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0)))))))))))))))))))))))))  # type: ignore
            # fmt: on
    for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6):
        with T.block("inverse"):
            vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap(
                "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1]
            )
            T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"})
            T.reads(
                [
                    inverse[vh, vw, p_3, co_1],
                    bgemm[r_a_1, r_b_1, p_3, co_1],
                    A[
                        T.min(r_a_1, r_b_1) : (  # type: ignore
                            T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1))  # type: ignore
                        ),
                        T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))),  # type: ignore
                    ],
                ]
            )
            T.writes([inverse[vh, vw, p_3, co_1]])
            with T.init():
                inverse[vh, vw, p_3, co_1] = T.float32(0)
            inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + (
                (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw]
            )
    for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128):
        with T.block("conv2d_winograd"):
            n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6])
            T.reads(
                [
                    inverse[
                        T.floormod(h, 4),  # type: ignore
                        T.floormod(w, 4),  # type: ignore
                        (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                        co_2,
                    ]
                ]
            )
            T.writes([conv2d_winograd[n, h, w, co_2]])
            conv2d_winograd[n, h, w, co_2] = inverse[
                T.floormod(h, 4),  # type: ignore
                T.floormod(w, 4),  # type: ignore
                (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)),  # type: ignore
                co_2,
            ]
示例#30
0
 def main(
     placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"],
     placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"],
     conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"],
 ) -> None:
     T.func_attr({"global_symbol": "main", "tir.noalias": True})
     for (
             i0_0,
             i1_0,
             i2_0,
             i3_0,
             i4_0_0,
             i0_1,
             i1_1,
             i2_1,
             i3_1,
             i4_0_1,
             i5_0,
             i6_0,
             i7_0,
             i8_0,
             i9_0_0,
             i0_2,
             i1_2,
             i2_2,
             i3_2,
             i4_0_2,
             i5_1,
             i6_1,
             i7_1,
             i8_1,
             i9_0_1,
             i0_3,
             i1_3,
             i2_3,
             i3_3,
             i4_0_3,
     ) in T.grid(
             1,
             1,
             2,
             1,
             1,
             1,
             4,
             1,
             14,
             1,
             1,
             1,
             4,
             1,
             1,
             1,
             4,
             7,
             1,
             1,
             1,
             1,
             1,
             4,
             1,
             1,
             1,
             4,
             4,
             1,
     ):
         with T.block("conv2d_NCHWc_int8_o"):
             n = T.axis.spatial(1, 0)
             oc_chunk = T.axis.spatial(16, i1_1 * 4 + i1_2)
             oh = T.axis.spatial(56, i2_0 * 28 + i2_2 * 4 + i2_3)
             ow = T.axis.spatial(56, i3_1 * 4 + i3_3)
             oc_block_o = T.axis.spatial(1, 0)
             kh = T.axis.reduce(1, 0)
             kw = T.axis.reduce(1, 0)
             ic_outer, ic_f_inner = T.axis.remap("RR", [i7_0, i8_1])
             ic_s_inner_o = T.axis.reduce(1, 0)
             T.reads(
                 placeholder[n, ic_outer, oh + kh, ow + kw,
                             ic_f_inner * 4:ic_f_inner * 4 + 4],
                 placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:16,
                               0:4],
             )
             T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:16])
             T.block_attr({"meta_schedule.auto_tensorize": "dot_16x4_vnni"})
             with T.init():
                 for i4_1 in T.serial(16):
                     with T.block("conv2d_NCHWc_int8_init"):
                         oc_block_init = T.axis.spatial(16, i4_1)
                         T.reads()
                         T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                                    oc_block_init])
                         conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                           oc_block_init] = 0
             for i4_1, i9_1 in T.grid(16, 4):
                 with T.block("conv2d_NCHWc_int8"):
                     oc_block, ic_s_inner = T.axis.remap("SR", [i4_1, i9_1])
                     T.reads(
                         conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block],
                         placeholder[n, ic_outer, oh + kh, ow + kw,
                                     ic_f_inner * 4 + ic_s_inner],
                         placeholder_1[oc_chunk, ic_outer, kh, kw,
                                       ic_f_inner, oc_block, ic_s_inner],
                     )
                     T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
                                                oc_block])
                     T.block_attr(
                         {"meta_schedule.tiling_structure": "SSRSRS"})
                     conv2d_NCHWc_int8[
                         n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[
                             n, oc_chunk, oh, ow, oc_block] + T.cast(
                                 placeholder[n, ic_outer, oh + kh, ow + kw,
                                             ic_f_inner * 4 + ic_s_inner],
                                 "int32",
                             ) * T.cast(
                                 placeholder_1[oc_chunk, ic_outer, kh, kw,
                                               ic_f_inner, oc_block,
                                               ic_s_inner],
                                 "int32",
                             )