Beispiel #1
0
def nms_ir(data, sorted_index, valid_count, out, box_indices, max_output_size,
           iou_threshold, force_suppress, top_k, coord_start, id_index,
           score_index):
    """Low level IR routing for transform location in multibox_detection operator.

    Parameters
    ----------
    data : Buffer
        Buffer of output boxes with class and score.

    sort_index : Buffer
        Buffer of output box indexes sorted by score.

    valid_count : Buffer
        Buffer of number of valid output boxes.

    out : Buffer
        Output buffer.

    max_output_size : int
        Max number of output valid boxes for each instance.
        By default all valid boxes are returned.

    iou_threshold : float
        Overlapping(IoU) threshold to suppress object with smaller score.

    force_suppress : boolean
        Whether to suppress all detections regardless of class_id.

    top_k : int
        Keep maximum top k detections before nms, -1 for no limit.

    coord_start : int
        Start index of the consecutive 4 coordinates.

    id_index : int
        index of the class categories, -1 to disable.

    score_index : optional, int
        Index of the scores/confidence of boxes.

    Returns
    -------
    stmt : Stmt
        The result IR statement.
    """
    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
        """Calculate overlap of two boxes.
        """
        w = tvm.max(
            0.0,
            tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) -
            tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
        h = tvm.max(
            0.0,
            tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) -
            tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
        i = w * h
        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
        return tvm.expr.Select(u <= 0.0, 0.0, i / u)

    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    box_data_length = data.shape[2]

    ib = tvm.ir_builder.create()

    data = ib.buffer_ptr(data)
    sorted_index = ib.buffer_ptr(sorted_index)
    valid_count = ib.buffer_ptr(valid_count)
    out = ib.buffer_ptr(out)
    box_indices = ib.buffer_ptr(box_indices)
    num_valid_boxes = ib.allocate("int32", (1, ),
                                  name="num_valid_boxes",
                                  scope="local")

    max_threads = int(
        tvm.target.current_target(allow_none=False).max_num_threads)
    nthread_tx = max_threads
    nthread_bx = num_anchors // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    j = bx * max_threads + tx

    iou_threshold = tvm.make.node("FloatImm",
                                  dtype="float32",
                                  value=iou_threshold)
    top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
    coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
    id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
    score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
    force_suppress = tvm.make.node("IntImm",
                                   dtype="int32",
                                   value=1 if force_suppress else 0)

    with ib.for_range(0, batch_size, for_type="unroll") as i:
        base_idx = i * num_anchors * box_data_length
        with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)):
            # Reorder output
            nkeep = if_then_else( \
                    tvm.all(top_k > 0, top_k < valid_count[i]),
                    top_k, valid_count[i])
            with ib.if_scope(j < nkeep):
                with ib.for_range(0, box_data_length) as k:
                    out[(base_idx + j * box_data_length + k)] = \
                    data[(base_idx + sorted_index[i * num_anchors + j] \
                    * box_data_length + k)]
                box_indices[i * num_anchors +
                            j] = sorted_index[i * num_anchors + j]
            with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
                with ib.if_scope(j < valid_count[i] - nkeep):
                    with ib.for_range(0, box_data_length) as k:
                        out[(base_idx + (j + nkeep) * box_data_length +
                             k)] = -1.0
                    box_indices[i * num_anchors + (j + nkeep)] = -1
            # Apply nms
            with ib.for_range(0, valid_count[i]) as k:
                offset_k = k * box_data_length
                with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \
                    tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))):
                    with ib.if_scope(j < valid_count[i]):
                        offset_j = j * box_data_length
                        with ib.if_scope(tvm.all(j > k, \
                            out[base_idx + offset_j + score_index] > 0, \
                                                 tvm.any(id_index < 0, \
                                                    out[base_idx + offset_j + id_index] >= 0), \
       tvm.any(force_suppress > 0, id_index < 0, \
                                                         out[base_idx + offset_k + id_index] == \
                                                         out[base_idx + offset_j + id_index]))):
                            iou = calculate_overlap(
                                out, base_idx + offset_j + coord_start,
                                base_idx + offset_k + coord_start)
                            with ib.if_scope(iou >= iou_threshold):
                                out[base_idx + offset_j + score_index] = -1.0
                                with ib.if_scope(id_index >= 0):
                                    out[base_idx + offset_j + id_index] = -1.0
                                box_indices[i * num_anchors + j] = -1
        with ib.else_scope():
            with ib.if_scope(j < valid_count[i]):
                offset_j = j * box_data_length
                with ib.for_range(0, box_data_length) as k:
                    out[(base_idx + offset_j + k)] = data[base_idx + offset_j +
                                                          k]
                box_indices[i * num_anchors + j] = j
        # Set invalid entry to be -1
        with ib.if_scope(j < num_anchors - valid_count[i]):
            with ib.for_range(0, box_data_length) as k:
                out[base_idx + (j + valid_count[i]) * box_data_length +
                    k] = -1.0
            box_indices[i * num_anchors + j + valid_count[i]] = -1
        # Only return max_output_size number of valid boxes
        num_valid_boxes[0] = 0
        with ib.if_scope(max_output_size > 0):
            with ib.if_scope(j < valid_count[i]):
                offset_j = j * box_data_length
                with ib.if_scope(out[base_idx + offset_j] >= 0):
                    with ib.if_scope(num_valid_boxes[0] == max_output_size):
                        with ib.for_range(0, box_data_length) as k:
                            out[base_idx + offset_j + k] = -1.0
                        box_indices[i * num_anchors + j] = -1
                    with ib.else_scope():
                        num_valid_boxes[0] += 1

    return ib.get()
