Esempio n. 1
0
def lowered_zero_rank_buffer(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128], dtype="float32")
    B = T.match_buffer(b, [], dtype="float32")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    for k in T.thread_binding(0, 128, thread="threadIdx.x"):
        with T.block("B_cross_thread_reduction"):
            vk = T.axis.reduce(128, k)
            T.reads([A[vk]])
            T.writes([reduce_temp0[0]])
            T.attr(
                T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
                "reduce_scope",
                T.reinterpret(T.uint64(0), dtype="handle"),
            )
            T.evaluate(
                T.tvm_thread_allreduce(T.uint32(1),
                                       A[vk],
                                       True,
                                       reduce_temp0.data,
                                       k,
                                       dtype="handle"))
        with T.block("B_write_back"):
            T.reads([reduce_temp0[0]])
            T.writes([B[()]])
            B[()] = reduce_temp0[0]
Esempio n. 2
0
 def main(A: T.handle, tensor: T.handle) -> None:
     # function attr dict
     T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
     # buffer definition
     tensor_2 = T.buffer_decl([1, 10, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1)
     A_1 = T.match_buffer(A, [1, 12, 14, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1)
     tensor_1 = T.match_buffer(tensor, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1)
     # body
     T.realize(tensor_1[0:1, 0:8, 0:8, 0:16], "")
     for ax1_outer in T.serial(0, 2):
         T.realize(tensor_2[0:1, (ax1_outer*4):((ax1_outer*4) + 6), 0:12, 0:16], "")
         T.attr(tensor_2, "rolling_buffer_scope", True)
         for ax1 in T.serial(0, 6):
             for ax2 in T.serial(0, 12):
                 for ax3 in T.serial(0, 16):
                     tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.int8(0)
                     for dh in T.serial(0, 3):
                         for dw in T.serial(0, 3):
                             tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3] = T.max(tensor_2[0, (ax1 + (ax1_outer*4)), ax2, ax3], A_1[0, ((ax1 + (ax1_outer*4)) + dh), (ax2 + dw), ax3])
         for ax1_inner in T.serial(0, 4):
             for ax2_inner in T.serial(0, 8):
                 for ax3_inner in T.serial(0, 16):
                     tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.int8(0)
                     for dh_1 in T.serial(0, 3):
                         for dw_1 in T.serial(0, 5):
                             tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner] = T.max(tensor_1[0, (ax1_inner + (ax1_outer*4)), ax2_inner, ax3_inner], tensor_2[0, ((ax1_inner + (ax1_outer*4)) + dh_1), (ax2_inner + dw_1), ax3_inner])
Esempio n. 3
0
def ptx_global_to_shared_dyn_copy_fp16x8(
    A: T.Buffer[(32, 128), "float16"],
    B: T.Buffer[(32, 128), "float16"],
    C: T.Buffer[(32, 128), "float16"],
) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    bx = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(bx, 1)
    T.launch_thread(tx, 32)
    with T.block():
        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
        B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn")
        T.reads(A[0:32, 0:128], B[0:32, 0:128])
        T.writes(C[0:32, 0:128])

        T.attr("default", "async_scope", 1)
        for i in T.serial(16):
            for j in T.vectorized(8):
                A_shared[tx, i * 8 + j] = A[tx, i * 8 + j]
                B_shared[tx, i * 8 + j] = B[tx, i * 8 + j]

        T.evaluate(T.ptx_commit_group(dtype=""))
        T.evaluate(T.ptx_wait_group(0, dtype=""))

        for i in range(128):
            C[tx, i] = A_shared[tx, i] + B_shared[tx, i]
