Exemplo n.º 1
0
 def main(A: T.Buffer[(256, 256), "float32"],
          T_softmax_norm: T.Buffer[(256, 256), "float32"]) -> None:
     T_softmax_maxelem = T.alloc_buffer([256], dtype="float32")
     T_softmax_expsum = T.alloc_buffer([256], dtype="float32")
     for i0, i1 in T.grid(256, 256):
         with T.block("T_softmax_maxelem"):
             i0_1, k = T.axis.remap("SR", [i0, i1])
             with T.init():
                 T_softmax_maxelem[i0_1] = T.min_value("float32")
             T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1],
                                             A[i0_1, k])
     for i0, i1 in T.grid(256, 256):
         with T.block("T_softmax_expsum"):
             i0_2, k = T.axis.remap("SR", [i0, i1])
             with T.init():
                 T_softmax_expsum[i0_2] = T.float32(0)
             T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp(
                 A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32")
     for i0_3, i1 in T.grid(256, 256):
         with T.block("T_softmax_norm"):
             i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1])
             T_softmax_norm[i0_4,
                            i1_1] = T.exp(
                                A[i0_4, i1_1] - T_softmax_maxelem[i0_4],
                                dtype="float32") / T_softmax_expsum[i0_4]
Exemplo n.º 2
0
def lowered_reducer_max(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    for i in T.serial(0, 128):
        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block("B_cross_thread_reduction"):
                vi, vk = T.axis.remap("SR", [i, k])
                T.reads([A[vi, vk]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: T.max(x, y),
                                   [T.min_value("float32")]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(T.uint32(1),
                                           A[vi, vk],
                                           True,
                                           reduce_temp0.data,
                                           k,
                                           dtype="handle"))
            with T.block("B_write_back"):
                vi = T.axis.spatial(128, i)
                T.reads([reduce_temp0[0]])
                T.writes([B[vi]])
                B[vi] = reduce_temp0[0]
Exemplo n.º 3
0
def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256], dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_maxelem"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]])
                    T.writes([T_softmax_maxelem_shared[i0_1]])
                    with T.init():
                        T_softmax_maxelem_shared[i0_1] = T.min_value("float32")
                    T_softmax_maxelem_shared[i0_1] = T.max(
                        T_softmax_maxelem_shared[i0_1], A[i0_1, k]
                    )
        for ax0_0 in T.serial(0, 8):
            for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_expsum"):
                    i0_2 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads(
                        [
                            T_softmax_expsum_shared[i0_2],
                            A[i0_2, k],
                            T_softmax_maxelem_shared[i0_2],
                        ]
                    )
                    T.writes([T_softmax_expsum_shared[i0_2]])
                    with T.init():
                        T_softmax_expsum_shared[i0_2] = T.float32(0)
                    T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp(
                        A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32"
                    )
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_3 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads(
                        [
                            A[i0_3, i1],
                            T_softmax_maxelem_shared[i0_3],
                            T_softmax_expsum_shared[i0_3],
                        ]
                    )
                    T.writes([T_softmax_norm[i0_3, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_3, i1] = (
                        T.exp(
                            A[i0_3, i1] - T_softmax_maxelem_shared[i0_3],
                            dtype="float32",
                        )
                        / T_softmax_expsum_shared[i0_3]
                    )
Exemplo n.º 4
0
def reducer_max(a: T.handle, b: T.handle) -> None:
    A = T.match_buffer(a, [128, 128], dtype="float32")
    B = T.match_buffer(b, [128], dtype="float32")
    for i in T.serial(0, 128):
        for k in T.thread_binding(0, 128, thread="threadIdx.x"):
            with T.block("B"):
                vi, vk = T.axis.remap("SR", [i, k])
                T.reads([B[vi], A[vi, vk]])
                T.writes([B[vi]])
                with T.init():
                    B[vi] = T.min_value("float32")
                B[vi] = T.max(B[vi], A[vi, vk])
Exemplo n.º 5
0
def tir_argmax_val_idx(
    var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle
) -> None:
    T.func_attr({"global_symbol": "main", "tir.noalias": True})
    m = T.var("int32")
    n = T.var("int32")
    val = T.match_buffer(var_val, [m, n], dtype="float32")
    idx = T.match_buffer(var_idx, [m, n], dtype="int32")
    argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32")
    argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32")
    for i0, i1 in T.grid(m, n):
        with T.block("argmax"):
            i, k = T.axis.remap("SR", [i0, i1])
            T.reads(val[i, k], idx[i, k])
            T.writes(argmax_v0[i], argmax_v1[i])
            with T.init():
                argmax_v0[i] = T.min_value("float32")
                argmax_v1[i] = T.int32(-1)
            v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k])
            v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k])
            argmax_v0[i] = v_argmax_v0
            argmax_v1[i] = v_argmax_v1