Beispiel #2
0
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
    """Low level IR routing for multibox_prior operator.

    Parameters
    ----------
    data : Buffer
        Input data buffer.

    out : Buffer
        Output buffer.

    sizes : tuple of float
        Tuple of sizes for anchor boxes.

    ratios : tuple of float
        Tuple of ratios for anchor boxes.

    steps : Tuple of float
        Priorbox step across y and x, -1 for auto calculation.

    offsets : tuple of int
        Priorbox center offsets, y and x respectively.

    Returns
    -------
    stmt : Stmt
        The result IR statement.
    """
    max_threads = int(
        math.sqrt(tvm.target.Target.current(allow_none=False).max_num_threads))
    tx = tvm.thread_axis("threadIdx.x")
    ty = tvm.thread_axis("threadIdx.y")
    bx = tvm.thread_axis("blockIdx.x")
    by = tvm.thread_axis("blockIdx.y")
    ib = tvm.ir_builder.create()
    p_out = ib.buffer_ptr(out)
    in_height = data.shape[2]
    in_width = data.shape[3]
    nthread_tx = max_threads
    nthread_bx = in_height // max_threads + 1
    nthread_ty = max_threads
    nthread_by = in_width // max_threads + 1
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(ty, "thread_extent", nthread_ty)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    ib.scope_attr(by, "thread_extent", nthread_by)

    num_sizes = len(sizes)
    num_ratios = len(ratios)
    size_ratio_concat = sizes + ratios
    steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height
    steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width
    offset_h = offsets[0]
    offset_w = offsets[1]

    i = bx * max_threads + tx
    j = by * max_threads + ty
    with ib.if_scope((i < in_height)):
        with ib.if_scope((j < in_width)):
            center_h = (i + offset_h) * steps_h
            center_w = (j + offset_w) * steps_w

            for k in range(num_sizes + num_ratios - 1):
                w = if_then_else(
                    k < num_sizes,
                    float(size_ratio_concat[k]) * in_height / in_width / 2.0,
                    float(size_ratio_concat[0]) * in_height / in_width *
                    math.sqrt(size_ratio_concat[k + 1]) / 2.0)
                h = if_then_else(
                    k < num_sizes, size_ratio_concat[k] / 2.0,
                    size_ratio_concat[0] /
                    math.sqrt(size_ratio_concat[k + 1]) / 2.0)
                count = (i * in_width * (num_sizes + num_ratios - 1) + j *
                         (num_sizes + num_ratios - 1) + k) * 4
                p_out[count] = center_w - w
                p_out[count + 1] = center_h - h
                p_out[count + 2] = center_w + w
                p_out[count + 3] = center_h + h

    body = ib.get()
    return body
