def _get_sorted_indices(data, data_buf, score_index, score_shape): """Extract a 1D score tensor from the packed input and do argsort on it.""" score_buf = tvm.tir.decl_buffer(score_shape, data.dtype, "score_buf", data_alignment=8) score_tensor = te.extern( [score_shape], [data], lambda ins, outs: _fetch_score_ir( ins[0], outs[0], score_index, ), dtype=[data.dtype], in_buffers=[data_buf], out_buffers=[score_buf], name="fetch_score", tag="fetch_score", ) target = tvm.target.Target.current() if target and ( can_use_thrust(target, "tvm.contrib.thrust.sort") or can_use_rocthrust(target, "tvm.contrib.thrust.sort") ): sort_tensor = argsort_thrust(score_tensor, axis=1, is_ascend=False, dtype="int32") else: sort_tensor = argsort(score_tensor, axis=1, is_ascend=False, dtype="int32") return sort_tensor
def do_scan(data, output_dtype): target = tvm.target.Target.current() if target and (can_use_thrust(target, "tvm.contrib.thrust.sum_scan") or can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan")): return scan_thrust(data, output_dtype, exclusive=True, return_reduction=return_reduction, binop=binop) if ndim == 1: # TIR exclusive scan accepts only 2D or higher-rank inputs. data = expand_dims(data, axis=0) data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) if return_reduction: output, reduction = te.extern( [data.shape, data.shape[:-1]], [data], lambda ins, outs: exclusive_scan_ir( ins[0], outs[0], outs[1], binop=binop), dtype=[data.dtype, output_dtype], in_buffers=[data_buf], name="exclusive_scan", tag="exclusive_scan_gpu", ) else: output = te.extern( [data.shape], [data], lambda ins, outs: exclusive_scan_ir( ins[0], outs[0], binop=binop), dtype=[output_dtype], in_buffers=[data_buf], out_buffers=[output_buf], name="exclusive_scan", tag="exclusive_scan_gpu", ) reduction = None if ndim == 1: output = squeeze(output, 0) if return_reduction: reduction = squeeze(reduction, 0) if return_reduction: return output, reduction return output
def _dispatch_sort(scores, ret_type="indices"): target = tvm.target.Target.current() if target and (can_use_thrust(target, "tvm.contrib.thrust.sort") or can_use_rocthrust(target, "tvm.contrib.thrust.sort")): return argsort_thrust(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type) return argsort(scores, axis=1, is_ascend=False, dtype="int32", ret_type=ret_type)
def sort_strategy_cuda(attrs, inputs, out_type, target): """sort rocm strategy""" strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_sort(topi.cuda.sort), wrap_topi_schedule(topi.cuda.schedule_sort), name="sort.rocm", ) if can_use_rocthrust(target, "tvm.contrib.thrust.sort"): strategy.add_implementation( wrap_compute_sort(topi.cuda.sort_thrust), wrap_topi_schedule(topi.cuda.schedule_sort), name="sort_thrust.cuda", plevel=15, ) return strategy
def scatter_cuda(attrs, inputs, out_type, target): """scatter rocm strategy""" strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_scatter(topi.cuda.scatter), wrap_topi_schedule(topi.cuda.schedule_scatter), name="scatter.rocm", plevel=10, ) rank = len(inputs[0].shape) with SpecializedCondition(rank == 1): if can_use_rocthrust(target, "tvm.contrib.thrust.stable_sort_by_key"): strategy.add_implementation( wrap_compute_scatter(topi.cuda.scatter_via_sort), wrap_topi_schedule(topi.cuda.schedule_scatter_via_sort), name="scatter_via_sort.rocm", plevel=9, # use the sequential version by default ) return strategy