Esempio n. 1
0
def _get_sorted_indices(data, data_buf, score_index, score_shape):
    """Extract a 1D score tensor from the packed input and do argsort on it."""
    score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8)
    score_tensor = te.extern(
        [score_shape],
        [data],
        lambda ins, outs: _fetch_score_ir(
            ins[0],
            outs[0],
            score_index,
        ),
        dtype=[data.dtype],
        in_buffers=[data_buf],
        out_buffers=[score_buf],
        name="fetch_score",
        tag="fetch_score",
    )

    target = tvm.target.Target.current()
    if target and (
        can_use_thrust(target, "tvm.contrib.thrust.sort")
        or can_use_rocthrust(target, "tvm.contrib.thrust.sort")
    ):
        sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32")
    else:
        sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32")

    return sort_tensor
Esempio n. 2
0
    def do_scan(data, output_dtype):
        target = tvm.target.Target.current()
        if target and (can_use_thrust(target, "tvm.contrib.thrust.sum_scan")
                       or can_use_rocthrust(target,
                                            "tvm.contrib.thrust.sum_scan")):
            return scan_thrust(data,
                               output_dtype,
                               exclusive=True,
                               return_reduction=return_reduction,
                               binop=binop)

        if ndim == 1:
            # TIR exclusive scan accepts only 2D or higher-rank inputs.
            data = expand_dims(data, axis=0)

        data_buf = tvm.tir.decl_buffer(data.shape,
                                       data.dtype,
                                       "data_buf",
                                       data_alignment=8)
        output_buf = tvm.tir.decl_buffer(data.shape,
                                         output_dtype,
                                         "output_buf",
                                         data_alignment=8)

        if return_reduction:
            output, reduction = te.extern(
                [data.shape, data.shape[:-1]],
                [data],
                lambda ins, outs: exclusive_scan_ir(
                    ins[0], outs[0], outs[1], binop=binop),
                dtype=[data.dtype, output_dtype],
                in_buffers=[data_buf],
                name="exclusive_scan",
                tag="exclusive_scan_gpu",
            )
        else:
            output = te.extern(
                [data.shape],
                [data],
                lambda ins, outs: exclusive_scan_ir(
                    ins[0], outs[0], binop=binop),
                dtype=[output_dtype],
                in_buffers=[data_buf],
                out_buffers=[output_buf],
                name="exclusive_scan",
                tag="exclusive_scan_gpu",
            )
            reduction = None

        if ndim == 1:
            output = squeeze(output, 0)
            if return_reduction:
                reduction = squeeze(reduction, 0)

        if return_reduction:
            return output, reduction

        return output
Esempio n. 3
0
def _dispatch_sort(scores, ret_type="indices"):
    target = tvm.target.Target.current()
    if target and (can_use_thrust(target, "tvm.contrib.thrust.sort")
                   or can_use_rocthrust(target, "tvm.contrib.thrust.sort")):
        return argsort_thrust(scores,
                              axis=1,
                              is_ascend=False,
                              dtype="int32",
                              ret_type=ret_type)
    return argsort(scores,
                   axis=1,
                   is_ascend=False,
                   dtype="int32",
                   ret_type=ret_type)
Esempio n. 4
0
def sort_strategy_cuda(attrs, inputs, out_type, target):
    """sort cuda strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_sort(topi.cuda.sort),
        wrap_topi_schedule(topi.cuda.schedule_sort),
        name="sort.cuda",
    )
    if can_use_thrust(target, "tvm.contrib.thrust.sort"):
        strategy.add_implementation(
            wrap_compute_sort(topi.cuda.sort_thrust),
            wrap_topi_schedule(topi.cuda.schedule_sort),
            name="sort_thrust.cuda",
            plevel=15,
        )
    return strategy
Esempio n. 5
0
def scatter_cuda(attrs, inputs, out_type, target):
    """scatter cuda strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_scatter(topi.cuda.scatter),
        wrap_topi_schedule(topi.cuda.schedule_scatter),
        name="scatter.cuda",
        plevel=10,
    )

    rank = len(inputs[0].shape)

    with SpecializedCondition(rank == 1):
        if can_use_thrust(target, "tvm.contrib.thrust.stable_sort_by_key"):
            strategy.add_implementation(
                wrap_compute_scatter(topi.cuda.scatter_via_sort),
                wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort),
                name="scatter_via_sort.cuda",
                plevel=9,  # use the sequential version by default
            )
    return strategy