Beispiel #3
0
def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id,
                      temp_score, threshold):
    """Low level IR routing for transform location data preparation.

    Parameters
    ----------
    cls_prob : Buffer
        Buffer of class probabilities.

    valid_count : Buffer
        Buffer of number of valid output boxes.

    temp_valid_count : Buffer
        Output intermediate result buffer

    temp_cls_id : Buffer
        Output intermediate result buffer

    temp_score : Buffer
        Output buffer

    threshold : float
        Threshold to be a positive prediction.

    Returns
    -------
    stmt : Stmt
        The result IR statement.
    """
    batch_size = cls_prob.shape[0]
    num_classes = cls_prob.shape[1]
    num_anchors = cls_prob.shape[2]

    ib = tvm.ir_builder.create()

    cls_prob = ib.buffer_ptr(cls_prob)
    cls_id = ib.buffer_ptr(temp_cls_id)
    valid_count = ib.buffer_ptr(valid_count)
    temp_valid_count = ib.buffer_ptr(temp_valid_count)
    score = ib.buffer_ptr(temp_score)

    threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold)

    max_threads = int(
        tvm.target.Target.current(allow_none=False).max_num_threads)
    nthread_tx = max_threads
    nthread_bx = (batch_size * num_anchors) // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    tid = bx * max_threads + tx
    idxd = tvm.indexdiv
    idxm = tvm.indexmod

    with ib.if_scope(tid < batch_size * num_anchors):
        i = idxd(tid, num_anchors)
        j = idxm(tid, num_anchors)
        valid_count[i] = 0
        score[tid] = -1.0
        cls_id[tid] = 0
        with ib.for_range(0, num_classes - 1) as k:
            temp = cls_prob[i * num_classes * num_anchors +
                            (k + 1) * num_anchors + j]
            cls_id[tid] = if_then_else(temp > score[tid], k + 1, cls_id[tid])
            score[tid] = tvm.max(temp, score[tid])
        with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)):
            cls_id[tid] = 0
        with ib.if_scope(cls_id[tid] > 0):
            temp_valid_count[tid] = 1
        with ib.else_scope():
            temp_valid_count[tid] = 0

        with ib.if_scope(tid < batch_size):
            with ib.for_range(0, num_anchors) as k:
                with ib.if_scope(k > 0):
                    temp_valid_count[tid * num_anchors + k] += \
                    temp_valid_count[tid * num_anchors + k - 1]
            valid_count[i] = temp_valid_count[tid * num_anchors + num_anchors -
                                              1]

    return ib.get()