Esempio n. 4
0
def lowered_reducer_max(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    for i in T.serial(0, 128):
        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block("B_cross_thread_reduction"):
                vi, vk = T.axis.remap("SR", [i, k])
                T.reads([A[vi, vk]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: T.max(x, y),
                                   [T.min_value("float32")]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(T.uint32(1),
                                           A[vi, vk],
                                           True,
                                           reduce_temp0.data,
                                           k,
                                           dtype="handle"))
            with T.block("B_write_back"):
                vi = T.axis.spatial(128, i)
                T.reads([reduce_temp0[0]])
                T.writes([B[vi]])
                B[vi] = reduce_temp0[0]
Esempio n. 5
0
def lowered_two_bound_loops(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    for i in T.serial(0, 128):
        for ko in T.thread_binding(0, 4, thread="threadIdx.x"):
            for ki in T.thread_binding(0, 32, thread="threadIdx.y"):
                with T.block("B_cross_thread_reduction"):
                    vi = T.axis.spatial(128, i)
                    vk = T.axis.reduce(128, ko * 32 + ki)
                    T.reads([A[vi, vk]])
                    T.writes([reduce_temp0[0]])
                    T.attr(
                        T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(T.uint32(1),
                                               A[vi, vk],
                                               True,
                                               reduce_temp0.data,
                                               ko,
                                               ki,
                                               dtype="handle"))
                with T.block("B_write_back"):
                    vi = T.axis.spatial(128, i)
                    T.reads([reduce_temp0[0]])
                    T.writes([B[vi]])
                    B[vi] = reduce_temp0[0]
Esempio n. 6
0
 def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol": "tvmgen_default_run_model",
         "runner_function": True
     })
     # body
     T.attr("default", "device_id", 0)
     T.attr("default", "device_type", 1)
     sid_9 = T.allocate([301056], "int8", "global")
     sid_8 = T.allocate([802816], "int8", "global")
     T.evaluate(
         T.call_extern("tvmgen_default_fused_cast_subtract",
                       input,
                       T.lookup_param("p0", dtype="handle"),
                       sid_9,
                       dtype="int32"))
     T.evaluate(
         T.call_extern(
             "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast",
             sid_9,
             T.lookup_param("p1", dtype="handle"),
             T.lookup_param("p2", dtype="handle"),
             sid_8,
             dtype="int32"))
     T.evaluate(
         T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast",
                       sid_8,
                       output,
                       dtype="int32"))
Esempio n. 7
0
def undefined_buffer(a: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")

    T.attr(A, "realize_scope", "")
    T.realize(C[0:16, 0:16], "")  # error
    for i in T.serial(16):
        for j in T.serial(0, 16):
            A[i, j] = 0.0
Esempio n. 8
0
def range_missing_args(a: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")

    T.attr(A, "realize_scope", "")
    T.realize(A[0:16, 0:16], "")
    for i in T.serial(16):  # error
        for j in T.serial(0, 16):
            A[i, j] = 0.0
Esempio n. 9
0
def unsupported_function_call(a: T.handle) -> None:
    A = T.match_buffer(a, (16, 16), "float32")

    T.attr(A, "realize_scope", "")
    T.realize(A[0:16, 0:16], "")
    for i in T.const_range(16):  # error
        for j in T.serial(0, 16):
            A[i, j] = 0.0
Esempio n. 10
0
def lowered_multiple_blocks_under_reduction_loop(a: T.handle,
                                                 b: T.handle) -> None:
    A = T.match_buffer(a, [16, 16, 16], dtype="float32")
    B = T.match_buffer(b, [16], dtype="float32")
    B_rf_local = T.alloc_buffer([16, 16], dtype="float32", scope="local")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i in T.thread_binding(0, 16, thread="blockIdx.x"):
        for k0o in T.thread_binding(0, 4, thread="threadIdx.x"):
            with T.block("B_in_thread_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.float32(0)
            for k0i0, k1 in T.grid(4, 16):
                with T.block("B_rf"):
                    vk0 = T.axis.spatial(16, k0o * 4 + k0i0)
                    vi, vk1 = T.axis.remap("SR", [i, k1])
                    T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]])
                    T.writes([B_rf_local[vk0, vi]])
                    with T.init():
                        B_rf_local[vk0, vi] = T.float32(0)
                    B_rf_local[vk0, vi] = B_rf_local[vk0, vi] + A[vi, vk0, vk1]
            for k0i1 in T.serial(0, 4):
                with T.block("B_normal_reduction"):
                    vk0 = T.axis.reduce(16, k0o * 4 + k0i1)
                    vi = T.axis.spatial(16, i)
                    T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[
                        0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi]
            with T.block("B_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0.data,
                        k0o,
                        dtype="handle",
                    ))
            with T.block("B_write_back"):
                vi = T.axis.spatial(16, i)
                T.reads([reduce_temp0[0]])
                T.writes([B[vi]])
                B[vi] = reduce_temp0[0]
