def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): """Get the top k elements in an input tensor along the given axis. Parameters ---------- data : tvm.Tensor The input tensor. k : int, optional Number of top elements to select. Return all elements if k < 1. axis : int, optional Axis long which to sort the input tensor. ret_type: str, optional The return type [both, values, indices]. "both": return both top k data and indices. "values": return top k data only. "indices": return top k indices only. is_ascend : boolean, optional Whether to sort in ascending or descending order. dtype : string, optional The data type of the indices output. Returns ------- out : tvm.Tensor or List[tvm.Tensor] The computed result. """ assert ret_type in ["both", "values", "indices"] data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) out_shape = list(get_const_tuple(data.shape)) if k >= 1: out_shape[axis] = k out_bufs = [] if ret_type in ["both", "values"]: out_bufs.append( api.decl_buffer(out_shape, data.dtype, "value_buf", data_alignment=8)) if ret_type in ["both", "indices"]: out_bufs.append( api.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8)) out_shapes = [out_shape] * len(out_bufs) out = tvm.extern( out_shapes, [data], lambda ins, outs: tvm.call_packed("tvm.contrib.sort.topk", ins[0], * outs, k, axis, ret_type, is_ascend), in_buffers=[data_buf], out_buffers=out_bufs, name="topk_cpu", tag="topk_cpu") return out
def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)): """Location transformation for multibox detection Parameters ---------- cls_prob : tvm.Tensor Class probabilities. loc_pred : tvm.Tensor Location regression predictions. anchor : tvm.Tensor Prior anchor boxes. clip : boolean Whether to clip out-of-boundary boxes. threshold : float Threshold to be a positive prediction. variances : tuple of float Variances to be decoded from box regression output. Returns ------- ret : tuple of tvm.Tensor composed of out : tvm.Tensor 3-D tensor with shape (batch_size, num_anchors, 6) valid_count : tvm.Tensor 1-D tensor with shape (batch_size,), number of valid anchor boxes. """ batch_size = cls_prob.shape[0] num_anchors = anchor.shape[1] oshape = (batch_size, num_anchors, 6) # Define data alignment for intermediate buffer valid_count_dtype = "int32" valid_count_buf = api.decl_buffer((batch_size, ), valid_count_dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8) valid_count, out = \ tvm.extern([(batch_size,), oshape], [cls_prob, loc_pred, anchor], lambda ins, outs: transform_loc_ir( ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances), dtype=[valid_count_dtype, cls_prob.dtype], out_buffers=[valid_count_buf, out_buf], tag="multibox_transform_loc") return [out, valid_count]
def argsort_gpu(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. Parameters ---------- data: tvm.Tensor The input array. valid_count : tvm.Tensor, optional The number of valid elements to be sorted. axis : int, optional Axis long which to sort the input tensor. is_ascend : boolean, optional Whether to sort in ascending or descending order. dtype : string, optional DType of the output indices. Returns ------- out : tvm.Tensor The output of this function. """ if valid_count is not None: sorted_data = identity(data) sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) out = tvm.extern([data.shape], [sorted_data, valid_count], lambda ins, outs: sort_nms_ir( ins[0], ins[1], outs[0], axis, is_ascend), dtype="int32", in_buffers=[sorted_data_buf, valid_count_buf], out_buffers=[out_buf], name="argsort_nms_gpu", tag="argsort_nms_gpu") else: value_buf = api.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) indices_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) out = tvm.extern([data.shape, data.shape], [data], lambda ins, outs: sort_ir( ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), out_buffers=[value_buf, indices_buf], name="argsort_gpu", tag="argsort_gpu")[1] return out
def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. Parameters ---------- data: tvm.Tensor The input array. valid_count : tvm.Tensor The number of valid elements to be sorted. axis : int Axis long which to sort the input tensor. is_ascend : boolean Whether to sort in ascending or descending order. flag : boolean Whether this argsort is used in nms operator Returns ------- out : tvm.Tensor The output of this function. """ sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8) sorted_data = identity(data) if flag: valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) out = tvm.extern([data.shape], [sorted_data, valid_count], lambda ins, outs: sort_nms_ir( ins[0], ins[1], outs[0], axis, is_ascend), dtype="int32", in_buffers=[sorted_data_buf, valid_count_buf], out_buffers=[out_buf], name="argsort_nms_gpu", tag="argsort_nms_gpu") else: out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) out = tvm.extern([data.shape], [sorted_data], lambda ins, outs: sort_ir( ins[0], outs[0], axis, is_ascend), dtype=dtype, in_buffers=[sorted_data_buf], out_buffers=[out_buf], name="argsort_gpu", tag="argsort_gpu") return out
def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. Parameters ---------- data: tvm.Tensor The input array. valid_count : tvm.Tensor The number of valid elements to be sorted. axis : int Axis long which to sort the input tensor. is_ascend : boolean Whether to sort in ascending or descending order. flag : boolean Whether this argsort is used in nms operator Returns ------- out : tvm.Tensor The output of this function. """ data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) if flag: valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) out = tvm.extern([data.shape], [data, valid_count], lambda ins, outs: sort_nms_ir( ins[0], ins[1], outs[0], axis, is_ascend), dtype="int32", in_buffers=[data_buf, valid_count_buf], out_buffers=[out_buf], name="argsort_nms_gpu", tag="argsort_nms_gpu") else: out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) out = tvm.extern([data.shape], [data], lambda ins, outs: sort_ir( ins[0], outs[0], axis, is_ascend), dtype=dtype, in_buffers=[data_buf], out_buffers=[out_buf], name="argsort_gpu", tag="argsort_gpu") return out
def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)): """Location transformation for multibox detection Parameters ---------- cls_prob : tvm.Tensor Class probabilities. loc_pred : tvm.Tensor Location regression predictions. anchor : tvm.Tensor Prior anchor boxes. clip : boolean Whether to clip out-of-boundary boxes. threshold : float Threshold to be a positive prediction. variances : tuple of float Variances to be decoded from box regression output. Returns ------- ret : tuple of tvm.Tensor """ batch_size = cls_prob.shape[0] num_anchors = anchor.shape[1] oshape = (batch_size, num_anchors, 6) # Define data alignment for intermediate buffer valid_count_dtype = "int32" valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(oshape, cls_prob.dtype, "out_buf", data_alignment=8) valid_count, out = \ tvm.extern([(batch_size,), oshape], [cls_prob, loc_pred, anchor], lambda ins, outs: transform_loc_ir( ins[0], ins[1], ins[2], outs[0], outs[1], clip, threshold, variances), dtype=[valid_count_dtype, cls_prob.dtype], out_buffers=[valid_count_buf, out_buf], tag="multibox_transform_loc") return [out, valid_count]
def multibox_transform_loc(cls_prob, loc_pred, anchor, clip=True, \ threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)): """Location transformation for multibox detection Parameters ---------- cls_prob : tvm.Tensor Class probabilities. loc_pred : tvm.Tensor Location regression predictions. anchor : tvm.Tensor Prior anchor boxes. clip : boolean Whether to clip out-of-boundary boxes. threshold : float Threshold to be a positive prediction. variances : tuple of float Variances to be decoded from box regression output. Returns ------- ret : tuple of tvm.Tensor composed of out : tvm.Tensor 3-D tensor with shape (batch_size, num_anchors, 6) valid_count : tvm.Tensor 1-D tensor with shape (batch_size,), number of valid anchor boxes. """ batch_size = cls_prob.shape[0] num_anchors = cls_prob.shape[2] oshape = (batch_size, num_anchors, 6) # Define data alignment for intermediate buffer valid_count_dtype = "int32" out_loc_dtype = loc_pred.dtype valid_count_buf = api.decl_buffer((batch_size, ), valid_count_dtype, "valid_count_buf", data_alignment=4) loc_pred_buf = api.decl_buffer(loc_pred.shape, loc_pred.dtype, "loc_pred_buf", data_alignment=8) anchor_buf = api.decl_buffer(anchor.shape, anchor.dtype, "anchor_buf", data_alignment=8) temp_valid_count_buf = api.decl_buffer(( batch_size, num_anchors, ), valid_count_dtype, "temp_valid_count", data_alignment=8) temp_cls_id_buf = api.decl_buffer(( batch_size, num_anchors, ), valid_count_dtype, "temp_cls_id", data_alignment=8) temp_score_buf = api.decl_buffer(( batch_size, num_anchors, ), cls_prob.dtype, "temp_score", data_alignment=8) valid_count, temp_valid_count, temp_cls_id, temp_score = \ tvm.extern([(batch_size,), (batch_size, num_anchors,), (batch_size, num_anchors,), \ (batch_size, num_anchors,)], [cls_prob], lambda ins, outs: transform_loc_pre( ins[0], outs[0], outs[1], outs[2], outs[3], threshold), dtype=[valid_count_dtype, valid_count_dtype, valid_count_dtype, cls_prob.dtype], out_buffers=[valid_count_buf, temp_valid_count_buf, \ temp_cls_id_buf, temp_score_buf], tag="multibox_transform_loc_phase_one") out_loc = \ tvm.extern([oshape], [loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score], lambda ins, outs: transform_loc_ir( ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, variances, \ batch_size, num_anchors), in_buffers=[loc_pred_buf, anchor_buf, temp_valid_count_buf, \ temp_cls_id_buf, temp_score_buf], dtype=[out_loc_dtype], tag="multibox_transform_loc") return [out_loc, valid_count]
def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend): """Function to generate low level IR to do sorting on the GPU, use it by calling sort_gpu. Parameters ---------- data: tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. data_buf: Buffer 2D Buffer of input boxes' score with shape [batch_size, num_anchors]. index : tvm.Tensor 1-D tensor for valid number of boxes. index_buf : Buffer Buffer of number of valid number of boxes. output_buf : Buffer Output buffer of indicies of sorted tensor. axis : int The axis used for sorting. is_descend : bool If the sorted data is in descending order. Returns ------- out : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors]. """ ndim = len(data.shape) assert data.dtype == "float32", "Currently only supports input dtype to be float32" assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim axis_mul_before = 1 axis_mul_after = 1 if axis < 0: axis = ndim + axis for i in range(0, ndim): if i < axis: axis_mul_before *= data.shape[i] elif i > axis: axis_mul_after *= data.shape[i] dshape = axis_mul_before * axis_mul_after fshape = data.shape[axis] * dshape loc_buf = api.decl_buffer(dshape, index.dtype, "sizes", data_alignment=8) new_index_buf = api.decl_buffer(fshape, index.dtype, "index_new", data_alignment=8) out_index_buf = api.decl_buffer(fshape, index.dtype, "index_out", data_alignment=8) new_data_buf = api.decl_buffer(dshape, data.dtype, "data_new", data_alignment=8) loc = \ tvm.extern([(dshape,)], [index], lambda ins, outs: sort_pre_ir( ins[0], outs[0], axis_mul_before, axis_mul_after), dtype=[index.dtype], in_buffers=index_buf, out_buffers=[loc_buf], tag="sorting_prepare") data_new, index_new = \ tvm.extern([(dshape,), (fshape,)], [data, index, loc], lambda ins, outs: sort_pre_ir_data( ins[0], ins[1], ins[2], outs[0], outs[1], axis, axis_mul_before, axis_mul_after), dtype=[data.dtype, index.dtype], in_buffers=[data_buf, index_buf, loc_buf], out_buffers=[new_data_buf, new_index_buf], tag="sorting_data") index_out = \ tvm.extern([(fshape,)], [data, index, data_new, index_new, loc], lambda ins, outs: sort_oet_ir( ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], axis_mul_before, axis_mul_after, axis, is_descend), dtype=[index.dtype], in_buffers=[data_buf, index_buf, new_data_buf, new_index_buf, loc_buf], out_buffers=[out_index_buf], tag="sorting_oet") out = \ tvm.extern([data.shape], [data, index, index_out, loc], lambda ins, outs: sort_ir_out( ins[0], ins[1], ins[2], ins[3], outs[0], axis_mul_before, axis_mul_after, axis), dtype=[index.dtype], in_buffers=[data_buf, index_buf, out_index_buf, loc_buf], out_buffers=output_buf, tag="sorting_output") return out
def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1): """Non-maximum suppression operator for object detection. Parameters ---------- data: tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. valid_count : tvm.Tensor 1-D tensor for valid number of boxes. nms_threshold : float Non-maximum suppression threshold. force_suppress : boolean Whether to suppress all detections regardless of class_id. nms_topk : int Keep maximum top k detections before nms, -1 for no limit. Returns ------- out : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. Example -------- .. code-block:: python # An example to use nms dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") valid_count = tvm.placeholder( (dshape[0],), dtype="int32", name="valid_count") nms_threshold = 0.7 force_suppress = True nms_topk = -1 out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) f = tvm.build(s, [data, valid_count, out], "llvm") ctx = tvm.cpu() tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) f(tvm_data, tvm_valid_count, tvm_out) """ batch_size = data.shape[0] num_anchors = data.shape[1] valid_count_dtype = "int32" valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4) data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) score_axis = 1 score_shape = (batch_size, num_anchors) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], name="score_tensor") score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype, "score_tensor_buf", data_alignment=8) sort_tensor_dtype = "int32" sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, "sort_tensor_buf", data_alignment=8) sort_tensor = sort_gpu(score_tensor, score_tensor_buf, valid_count, valid_count_buf, sort_tensor_buf, score_axis, True) out = \ tvm.extern(data.shape, [data, sort_tensor, valid_count], lambda ins, outs: nms_ir( ins[0], ins[1], ins[2], outs[0], nms_threshold, force_suppress, nms_topk), dtype="float32", in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], tag="nms") return out
def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. Parameters ---------- data : tvm.Tensor Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length]. score_threshold : optional, float Lower limit of score for valid bounding boxes. id_index : optional, int index of the class categories, -1 to disable. score_index: optional, int Index of the scores/confidence of boxes. Returns ------- valid_count : tvm.Tensor 1-D tensor for valid number of boxes. out_tensor : tvm.Tensor Rearranged data tensor. """ batch_size = data.shape[0] num_anchors = data.shape[1] max_threads = int( tvm.target.current_target(allow_none=False).max_num_threads) elem_per_thread = num_anchors // max_threads + 1 new_range = num_anchors // elem_per_thread + 1 temp_flag_buf = api.decl_buffer(( batch_size, num_anchors, ), "int32", "temp_flag", data_alignment=8) temp_idx_buf = api.decl_buffer(( batch_size, num_anchors, ), "int32", "temp_idx", data_alignment=8) temp_partial_buf = api.decl_buffer((batch_size, new_range), "int32", "temp_partial", data_alignment=8) data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) temp_flag, temp_idx = \ tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data], lambda ins, outs: get_valid_counts_pre( ins[0], outs[0], outs[1], score_threshold, id_index, score_index), dtype=["int32", "int32"], out_buffers=[temp_flag_buf, temp_idx_buf], name="get_valid_counts_phase_one") temp_idx_new, temp_partial = \ tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx], lambda ins, outs: get_valid_counts_upsweep( ins[0], ins[1], outs[0], outs[1]), dtype=["int32", "int32"], out_buffers=[temp_idx_buf, temp_partial_buf], name="get_valid_counts_phase_two") temp_partial_new = \ tvm.extern([(batch_size, new_range)], [data, temp_partial], lambda ins, outs: get_valid_counts_scan( ins[0], ins[1], outs[0]), dtype=["int32"], out_buffers=[temp_partial_buf], name="get_valid_counts_phase_three") temp_idx_final = \ tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new], lambda ins, outs: get_valid_counts_downsweep( ins[0], ins[1], ins[2], outs[0]), dtype=["int32"], out_buffers=[temp_idx_buf], name="get_valid_counts_phase_four") valid_count, out_tensor = \ tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final], lambda ins, outs: get_valid_counts_ir( ins[0], ins[1], ins[2], outs[0], outs[1]), dtype=["int32", data.dtype], in_buffers=[data_buf, temp_flag_buf, temp_idx_buf], name="get_valid_counts_phase_five", tag="get_valid_counts_gpu") return [valid_count, out_tensor]
def non_max_suppression_gpu(data, valid_count, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, coord_start=2, score_index=1, id_index=0, return_indices=True, invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters ---------- data : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, elem_length]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. valid_count : tvm.Tensor 1-D tensor for valid number of boxes. max_output_size : optional, int Max number of output valid boxes for each instance. By default all valid boxes are returned. iou_threshold : optional, float Non-maximum suppression threshold. force_suppress : optional, boolean Whether to suppress all detections regardless of class_id. top_k : optional, int Keep maximum top k detections before nms, -1 for no limit. coord_start : required, int Start index of the consecutive 4 coordinates. score_index : optional, int Index of the scores/confidence of boxes. id_index : optional, int index of the class categories, -1 to disable. return_indices : boolean Whether to return box indices in input data. invalid_to_bottom : optional, boolean Whether to move all valid bounding boxes to the top. Returns ------- out : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, elem_length]. Example -------- .. code-block:: python # An example to use nms dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") iou_threshold = 0.7 force_suppress = True top_k = -1 out = non_max_suppression(data=data, valid_count=valid_count, iou_threshold=iou_threshold, force_suppress=force_supress, top_k=top_k, return_indices=False) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) f = tvm.build(s, [data, valid_count, out], "cuda") ctx = tvm.gpu(0) tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) f(tvm_data, tvm_valid_count, tvm_out) """ batch_size = data.shape[0] num_anchors = data.shape[1] valid_count_dtype = "int32" valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4) score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8) data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) out_buf = api.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) out, box_indices = \ tvm.extern([data.shape, score_shape], [data, sort_tensor, valid_count], lambda ins, outs: nms_ir( ins[0], ins[1], ins[2], outs[0], outs[1], max_output_size, iou_threshold, force_suppress, top_k, coord_start, id_index, score_index), dtype=[data.dtype, "int32"], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], name="nms", tag="nms") if return_indices: return box_indices if invalid_to_bottom: output_buf = api.decl_buffer(data.shape, data.dtype, "output_buf", data_alignment=8) temp_flag_buf = api.decl_buffer(score_shape, valid_count_dtype, "temp_flag", data_alignment=8) temp_idx_buf = api.decl_buffer(score_shape, valid_count_dtype, "temp_idx", data_alignment=8) temp_flag, temp_idx = tvm.extern( [score_shape, score_shape], [out], lambda ins, outs: invalid_to_bottom_pre(ins[0], outs[0], outs[1]), dtype=["int32", "int32"], in_buffers=[out_buf], out_buffers=[temp_flag_buf, temp_idx_buf], name="invalid_to_bottom_phase_one") output = tvm.extern([data.shape], [out, temp_flag, temp_idx], lambda ins, outs: invalid_to_bottom_ir( ins[0], ins[1], ins[2], outs[0]), dtype=[data.dtype], in_buffers=[out_buf, temp_flag_buf, temp_idx_buf], out_buffers=[output_buf], name="invalid_to_bottom", tag="invalid_to_bottom") return output return out
def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): """Performs sorting along the given axis and returns an array of indices having the same shape as an input array that index data in sorted order. Parameters ---------- data : tvm.Tensor The input tensor. valid_count : tvm.Tensor 1-D tensor for valid number of boxes only for ssd. axis : optional, int Axis along which to sort the input tensor. By default the flattened array is used. is_ascend : optional, boolean Whether to sort in ascending or descending order. dtype : optional, string DType of the output indices. flag : optional, boolean Whether valid_count is valid. Returns ------- out : tvm.Tensor Sorted index tensor. Example -------- .. code-block:: python # An example to use argsort dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") axis = 0 is_ascend = False flag = False out = argsort(data, valid_count, axis, is_ascend, flag) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) s = topi.generic.schedule_argsort(out) f = tvm.build(s, [data, valid_count, out], "llvm") ctx = tvm.cpu() tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) f(tvm_data, tvm_valid_count, tvm_out) """ data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) if flag: valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8) out = \ tvm.extern(data.shape, [data, valid_count], lambda ins, outs: tvm.call_packed( "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend), dtype="int32", in_buffers=[data_buf, valid_count_buf], out_buffers=out_buf, name="argsort_nms_cpu", tag="argsort_nms_cpu") else: out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) out = \ tvm.extern(data.shape, [data], lambda ins, outs: tvm.call_packed( "tvm.contrib.sort.argsort", ins[0], outs[0], axis, is_ascend), dtype=dtype, in_buffers=[data_buf], out_buffers=out_buf, name="argsort_cpu", tag="argsort_cpu") return out
def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \ threshold=0.01, variances=(0.1, 0.1, 0.2, 0.2)): """Location transformation for multibox detection Parameters ---------- cls_prob : tvm.Tensor Class probabilities. loc_pred : tvm.Tensor Location regression predictions. anchor : tvm.Tensor Prior anchor boxes. clip : boolean Whether to clip out-of-boundary boxes. threshold : float Threshold to be a positive prediction. variances : tuple of float Variances to be decoded from box regression output. Returns ------- ret : tuple of tvm.Tensor composed of out : tvm.Tensor 3-D tensor with shape (batch_size, num_anchors, 6) valid_count : tvm.Tensor 1-D tensor with shape (batch_size,), number of valid anchor boxes. """ batch_size = cls_prob.shape[0] num_classes = cls_prob.shape[1] num_anchors = cls_prob.shape[2] oshape = (batch_size, num_anchors, 6) # Define data alignment for intermediate buffer valid_count_dtype = "int32" valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer( oshape, cls_prob.dtype, "out_buf", data_alignment=8) size = num_anchors temp_flag_buf = api.decl_buffer( (size,), valid_count_dtype, "flag", data_alignment=8) temp_id_buf = api.decl_buffer( (size,), valid_count_dtype, "cls_id", data_alignment=8) temp_score_buf = api.decl_buffer( (size,), cls_prob.dtype, "score", data_alignment=8) valid_count, temp_flag, temp_id, temp_score = \ tvm.extern([(batch_size,), (size,), (size,), (size,)], [cls_prob], lambda ins, outs: transform_loc_pre( ins[0], outs[0], outs[1], outs[2], outs[3], threshold), dtype=[valid_count_dtype, valid_count_dtype, valid_count_dtype, cls_prob.dtype], out_buffers=[valid_count_buf, temp_flag_buf, temp_id_buf, temp_score_buf], tag="multibox_transform_loc_first_step") out = \ tvm.extern([oshape], [loc_pred, anchor, temp_flag, temp_id, temp_score], lambda ins, outs: transform_loc_ir( ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \ variances, batch_size, num_classes, num_anchors), dtype=[cls_prob.dtype], out_buffers=[out_buf], tag="multibox_transform_loc") return [out, valid_count]
def non_max_suppression(data, valid_count, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, id_index=0, return_indices=True, invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters ---------- data : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. valid_count : tvm.Tensor 1-D tensor for valid number of boxes. max_output_size : optional, int Max number of output valid boxes for each instance. By default all valid boxes are returned. iou_threshold : optional, float Non-maximum suppression threshold. force_suppress : optional, boolean Whether to suppress all detections regardless of class_id. top_k : optional, int Keep maximum top k detections before nms, -1 for no limit. id_index : optional, int index of the class categories, -1 to disable. return_indices : optional, boolean Whether to return box indices in input data. invalid_to_bottom : optional, boolean Whether to move all valid bounding boxes to the top. Returns ------- out : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. Example -------- .. code-block:: python # An example to use non_max_suppression dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") iou_threshold = 0.7 force_suppress = True top_k = -1 out = non_max_suppression(data, valid_count, iou_threshold=iou_threshold, force_suppress=force_suppress, top_k=top_k) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) f = tvm.build(s, [data, valid_count, out], "llvm") ctx = tvm.cpu() tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) f(tvm_data, tvm_valid_count, tvm_out) """ batch_size = data.shape[0] num_anchors = data.shape[1] valid_count_dtype = "int32" valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4) score_axis = 1 score_shape = (batch_size, num_anchors) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype, "score_tensor_buf", data_alignment=8) sort_tensor_dtype = "int32" sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, "sort_tensor_buf", data_alignment=8) sort_tensor = \ tvm.extern(score_shape, [score_tensor, valid_count], lambda ins, outs: tvm.call_packed( "tvm.contrib.sort.argsort", ins[0], ins[1], outs[0], score_axis, True), dtype=sort_tensor_dtype, in_buffers=[score_tensor_buf, valid_count_buf], out_buffers=sort_tensor_buf, name="nms_sort") out, box_indices = hybrid_nms(data, sort_tensor, valid_count, tvm.const(max_output_size, dtype="int32"), tvm.const(iou_threshold, dtype="float32"), tvm.const(force_suppress, dtype="bool"), tvm.const(top_k, dtype="int32"), tvm.const(id_index, dtype="int32")) if not return_indices and invalid_to_bottom: out = hybrid_rearrange_out(out) return box_indices if return_indices else out
def get_valid_counts_gpu(data, score_threshold=0): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. Parameters ---------- data : tvm.Tensor Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length]. score_threshold : optional, float Lower limit of score for valid bounding boxes. Returns ------- valid_count : tvm.Tensor 1-D tensor for valid number of boxes. out_tensor : tvm.Tensor Rearranged data tensor. """ batch_size = data.shape[0] num_anchors = data.shape[1] max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) elem_per_thread = num_anchors // max_threads + 1 new_range = num_anchors // elem_per_thread + 1 temp_flag_buf = api.decl_buffer( (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8) temp_idx_buf = api.decl_buffer( (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8) temp_partial_buf = api.decl_buffer( (batch_size, new_range), "int32", "temp_partial", data_alignment=8) data_buf = api.decl_buffer( data.shape, data.dtype, "data_buf", data_alignment=8) temp_flag, temp_idx = \ tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data], lambda ins, outs: get_valid_counts_pre( ins[0], outs[0], outs[1], score_threshold), dtype=["int32", "int32"], out_buffers=[temp_flag_buf, temp_idx_buf], name="get_valid_counts_phase_one") temp_idx_new, temp_partial = \ tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx], lambda ins, outs: get_valid_counts_upsweep( ins[0], ins[1], outs[0], outs[1]), dtype=["int32", "int32"], out_buffers=[temp_idx_buf, temp_partial_buf], name="get_valid_counts_phase_two") temp_partial_new = \ tvm.extern([(batch_size, new_range)], [data, temp_partial], lambda ins, outs: get_valid_counts_scan( ins[0], ins[1], outs[0]), dtype=["int32"], out_buffers=[temp_partial_buf], name="get_valid_counts_phase_three") temp_idx_final = \ tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new], lambda ins, outs: get_valid_counts_downsweep( ins[0], ins[1], ins[2], outs[0]), dtype=["int32"], out_buffers=[temp_idx_buf], name="get_valid_counts_phase_four") valid_count, out_tensor = \ tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final], lambda ins, outs: get_valid_counts_ir( ins[0], ins[1], ins[2], outs[0], outs[1]), dtype=["int32", data.dtype], in_buffers=[data_buf, temp_flag_buf, temp_idx_buf], name="get_valid_counts_phase_five", tag="get_valid_counts_gpu") return [valid_count, out_tensor]
def non_max_suppression_gpu(data, valid_count, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, coord_start=2, score_index=1, id_index=0, return_indices=True, invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters ---------- data : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, elem_length]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. valid_count : tvm.Tensor 1-D tensor for valid number of boxes. max_output_size : optional, int Max number of output valid boxes for each instance. By default all valid boxes are returned. iou_threshold : optional, float Non-maximum suppression threshold. force_suppress : optional, boolean Whether to suppress all detections regardless of class_id. top_k : optional, int Keep maximum top k detections before nms, -1 for no limit. coord_start : required, int Start index of the consecutive 4 coordinates. score_index : optional, int Index of the scores/confidence of boxes. id_index : optional, int index of the class categories, -1 to disable. return_indices : boolean Whether to return box indices in input data. invalid_to_bottom : optional, boolean Whether to move all valid bounding boxes to the top. Returns ------- out : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, elem_length]. Example -------- .. code-block:: python # An example to use nms dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") iou_threshold = 0.7 force_suppress = True top_k = -1 out = non_max_suppression(data=data, valid_count=valid_count, iou_threshold=iou_threshold, force_suppress=force_supress, top_k=top_k, return_indices=False) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) f = tvm.build(s, [data, valid_count, out], "cuda") ctx = tvm.gpu(0) tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) f(tvm_data, tvm_valid_count, tvm_out) """ batch_size = data.shape[0] num_anchors = data.shape[1] valid_count_dtype = "int32" valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4) score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8) data_buf = api.decl_buffer( data.shape, data.dtype, "data_buf", data_alignment=8) out_buf = api.decl_buffer( data.shape, data.dtype, "out_buf", data_alignment=8) out, box_indices = \ tvm.extern([data.shape, score_shape], [data, sort_tensor, valid_count], lambda ins, outs: nms_ir( ins[0], ins[1], ins[2], outs[0], outs[1], max_output_size, iou_threshold, force_suppress, top_k, coord_start, id_index), dtype=[data.dtype, "int32"], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], name="nms", tag="nms") if return_indices: return box_indices if invalid_to_bottom: output_buf = api.decl_buffer( data.shape, data.dtype, "output_buf", data_alignment=8) temp_flag_buf = api.decl_buffer( score_shape, valid_count_dtype, "temp_flag", data_alignment=8) temp_idx_buf = api.decl_buffer( score_shape, valid_count_dtype, "temp_idx", data_alignment=8) temp_flag, temp_idx = tvm.extern([score_shape, score_shape], [out], lambda ins, outs: invalid_to_bottom_pre( ins[0], outs[0], outs[1]), dtype=["int32", "int32"], in_buffers=[out_buf], out_buffers=[temp_flag_buf, temp_idx_buf], name="invalid_to_bottom_phase_one") output = tvm.extern([data.shape], [out, temp_flag, temp_idx], lambda ins, outs: invalid_to_bottom_ir( ins[0], ins[1], ins[2], outs[0]), dtype=[data.dtype], in_buffers=[out_buf, temp_flag_buf, temp_idx_buf], out_buffers=[output_buf], name="invalid_to_bottom", tag="invalid_to_bottom") return output return out
def nms_gpu(data, valid_count, nms_threshold=0.5, force_suppress=False, nms_topk=-1): """Non-maximum suppression operator for object detection. Parameters ---------- data: tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. valid_count : tvm.Tensor 1-D tensor for valid number of boxes. nms_threshold : float Non-maximum suppression threshold. force_suppress : boolean Whether to suppress all detections regardless of class_id. nms_topk : int Keep maximum top k detections before nms, -1 for no limit. Returns ------- out : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. Example -------- .. code-block:: python # An example to use nms dshape = (1, 5, 6) data = tvm.placeholder(dshape, name="data") valid_count = tvm.placeholder( (dshape[0],), dtype="int32", name="valid_count") nms_threshold = 0.7 force_suppress = True nms_topk = -1 out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) f = tvm.build(s, [data, valid_count, out], "llvm") ctx = tvm.cpu() tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) f(tvm_data, tvm_valid_count, tvm_out) """ batch_size = data.shape[0] num_anchors = data.shape[1] valid_count_dtype = "int32" valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4) data_buf = api.decl_buffer( data.shape, data.dtype, "data_buf", data_alignment=8) score_axis = 1 score_shape = (batch_size, num_anchors) score_tensor = tvm.compute( score_shape, lambda i, j: data[i, j, score_axis], name="score_tensor") score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype, "score_tensor_buf", data_alignment=8) sort_tensor_dtype = "int32" sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, "sort_tensor_buf", data_alignment=8) sort_tensor = sort_gpu(score_tensor, score_tensor_buf, valid_count, valid_count_buf, sort_tensor_buf, score_axis, True) out = \ tvm.extern(data.shape, [data, sort_tensor, valid_count], lambda ins, outs: nms_ir( ins[0], ins[1], ins[2], outs[0], nms_threshold, force_suppress, nms_topk), dtype="float32", in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], tag="nms") return out
def sort_gpu(data, data_buf, index, index_buf, output_buf, axis, is_descend): """Function to generate low level IR to do sorting on the GPU, use it by calling sort_gpu. Parameters ---------- data: tvm.Tensor 3-D tensor with shape [batch_size, num_anchors, 6]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. data_buf: Buffer 2D Buffer of input boxes' score with shape [batch_size, num_anchors]. index : tvm.Tensor 1-D tensor for valid number of boxes. index_buf : Buffer Buffer of number of valid number of boxes. output_buf : Buffer Output buffer of indicies of sorted tensor. axis : int The axis used for sorting. is_descend : bool If the sorted data is in descending order. Returns ------- out : tvm.Tensor 3-D tensor with shape [batch_size, num_anchors]. """ ndim = len(data.shape) assert data.dtype == "float32", "Currently only supports input dtype to be float32" assert axis < ndim, "Axis out of boundary for input ndim %d" % ndim axis_mul_before = 1 axis_mul_after = 1 if axis < 0: axis = ndim + axis for i in range(0, ndim): if i < axis: axis_mul_before *= data.shape[i] elif i > axis: axis_mul_after *= data.shape[i] dshape = axis_mul_before*axis_mul_after fshape = data.shape[axis] * dshape loc_buf = api.decl_buffer(dshape, index.dtype, "sizes", data_alignment=8) new_index_buf = api.decl_buffer( fshape, index.dtype, "index_new", data_alignment=8) out_index_buf = api.decl_buffer( fshape, index.dtype, "index_out", data_alignment=8) new_data_buf = api.decl_buffer( dshape, data.dtype, "data_new", data_alignment=8) loc = \ tvm.extern([(dshape,)], [index], lambda ins, outs: sort_pre_ir( ins[0], outs[0], axis_mul_before, axis_mul_after), dtype=[index.dtype], in_buffers=index_buf, out_buffers=[loc_buf], tag="sorting_prepare") data_new, index_new = \ tvm.extern([(dshape,), (fshape,)], [data, index, loc], lambda ins, outs: sort_pre_ir_data( ins[0], ins[1], ins[2], outs[0], outs[1], axis, axis_mul_before, axis_mul_after), dtype=[data.dtype, index.dtype], in_buffers=[data_buf, index_buf, loc_buf], out_buffers=[new_data_buf, new_index_buf], tag="sorting_data") index_out = \ tvm.extern([(fshape,)], [data, index, data_new, index_new, loc], lambda ins, outs: sort_oet_ir( ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], axis_mul_before, axis_mul_after, axis, is_descend), dtype=[index.dtype], in_buffers=[data_buf, index_buf, new_data_buf, new_index_buf, loc_buf], out_buffers=[out_index_buf], tag="sorting_oet") out = \ tvm.extern([data.shape], [data, index, index_out, loc], lambda ins, outs: sort_ir_out( ins[0], ins[1], ins[2], ins[3], outs[0], axis_mul_before, axis_mul_after, axis), dtype=[index.dtype], in_buffers=[data_buf, index_buf, out_index_buf, loc_buf], out_buffers=output_buf, tag="sorting_output") return out
def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): """Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. Parameters ---------- data : tvm.Tensor Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length]. score_threshold : optional, float Lower limit of score for valid bounding boxes. id_index : optional, int index of the class categories, -1 to disable. score_index: optional, int Index of the scores/confidence of boxes. Returns ------- valid_count : tvm.Tensor 1-D tensor for valid number of boxes. out_tensor : tvm.Tensor Rearranged data tensor. """ batch_size = data.shape[0] num_anchors = data.shape[1] data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) valid_count_buf = api.decl_buffer((batch_size, ), "int32", "valid_count_buf", data_alignment=8) temp_flag_buf = api.decl_buffer(( batch_size, num_anchors, ), "int32", "temp_flag", data_alignment=8) temp_partial_buf = api.decl_buffer((batch_size, num_anchors), "int32", "temp_partial", data_alignment=8) out_buf = api.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) valid_count, temp_flag = \ tvm.extern([(batch_size,), (batch_size, num_anchors)], [data], lambda ins, outs: get_valid_counts_ir( ins[0], outs[0], outs[1], score_threshold, id_index, score_index), dtype=["int32", "int32"], in_buffers=[data_buf], out_buffers=[valid_count_buf, temp_flag_buf], name="get_valid_counts", tag="get_valid_counts_gpu") temp_partial = \ tvm.extern([(batch_size, num_anchors)], [temp_flag], lambda ins, outs: flag_scan( ins[0], outs[0]), dtype=["int32"], in_buffers=[temp_flag_buf], out_buffers=[temp_partial_buf], name="flag_scan") out = \ tvm.extern([data.shape], [data, temp_flag, temp_partial, valid_count], lambda ins, outs: out_rewrite( ins[0], ins[1], ins[2], ins[3], outs[0]), dtype=[data.dtype], in_buffers=[data_buf, temp_flag_buf, temp_partial_buf, valid_count_buf], out_buffers=[out_buf], name="out_rewrite") return [valid_count, out]
def topk_gpu(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): """Get the top k elements in an input tensor along the given axis. Parameters ---------- data : tvm.Tensor The input tensor. k : int, optional Number of top elements to select. Return all elements if k < 1. axis : int, optional Axis long which to sort the input tensor. ret_type: str, optional The return type [both, values, indices]. "both": return both top k data and indices. "values": return top k data only. "indices": return top k indices only. is_ascend : boolean, optional Whether to sort in ascending or descending order. dtype : string, optional The data type of the indices output. Returns ------- out : tvm.Tensor or List[tvm.Tensor] The computed result. """ assert ret_type in ["both", "values", "indices"] ndim = len(data.shape) axis = axis + ndim if axis < 0 else axis assert 0 <= axis < ndim values_buf = api.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8) indices_buf = api.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8) if ret_type == "values": output = tvm.extern( [data.shape], [data], lambda ins, outs: sort_ir(ins[0], outs[0], axis, is_ascend), out_buffers=[values_buf], name="topk_gpu", tag="topk_gpu") else: output = tvm.extern( [data.shape, data.shape], [data], lambda ins, outs: sort_ir( ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), out_buffers=[values_buf, indices_buf], name="topk_gpu", tag="topk_gpu") if k < 1: if ret_type == "indices": return output[1] return output beg = [0] * ndim end = [] for i in range(ndim): if i == axis: end.append(k) else: end.append(data.shape[i]) if ret_type == "both": values_out, indices_out = output values_out = strided_slice(values_out, beg, end) indices_out = strided_slice(indices_out, beg, end) output = [values_out, indices_out] elif ret_type == "values": output = [strided_slice(output, beg, end)] else: # ret_type == "indices" indices_out = output[1] output = [strided_slice(indices_out, beg, end)] return output