Beispiel #4
0
def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
    """Low level IR routing for multibox_prior operator.

    Parameters
    ----------
    data : Buffer
        Input data buffer.

    out : Buffer
        Output buffer.

    sizes : tuple of float
        Tuple of sizes for anchor boxes.

    ratios : tuple of float
        Tuple of ratios for anchor boxes.

    steps : Tuple of float
        Priorbox step across y and x, -1 for auto calculation.

    offsets : tuple of int
        Priorbox center offsets, y and x respectively.

    Returns
    -------
    stmt : Stmt
        The result IR statement.
    """
    max_threads = int(math.sqrt(
        tvm.target.current_target(allow_none=False).max_num_threads))
    tx = tvm.thread_axis("threadIdx.x")
    ty = tvm.thread_axis("threadIdx.y")
    bx = tvm.thread_axis("blockIdx.x")
    by = tvm.thread_axis("blockIdx.y")
    ib = tvm.ir_builder.create()
    p_out = ib.buffer_ptr(out)
    in_height = data.shape[2]
    in_width = data.shape[3]
    nthread_tx = max_threads
    nthread_bx = in_height // max_threads + 1
    nthread_ty = max_threads
    nthread_by = in_width // max_threads + 1
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(ty, "thread_extent", nthread_ty)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    ib.scope_attr(by, "thread_extent", nthread_by)

    num_sizes = len(sizes)
    num_ratios = len(ratios)
    size_ratio_concat = sizes + ratios
    steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height
    steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width
    offset_h = offsets[0]
    offset_w = offsets[1]

    i = bx * max_threads + tx
    j = by * max_threads + ty
    with ib.if_scope((i < in_height)):
        with ib.if_scope((j < in_width)):
            center_h = (i + offset_h) * steps_h
            center_w = (j + offset_w) * steps_w

            for k in range(num_sizes + num_ratios - 1):
                w = if_then_else(k < num_sizes,
                                 size_ratio_concat[k] * in_height / in_width / 2.0,
                                 size_ratio_concat[0] * in_height / in_width *
                                 math.sqrt(size_ratio_concat[k + 1]) / 2.0)
                h = if_then_else(
                    k < num_sizes, size_ratio_concat[k] / 2.0,
                    size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
                count = (i * in_width * (num_sizes + num_ratios - 1) +
                         j * (num_sizes + num_ratios - 1) + k) * 4
                p_out[count] = center_w - w
                p_out[count + 1] = center_h - h
                p_out[count + 2] = center_w + w
                p_out[count + 3] = center_h + h

    body = ib.get()
    return body
Beispiel #5
0
def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp_score, threshold):
    """Low level IR routing for transform location data preparation.

    Parameters
    ----------
    cls_prob : Buffer
        Buffer of class probabilities.

    valid_count : Buffer
        Buffer of number of valid output boxes.

    temp_valid_count : Buffer
        Output intermediate result buffer

    temp_cls_id : Buffer
        Output intermediate result buffer

    temp_score : Buffer
        Output buffer

    threshold : float
        Threshold to be a positive prediction.

    Returns
    -------
    stmt : Stmt
        The result IR statement.
    """
    batch_size = cls_prob.shape[0]
    num_classes = cls_prob.shape[1]
    num_anchors = cls_prob.shape[2]

    ib = tvm.ir_builder.create()

    cls_prob = ib.buffer_ptr(cls_prob)
    cls_id = ib.buffer_ptr(temp_cls_id)
    valid_count = ib.buffer_ptr(valid_count)
    temp_valid_count = ib.buffer_ptr(temp_valid_count)
    score = ib.buffer_ptr(temp_score)

    threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold)

    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
    nthread_tx = max_threads
    nthread_bx = (batch_size *  num_anchors) // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    tid = bx * max_threads + tx

    with ib.if_scope(tid < batch_size * num_anchors):
        i = tid / num_anchors
        j = tid % num_anchors
        valid_count[i] = 0
        score[tid] = -1.0
        cls_id[tid] = 0
        with ib.for_range(0, num_classes - 1) as k:
            temp = cls_prob[i * num_classes * num_anchors + (k + 1) * num_anchors + j]
            cls_id[tid] = if_then_else(temp > score[tid], k + 1, cls_id[tid])
            score[tid] = tvm.max(temp, score[tid])
        with ib.if_scope(tvm.all(cls_id[tid] > 0, score[tid] < threshold)):
            cls_id[tid] = 0
        with ib.if_scope(cls_id[tid] > 0):
            temp_valid_count[tid] = 1
        with ib.else_scope():
            temp_valid_count[tid] = 0

        with ib.if_scope(tid < batch_size):
            with ib.for_range(0, num_anchors) as k:
                with ib.if_scope(k > 0):
                    temp_valid_count[tid * num_anchors + k] += \
                    temp_valid_count[tid * num_anchors + k - 1]
            valid_count[i] = temp_valid_count[tid * num_anchors + num_anchors - 1]

    return ib.get()