Esempio n. 11
0
 def main(placeholder1: T.Buffer[(100, ), "int8"],
          placeholder2: T.Buffer[(100, ), "int8"]) -> None:
     T.attr("i0", "pragma_layout", "NHCWB16")
     for i0 in T.serial(0, 1):
         for i1 in T.serial(0, 1):
             for i2 in T.serial(0, 1):
                 for i3 in T.serial(0, 6):
                     for i4 in T.serial(0, 16):
                         placeholder1[((i3 * 16) +
                                       i4)] = placeholder2[((T.floormod(
                                           (i3 + 4), 6) * 16) + i4)]
Esempio n. 12
0
 def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
     # function attr dict
     T.func_attr({
         "global_symbol": "tvmgen_default_run_model",
         "runner_function": True
     })
     # body
     T.attr("default", "device_id", 0)
     T.attr("default", "device_type", 1)
     sid_2 = T.allocate([720000], "int8", "global")
     sid_6 = T.allocate([5760000], "int8", "global")
     sid_7 = T.allocate([720000], "int8", "global")
     sid_8 = T.allocate([720000], "int8", "global")
     T.evaluate(
         T.call_extern(
             "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast",
             input,
             T.lookup_param("p0", dtype="handle"),
             sid_2.data,
             dtype="int32"))
     T.evaluate(
         T.call_extern(
             "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast",
             sid_2.data,
             T.lookup_param("p3", dtype="handle"),
             T.lookup_param("p4", dtype="handle"),
             sid_8.data,
             dtype="int32"))
     T.evaluate(
         T.call_extern(
             "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1",
             sid_8.data,
             T.lookup_param("p5", dtype="handle"),
             T.lookup_param("p6", dtype="handle"),
             sid_7.data,
             dtype="int32"))
     T.evaluate(
         T.call_extern(
             "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_",
             sid_7.data,
             T.lookup_param("p7", dtype="handle"),
             T.lookup_param("p8", dtype="handle"),
             sid_6.data,
             dtype="int32"))
     T.evaluate(
         T.call_extern(
             "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_",
             sid_2.data,
             T.lookup_param("p1", dtype="handle"),
             T.lookup_param("p2", dtype="handle"),
             sid_6.data,
             output,
             dtype="int32"))
Esempio n. 13
0
 def main(placeholder1: T.Buffer[(100, ), "int8"],
          placeholder2: T.Buffer[(100, ), "int8"]) -> None:
     T.attr("i0", "pragma_layout", "NHWC")
     for i0 in T.serial(0, 1):
         for i1 in T.serial(0, 5):
             for i2 in T.serial(0, 6):
                 for i3 in T.serial(0, 4):
                     placeholder1[(
                         ((i1 * 24) + (i2 * 4)) +
                         i3)] = placeholder2[(((((T.floordiv(
                             (i1 - 1), 2) * 48) + (T.floormod(
                                 (i1 + 1), 2) * 24)) + (i2 * 4)) + i3) +
                                              96)]
