예제 #1
0
 def main(A: T.Buffer[(256, 256), "float32"],
          T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
     T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
     T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
     for i0, i1 in T.grid(256, 256):
         with T.block("T_softmax_maxelem"):
             i0_1, k = T.axis.remap("SR", [i0, i1])
             with T.init():
                 T_softmax_maxelem[i0_1] = T.min_value("float32")
             T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1],
                                             A[i0_1, k])
     for i0, i1 in T.grid(256, 256):
         with T.block("T_softmax_expsum"):
             i0_2, k = T.axis.remap("SR", [i0, i1])
             with T.init():
                 T_softmax_expsum[i0_2] = T.float32(0)
             T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(
                 A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32")
     for i0_3, i1 in T.grid(256, 256):
         with T.block("T_softmax_norm"):
             i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1])
             T_softmax_norm[i0_4,
                            i1_1] = T.exp(
                                A[i0_4, i1_1] - T_softmax_maxelem[i0_4],
                                dtype="float32") / T_softmax_expsum[i0_4]
예제 #2
0
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]])
                    T.writes([T_softmax_maxelem_shared[i0_1]])
                    with T.init():
                        T_softmax_maxelem_shared[i0_1] = T.min_value("float32")
                    T_softmax_maxelem_shared[i0_1] = T.max(
                        T_softmax_maxelem_shared[i0_1], A[i0_1, k]
                    )
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_expsum"):
                    i0_2 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads(
                        [
                            T_softmax_expsum_shared[i0_2],
                            A[i0_2, k],
                            T_softmax_maxelem_shared[i0_2],
                        ]
                    )
                    T.writes([T_softmax_expsum_shared[i0_2]])
                    with T.init():
                        T_softmax_expsum_shared[i0_2] = T.float32(0)
                    T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
                        A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32"
                    )
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_3 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads(
                        [
                            A[i0_3, i1],
                            T_softmax_maxelem_shared[i0_3],
                            T_softmax_expsum_shared[i0_3],
                        ]
                    )
                    T.writes([T_softmax_norm[i0_3, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_3, i1] = (
                        T.exp(
                            A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
                            dtype="float32",
                        )
                        / T_softmax_expsum_shared[i0_3]
                    )
예제 #3
0
def single_reduction_loop_with_block_predicate(
        A: T.Buffer[(256, 256), "float32"],
        T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    for i0 in T.serial(256):
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k])
                    T.writes(T_softmax_maxelem_shared[i0_1])
                    with T.init():
                        T_softmax_maxelem_shared[i0_1] = T.float32(
                            -3.4028234663852886e38)
                    T_softmax_maxelem_shared[i0_1] = T.max(
                        T_softmax_maxelem_shared[i0_1], A[i0_1, k])
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_expsum"):
                    i0_2 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(T_softmax_expsum_shared[i0_2], A[i0_2, k],
                            T_softmax_maxelem_shared[i0_2])
                    T.writes(T_softmax_expsum_shared[i0_2])
                    with T.init():
                        T_softmax_expsum_shared[i0_2] = T.float32(0)
                    T_softmax_expsum_shared[
                        i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
                            A[i0_2, k] - T_softmax_maxelem_shared[i0_2],
                            dtype="float32")
        for i1_0 in T.serial(1):
            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_3 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_1)
                    T.where(i1_0 * 512 + i1_1 < 256)
                    T.reads(A[i0_3, i1], T_softmax_maxelem_shared[i0_3],
                            T_softmax_expsum_shared[i0_3])
                    T.writes(T_softmax_norm[i0_3, i1])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_3, i1] = (
                        T.exp(A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
                              dtype="float32") / T_softmax_expsum_shared[i0_3])
def exp_exp_opaque_access_with_tvm_access_ptr_inlined(
    lookup_table: T.Buffer[(1024,), "int8"],
    x: T.Buffer[(16,), "float16"],
    compute: T.Buffer[(16,), "float16"],
) -> None:
    for i0 in T.serial(16):
        with T.block("compute_1"):
            i0_1 = T.axis.spatial(16, i0)
            # Do not put the opaque access to new write region when opaque access
            # wrapped with a tvm_access_ptr and the access mask set to "read only"
            T.reads(x[i0_1], lookup_table[0:1024])
            T.writes(compute[i0_1])
            compute[i0_1] = T.exp(
                T.exp(x[i0_1], dtype="float16"),
                lookup_table.access_ptr("r"),
                dtype="float16",
            )