Exemplo n.º 6
0
def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None:
    A = T.match_buffer(var_A, [256, 256], dtype="float32")
    T_softmax_norm = T.match_buffer(var_T_softmax_norm, [256, 256],
                                    dtype="float32")
    T_softmax_maxelem_shared = T.alloc_buffer([256],
                                              dtype="float32",
                                              scope="shared")
    T_softmax_expsum_shared = T.alloc_buffer([256],
                                             dtype="float32",
                                             scope="shared")
    reduce_temp0 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp0 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    reduce_temp1 = T.alloc_buffer([1],
                                  dtype="float32",
                                  strides=[1],
                                  scope="local")
    normal_reduce_temp1 = T.alloc_buffer([1],
                                         dtype="float32",
                                         strides=[1],
                                         scope="local")
    for i0 in T.thread_binding(0, 256, thread="blockIdx.x"):
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_maxelem_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp0[0]])
                normal_reduce_temp0[0] = T.min_value("float32")
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_maxelem_normal_reduction"):
                    i0_1 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([A[i0_1, k], normal_reduce_temp0[0]])
                    T.writes([normal_reduce_temp0[0]])
                    normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0],
                                                   A[i0_1, k])
            with T.block("T_softmax_maxelem_cross_thread_reduction"):
                T.reads([normal_reduce_temp0[0]])
                T.writes([reduce_temp0[0]])
                T.attr(
                    T.comm_reducer(lambda x, y: T.max(x, y),
                                   [T.min_value("float32")]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp0[0],
                        True,
                        reduce_temp0.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_maxelem_write_back"):
                i0_2 = T.axis.spatial(256, i0)
                T.reads([reduce_temp0[0]])
                T.writes([T_softmax_maxelem_shared[i0_2]])
                T_softmax_maxelem_shared[i0_2] = reduce_temp0[0]
        for ax0_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
            with T.block("T_softmax_expsum_normal_reduction_init"):
                T.reads([])
                T.writes([normal_reduce_temp1[0]])
                normal_reduce_temp1[0] = T.float32(0)
            for ax0_0 in T.serial(0, 8):
                with T.block("T_softmax_expsum_normal_reduction"):
                    i0_3 = T.axis.spatial(256, i0)
                    k = T.axis.reduce(256, ax0_0 * 32 + ax0_1)
                    T.reads([
                        A[i0_3, k],
                        T_softmax_maxelem_shared[i0_3],
                        normal_reduce_temp1[0],
                    ])
                    T.writes([normal_reduce_temp1[0]])
                    normal_reduce_temp1[0] = normal_reduce_temp1[0] + T.exp(
                        A[i0_3, k] - T_softmax_maxelem_shared[i0_3],
                        dtype="float32")
            with T.block("T_softmax_expsum_cross_thread_reduction"):
                T.reads([normal_reduce_temp1[0]])
                T.writes([reduce_temp1[0]])
                T.attr(
                    T.comm_reducer(lambda x_1, y_1: x_1 + y_1, [T.float32(0)]),
                    "reduce_scope",
                    T.reinterpret(T.uint64(0), dtype="handle"),
                )
                T.evaluate(
                    T.tvm_thread_allreduce(
                        T.uint32(1),
                        normal_reduce_temp1[0],
                        True,
                        reduce_temp1.data,
                        ax0_1,
                        dtype="handle",
                    ))
            with T.block("T_softmax_expsum_write_back"):
                i0_4 = T.axis.spatial(256, i0)
                T.reads([reduce_temp1[0]])
                T.writes([T_softmax_expsum_shared[i0_4]])
                T_softmax_expsum_shared[i0_4] = reduce_temp1[0]
        for i1_0 in T.serial(0, 8):
            for i1_1 in T.thread_binding(0, 32, thread="threadIdx.x"):
                with T.block("T_softmax_norm"):
                    i0_5 = T.axis.spatial(256, i0)
                    i1 = T.axis.spatial(256, i1_0 * 32 + i1_1)
                    T.reads([
                        A[i0_5, i1],
                        T_softmax_maxelem_shared[i0_5],
                        T_softmax_expsum_shared[i0_5],
                    ])
                    T.writes([T_softmax_norm[i0_5, i1]])
                    T.block_attr({"axis": 1})
                    T_softmax_norm[i0_5, i1] = (T.exp(
                        A[i0_5, i1] - T_softmax_maxelem_shared[i0_5],
                        dtype="float32",
                    ) / T_softmax_expsum_shared[i0_5])