Esempio n. 14
0
 def main(buffer2: T.Buffer[(160,), "uint8"], buffer4: T.Buffer[(144,), "uint8"], buffer6: T.Buffer[(144,), "uint8"], buffer8: T.Buffer[(144,), "uint8"]) -> None:
     # function attr dict
     T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
     v1a = T.var("int32")
     v1c = T.var("int32")
     v2a = T.var("int32")
     v2c = T.var("int32")
     v3a = T.var("int32")
     v3c = T.var("int32")
     v4a = T.var("int32")
     v4c = T.var("int32")
     buffer1 = T.buffer_decl([8192], "int8")
     buffer10 = T.buffer_decl([2048], "int8")
     # body
     p4 = T.allocate([160], "uint8", "global")
     p7 = T.allocate([144], "uint8", "global")
     p10 = T.allocate([144], "uint8", "global")
     p11 = T.allocate([144], "uint8", "global")
     with T.attr(T.iter_var(v1a, None, "DataPar", ""), "pragma_compute_cycles_hint", 201):
         T.evaluate(T.call_extern("ethosu_copy", buffer2[0], 160, p4[0], dtype="handle"))
     with T.attr(T.iter_var(v2a, None, "DataPar", ""), "pragma_compute_cycles_hint", 205):
         T.evaluate(T.call_extern("ethosu_copy", buffer4[0], 144, p7[0], dtype="handle"))
     with T.attr(T.iter_var(v1c, None, "DataPar", ""), "pragma_compute_cycles_hint", 300):
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p4[0], 128, 12, p4[128], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     with T.attr(T.iter_var(v3a, None, "DataPar", ""), "pragma_compute_cycles_hint", 209):
         T.evaluate(T.call_extern("ethosu_copy", buffer6[0], 144, p10[0], dtype="handle"))
     with T.attr(T.iter_var(v2c, None, "DataPar", ""), "pragma_compute_cycles_hint", 301):
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p7[0], 112, 12, p7[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     with T.attr(T.iter_var(v4a, None, "DataPar", ""), "pragma_compute_cycles_hint", 213):
         T.evaluate(T.call_extern("ethosu_copy", buffer8[0], 144, p11[0], dtype="handle"))
     with T.attr(T.iter_var(v3c, None, "DataPar", ""), "pragma_compute_cycles_hint", 302):
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p10[0], 112, 12, p10[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
     with T.attr(T.iter_var(v4c, None, "DataPar", ""), "pragma_compute_cycles_hint", 303):
         T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, buffer1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, buffer10[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p11[0], 112, 12, p11[112], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle"))
def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 120], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i in T.serial(0, 128):
        for ki in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("B_in_thread_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.float32(0)
            for ko in T.serial(0, 4):
                with T.block("B_normal_reduction"):
                    vi = T.axis.spatial(128, i)
                    vk = T.axis.reduce(120, ko * 32 + ki)
                    T.where(ko * 32 + ki < 120)
                    T.reads([A[vi, vk], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk]
            with T.block("B_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0[0],
                        ki,
                        dtype="handle",
                    ))
            with T.block("B_write_back"):
                vi = T.axis.spatial(128, i)
                T.reads([reduce_temp0[0]])
                T.writes([B[vi]])
                B[vi] = reduce_temp0[0]
Esempio n. 16
0
def ptx_global_to_shared_copy_fp32x1(
        A: T.Buffer[(32, 128), "float32"], B: T.Buffer[(32, 128),
                                                       "float32"]) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    bx = T.env_thread("blockIdx.x")
    tx = T.env_thread("threadIdx.x")
    T.launch_thread(bx, 1)
    T.launch_thread(tx, 32)
    with T.block():
        A_shared = T.alloc_buffer([32, 128], "float32", scope="shared")
        T.reads(A[0:32, 0:128])
        T.writes(B[0:32, 0:128])

        T.attr("default", "async_scope", 1)
        for i in T.serial(128):
            A_shared[tx, i] = A[tx, i]

        T.evaluate(T.ptx_commit_group(dtype=""))
        T.evaluate(T.ptx_wait_group(0, dtype=""))

        for i in range(128):
            B[tx, i] = A_shared[tx, i]
        def main(A_param: T.handle, C_param: T.handle):
            A = T.match_buffer(A_param, (400,), "float32", strides=[1])
            C = T.match_buffer(C_param, (4,), "float32", strides=[1])
            T.func_attr({"from_legacy_te_schedule": True})
            threadIdx_x = T.env_thread("threadIdx.x")
            T.launch_thread(threadIdx_x, 1)
            for i in T.serial(0, 100):
                B = T.allocate([4], "float32", scope="shared", strides=[1])
                with T.attr(B.data, "double_buffer_scope", 1):
                    for j in T.serial(0, 4):
                        B[j] = A[4 * i + j]

                for j in T.serial(0, 4):
                    C[j] = B[j] + 1.0
Esempio n. 18
0
def lowered_single_reduction_loop_with_block_predicate(
        A: T.Buffer[(256, 256), "float32"],
        T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    cross_thread_0 = T.alloc_buffer([1],
                                    dtype="float32",
                                    strides=[1],
                                    scope="local")
    in_thread_0 = T.alloc_buffer([1],
                                 dtype="float32",
                                 strides=[1],
                                 scope="local")
    cross_thread_1 = T.alloc_buffer([1],
                                    dtype="float32",
                                    strides=[1],
                                    scope="local")
    in_thread_1 = T.alloc_buffer([1],
                                 dtype="float32",
                                 strides=[1],
                                 scope="local")
    for i0 in T.serial(256):
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem_in_thread_init"):
                    T.reads()
                    T.writes(in_thread_0[0])
                    in_thread_0[0] = T.float32(-3.4028234663852886e38)
                with T.block("T_softmax_maxelem_in_thread"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(A[i0_1, k], in_thread_0[0])
                    T.writes(in_thread_0[0])
                    in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
                with T.block("T_softmax_maxelem_cross_thread"):
                    T.reads(in_thread_0[0])
                    T.writes(cross_thread_0[0])
                    T.attr(
                        T.comm_reducer(lambda x, y: T.max(x, y),
                                       [T.float32(-3.4028234663852886e38)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(
                            T.uint32(1),
                            in_thread_0[0],
                            True,
                            cross_thread_0.data,
                            ax1_1,
                            dtype="handle",
                        ))
                with T.block("T_softmax_maxelem_write_back"):
                    i0_2 = T.axis.spatial(256, i0)
                    T.reads(cross_thread_0[0])
                    T.writes(T_softmax_maxelem_shared[i0_2])
                    T_softmax_maxelem_shared[i0_2] = cross_thread_0[0]
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_expsum_in_thread_init"):
                    T.reads()
                    T.writes(in_thread_1[0])
                    in_thread_1[0] = T.float32(0)
                with T.block("T_softmax_expsum_in_thread"):
                    i0_3 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3],
                            in_thread_1[0])
                    T.writes(in_thread_1[0])
                    in_thread_1[0] = in_thread_1[0] + T.exp(
                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
                        dtype="float32")
                with T.block("T_softmax_expsum_cross_thread"):
                    T.reads(in_thread_1[0])
                    T.writes(cross_thread_1[0])
                    T.attr(
                        T.comm_reducer(lambda x_1, y_1: x_1 + y_1,
                                       [T.float32(0)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(
                            T.uint32(1),
                            in_thread_1[0],
                            True,
                            cross_thread_1.data,
                            ax1_1,
                            dtype="handle",
                        ))
                with T.block("T_softmax_expsum_write_back"):
                    i0_4 = T.axis.spatial(256, i0)
                    T.reads(cross_thread_1[0])
                    T.writes(T_softmax_expsum_shared[i0_4])
                    T_softmax_expsum_shared[i0_4] = cross_thread_1[0]
        for i1_0 in T.serial(1):
            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_5 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_1)
                    T.where(i1_0 * 512 + i1_1 < 256)
                    T.reads(A[i0_5, i1], T_softmax_maxelem_shared[i0_5],
                            T_softmax_expsum_shared[i0_5])
                    T.writes(T_softmax_norm[i0_5, i1])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_5, i1] = (
                        T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
                              dtype="float32") / T_softmax_expsum_shared[i0_5])
Esempio n. 19
0
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256],
                                    dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    reduce_temp1 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp1 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_maxelem_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.min_value("float32")
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_maxelem_normal_reduction"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([A[i0_1, k], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0],
                                                   A[i0_1, k])
            with T.block("T_softmax_maxelem_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: T.max(x, y),
                                   [T.min_value("float32")]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_maxelem_write_back"):
                i0_2 = T.axis.spatial(256, i0)
                T.reads([reduce_temp0[0]])
                T.writes([T_softmax_maxelem_shared[i0_2]])
                T_softmax_maxelem_shared[i0_2] = reduce_temp0[0]
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_expsum_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp1[0]])
                normal_reduce_temp1[0] = T.float32(0)
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_expsum_normal_reduction"):
                    i0_3 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([
                        A[i0_3, k],
                        T_softmax_maxelem_shared[i0_3],
                        normal_reduce_temp1[0],
                    ])
                    T.writes([normal_reduce_temp1[0]])
                    normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp(
                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
                        dtype="float32")
            with T.block("T_softmax_expsum_cross_thread_reduction"):
                T.reads([normal_reduce_temp1[0]])
                T.writes([reduce_temp1[0]])
                T.attr(
                    T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp1[0],
                        True,
                        reduce_temp1.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_expsum_write_back"):
                i0_4 = T.axis.spatial(256, i0)
                T.reads([reduce_temp1[0]])
                T.writes([T_softmax_expsum_shared[i0_4]])
                T_softmax_expsum_shared[i0_4] = reduce_temp1[0]
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_5 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads([
                        A[i0_5, i1],
                        T_softmax_maxelem_shared[i0_5],
                        T_softmax_expsum_shared[i0_5],
                    ])
                    T.writes([T_softmax_norm[i0_5, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_5, i1] = (T.exp(
                        A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
                        dtype="float32",
                    ) / T_softmax_expsum_shared[i0_5])