def exp_exp_opaque_access_with_tvm_access_ptr(
    lookup_table: T.Buffer[(1024,), "int8"],
    x: T.Buffer[(16,), "float16"],
    compute: T.Buffer[(16,), "float16"],
) -> None:
    compute_1 = T.alloc_buffer([16], dtype="float16")
    for i0 in T.serial(16):
        with T.block("compute"):
            i0_1 = T.axis.spatial(16, i0)
            T.reads(x[i0_1])
            T.writes(compute_1[i0_1])
            compute_1[i0_1] = T.exp(x[i0_1], dtype="float16")
    for i0 in T.serial(16):
        with T.block("compute_1"):
            i0_2 = T.axis.spatial(16, i0)
            T.reads(compute_1[i0_2], lookup_table[0:1024])
            T.writes(compute[i0_2])
            compute[i0_2] = T.exp(
                compute_1[i0_2],
                lookup_table.access_ptr("r"),
                dtype="float16",
            )
예제 #6
0
def different_access_indices(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128, 128], dtype="float32")
    B = T.match_buffer(b, [128, 128], dtype="float32")
    for i, j in T.grid(128, 128):
        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block("B"):
                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                T.reads([B[vi, vj], A[vi, vj, vk]])
                T.writes([
                    B[T.min(vj, vi):T.min(vj, vi)  # type: ignore[misc]
                      + (T.max(vj, vi) + 1 - T.min(vj, vi)),
                      T.min(vi, vj):T.min(vi, vj)  # type: ignore[misc]
                      + (T.max(vi, vj) + 1 - T.min(vi, vj)), ]
                ])
                with T.init():
                    B[vj, vi] = T.exp(B[vj, vi], dtype="float32")
                B[vi, vj] = B[vi, vj] + A[vi, vj, vk]
예제 #7
0
 def func_rewritten(A: T.Buffer[(8, ), "float32"]) -> None:
     B = T.allocate((1, ), "float32", "global")
     for i in range(8):
         B[0] = 3.14
         x: T.float32 = T.exp(B[0], dtype="float32")
         A[i] = (x + 1.0) / (x - 1.0)
예제 #8
0
 def func(A: T.Buffer[(8, ), "float32"]):
     for i in range(8):
         B = T.allocate((1, ), "float32", "global")
         B[0] = 3.14
         x: T.float32 = T.exp(B[0], dtype="float32")
         A[i] = (x + 1.0) / (x - 1.0)