Beispiel #6
0
Datei: nms.py Projekt: bddppq/tvm
def nms_ir(data, sorted_index, valid_count, out, box_indices,
           max_output_size, iou_threshold, force_suppress,
           top_k, coord_start, id_index):
    """Low level IR routing for transform location in multibox_detection operator.

    Parameters
    ----------
    data : Buffer
        Buffer of output boxes with class and score.

    sort_index : Buffer
        Buffer of output box indexes sorted by score.

    valid_count : Buffer
        Buffer of number of valid output boxes.

    out : Buffer
        Output buffer.

    max_output_size : int
        Max number of output valid boxes for each instance.
        By default all valid boxes are returned.

    iou_threshold : float
        Overlapping(IoU) threshold to suppress object with smaller score.

    force_suppress : boolean
        Whether to suppress all detections regardless of class_id.

    top_k : int
        Keep maximum top k detections before nms, -1 for no limit.

    coord_start : int
        Start index of the consecutive 4 coordinates.

    id_index : int
        index of the class categories, -1 to disable.

    Returns
    -------
    stmt : Stmt
        The result IR statement.
    """
    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
        """Calculate overlap of two boxes.
        """
        w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
                    - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
        h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
                    - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
        i = w * h
        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
        return tvm.expr.Select(u <= 0.0, 0.0, i / u)

    batch_size = data.shape[0]
    num_anchors = data.shape[1]
    box_data_length = data.shape[2]

    ib = tvm.ir_builder.create()

    data = ib.buffer_ptr(data)
    sorted_index = ib.buffer_ptr(sorted_index)
    valid_count = ib.buffer_ptr(valid_count)
    out = ib.buffer_ptr(out)
    box_indices = ib.buffer_ptr(box_indices)
    num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")

    max_threads = int(math.sqrt(
        tvm.target.current_target(allow_none=False).max_num_threads))
    nthread_tx = max_threads
    nthread_bx = num_anchors // max_threads + 1
    tx = tvm.thread_axis("threadIdx.x")
    bx = tvm.thread_axis("blockIdx.x")
    ib.scope_attr(tx, "thread_extent", nthread_tx)
    ib.scope_attr(bx, "thread_extent", nthread_bx)
    k = bx * max_threads + tx

    iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
    top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
    coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
    id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
    force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)

    with ib.for_range(0, batch_size, for_type="unroll") as i:
        base_idx = i * num_anchors * box_data_length
        with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)):
            # Reorder output
            nkeep = if_then_else( \
                    tvm.all(top_k > 0, top_k < valid_count[i]),
                    top_k, valid_count[i])
            with ib.for_range(0, nkeep) as j:
                with ib.if_scope(k < box_data_length):
                    out[(base_idx + j * box_data_length + k)] = \
                    data[(base_idx + sorted_index[i * num_anchors + j] \
                    * box_data_length + k)]
                box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
            with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
                with ib.for_range(0, valid_count[i] - nkeep) as j:
                    with ib.if_scope(k < box_data_length):
                        out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
                    box_indices[i * num_anchors + (j + nkeep)] = -1
            # Apply nms
            with ib.for_range(0, valid_count[i]) as j:
                offset_j = j * box_data_length
                with ib.if_scope(out[base_idx + offset_j] >= 0):
                    with ib.if_scope(k < valid_count[i]):
                        offset_k = k * box_data_length
                        with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \
						 tvm.any(force_suppress > 0, id_index < 0, \
                                                         out[base_idx + offset_j] == \
                                                         out[base_idx + offset_k]))):
                            iou = calculate_overlap(out, base_idx + offset_k + coord_start,
                                                    base_idx + offset_j + coord_start)
                            with ib.if_scope(iou >= iou_threshold):
                                out[base_idx + offset_k] = -1.0
                                box_indices[i * num_anchors + k] = -1
                ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                                      tvm.convert(['shared']),
                                      tvm.expr.Call.Intrinsic, None, 0))
        with ib.else_scope():
            with ib.for_range(0, valid_count[i]) as j:
                offset_j = j * box_data_length
                with ib.if_scope(k < box_data_length):
                    out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
                box_indices[i * num_anchors + j] = j
        # Set invalid entry to be -1
        with ib.for_range(0, num_anchors - valid_count[i]) as j:
            with ib.if_scope(k < box_data_length):
                out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
            box_indices[i * num_anchors + j + valid_count[i]] = -1
        # Only return max_output_size number of valid boxes
        num_valid_boxes[0] = 0
        with ib.if_scope(max_output_size > 0):
            with ib.for_range(0, valid_count[i]) as j:
                offset_j = j * box_data_length
                with ib.if_scope(out[base_idx + offset_j] >= 0):
                    with ib.if_scope(num_valid_boxes[0] == max_output_size):
                        with ib.if_scope(k < box_data_length):
                            out[base_idx + offset_j + k] = -1.0
                        box_indices[i * num_anchors + j] = -1
                    with ib.else_scope():
                        num_valid_boxes[0] += 1
                ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
                                      tvm.convert(['shared']),
                                      tvm.expr.Call.Intrinsic, None, 0))

    return ib.get()