예제 #9
0
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256],
                                    dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    reduce_temp1 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp1 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_maxelem_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.min_value("float32")
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_maxelem_normal_reduction"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([A[i0_1, k], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0],
                                                   A[i0_1, k])
            with T.block("T_softmax_maxelem_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: T.max(x, y),
                                   [T.min_value("float32")]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_maxelem_write_back"):
                i0_2 = T.axis.spatial(256, i0)
                T.reads([reduce_temp0[0]])
                T.writes([T_softmax_maxelem_shared[i0_2]])
                T_softmax_maxelem_shared[i0_2] = reduce_temp0[0]
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_expsum_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp1[0]])
                normal_reduce_temp1[0] = T.float32(0)
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_expsum_normal_reduction"):
                    i0_3 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([
                        A[i0_3, k],
                        T_softmax_maxelem_shared[i0_3],
                        normal_reduce_temp1[0],
                    ])
                    T.writes([normal_reduce_temp1[0]])
                    normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp(
                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
                        dtype="float32")
            with T.block("T_softmax_expsum_cross_thread_reduction"):
                T.reads([normal_reduce_temp1[0]])
                T.writes([reduce_temp1[0]])
                T.attr(
                    T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp1[0],
                        True,
                        reduce_temp1.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_expsum_write_back"):
                i0_4 = T.axis.spatial(256, i0)
                T.reads([reduce_temp1[0]])
                T.writes([T_softmax_expsum_shared[i0_4]])
                T_softmax_expsum_shared[i0_4] = reduce_temp1[0]
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_5 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads([
                        A[i0_5, i1],
                        T_softmax_maxelem_shared[i0_5],
                        T_softmax_expsum_shared[i0_5],
                    ])
                    T.writes([T_softmax_norm[i0_5, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_5, i1] = (T.exp(
                        A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
                        dtype="float32",
                    ) / T_softmax_expsum_shared[i0_5])
예제 #10
0
def lowered_single_reduction_loop_with_block_predicate(
        A: T.Buffer[(256, 256), "float32"],
        T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    cross_thread_0 = T.alloc_buffer([1],
                                    dtype="float32",
                                    strides=[1],
                                    scope="local")
    in_thread_0 = T.alloc_buffer([1],
                                 dtype="float32",
                                 strides=[1],
                                 scope="local")
    cross_thread_1 = T.alloc_buffer([1],
                                    dtype="float32",
                                    strides=[1],
                                    scope="local")
    in_thread_1 = T.alloc_buffer([1],
                                 dtype="float32",
                                 strides=[1],
                                 scope="local")
    for i0 in T.serial(256):
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem_in_thread_init"):
                    T.reads()
                    T.writes(in_thread_0[0])
                    in_thread_0[0] = T.float32(-3.4028234663852886e38)
                with T.block("T_softmax_maxelem_in_thread"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(A[i0_1, k], in_thread_0[0])
                    T.writes(in_thread_0[0])
                    in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k])
                with T.block("T_softmax_maxelem_cross_thread"):
                    T.reads(in_thread_0[0])
                    T.writes(cross_thread_0[0])
                    T.attr(
                        T.comm_reducer(lambda x, y: T.max(x, y),
                                       [T.float32(-3.4028234663852886e38)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(
                            T.uint32(1),
                            in_thread_0[0],
                            True,
                            cross_thread_0.data,
                            ax1_1,
                            dtype="handle",
                        ))
                with T.block("T_softmax_maxelem_write_back"):
                    i0_2 = T.axis.spatial(256, i0)
                    T.reads(cross_thread_0[0])
                    T.writes(T_softmax_maxelem_shared[i0_2])
                    T_softmax_maxelem_shared[i0_2] = cross_thread_0[0]
        for ax0, ax1_0 in T.grid(1, 1):
            for ax1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_expsum_in_thread_init"):
                    T.reads()
                    T.writes(in_thread_1[0])
                    in_thread_1[0] = T.float32(0)
                with T.block("T_softmax_expsum_in_thread"):
                    i0_3 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax1_1)
                    T.where(ax1_0 * 512 + ax1_1 < 256)
                    T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3],
                            in_thread_1[0])
                    T.writes(in_thread_1[0])
                    in_thread_1[0] = in_thread_1[0] + T.exp(
                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
                        dtype="float32")
                with T.block("T_softmax_expsum_cross_thread"):
                    T.reads(in_thread_1[0])
                    T.writes(cross_thread_1[0])
                    T.attr(
                        T.comm_reducer(lambda x_1, y_1: x_1 + y_1,
                                       [T.float32(0)]),
                        "reduce_scope",
                        T.reinterpret(T.uint64(0), dtype="handle"),
                    )
                    T.evaluate(
                        T.tvm_thread_allreduce(
                            T.uint32(1),
                            in_thread_1[0],
                            True,
                            cross_thread_1.data,
                            ax1_1,
                            dtype="handle",
                        ))
                with T.block("T_softmax_expsum_write_back"):
                    i0_4 = T.axis.spatial(256, i0)
                    T.reads(cross_thread_1[0])
                    T.writes(T_softmax_expsum_shared[i0_4])
                    T_softmax_expsum_shared[i0_4] = cross_thread_1[0]
        for i1_0 in T.serial(1):
            for i1_1 in T.thread_binding(512, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_5 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_1)
                    T.where(i1_0 * 512 + i1_1 < 256)
                    T.reads(A[i0_5, i1], T_softmax_maxelem_shared[i0_5],
                            T_softmax_expsum_shared[i0_5])
                    T.writes(T_softmax_norm[i0_5, i1])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_5, i1] = (
                        T.exp(A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
                              dtype="float32") / T_softmax_expsum_shared[i0_5])