def _generate_detections(boxes,
                         scores,
                         max_total_size=100,
                         nms_iou_threshold=0.3,
                         score_threshold=0.05,
                         pre_nms_num_boxes=5000):
    """Generate the final detections given the model outputs.

  This uses classes unrolling with while loop based NMS, could be parralled
  at batch dimension.

  Args:
    boxes: a tensor with shape [batch_size, N, num_classes, 4] or [batch_size,
      N, 1, 4], which box predictions on all feature levels. The N is the number
      of total anchors on all levels.
    scores: a tensor with shape [batch_size, N, num_classes], which stacks class
      probability on all feature levels. The N is the number of total anchors on
      all levels. The num_classes is the number of classes predicted by the
      model. Note that the class_outputs here is the raw score.
    max_total_size: a scalar representing maximum number of boxes retained over
      all classes.
    nms_iou_threshold: a float representing the threshold for deciding whether
      boxes overlap too much with respect to IOU.
    score_threshold: a float representing the threshold for deciding when to
      remove boxes based on score.
    pre_nms_num_boxes: an int number of top candidate detections per class
      before NMS.

  Returns:
    nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
      representing top detected boxes in [y1, x1, y2, x2].
    nms_scores: `float` Tensor of shape [batch_size, max_total_size]
      representing sorted confidence scores for detected boxes. The values are
      between [0, 1].
    nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
      classes for detected boxes.
    valid_detections: `int` Tensor of shape [batch_size] only the top
      `valid_detections` boxes are valid detections.
  """
    with tf.name_scope('generate_detections'):
        nmsed_boxes = []
        nmsed_classes = []
        nmsed_scores = []
        valid_detections = []
        batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list()
        _, total_anchors, num_classes = scores.get_shape().as_list()
        # Selects top pre_nms_num scores and indices before NMS.
        scores, indices = _select_top_k_scores(
            scores, min(total_anchors, pre_nms_num_boxes))
        for i in range(num_classes):
            boxes_i = boxes[:, :, min(num_classes_for_box - 1, i), :]
            scores_i = scores[:, :, i]
            # Obtains pre_nms_num_boxes before running NMS.
            boxes_i = tf.gather(boxes_i,
                                indices[:, :, i],
                                batch_dims=1,
                                axis=1)

            # Filter out scores.
            boxes_i, scores_i = box_utils.filter_boxes_by_scores(
                boxes_i, scores_i, min_score_threshold=score_threshold)

            (nmsed_scores_i,
             nmsed_boxes_i) = nms.sorted_non_max_suppression_padded(
                 tf.cast(scores_i, tf.float32),
                 tf.cast(boxes_i, tf.float32),
                 max_total_size,
                 iou_threshold=nms_iou_threshold)
            nmsed_classes_i = tf.fill([batch_size, max_total_size], i)
            nmsed_boxes.append(nmsed_boxes_i)
            nmsed_scores.append(nmsed_scores_i)
            nmsed_classes.append(nmsed_classes_i)
    nmsed_boxes = tf.concat(nmsed_boxes, axis=1)
    nmsed_scores = tf.concat(nmsed_scores, axis=1)
    nmsed_classes = tf.concat(nmsed_classes, axis=1)
    nmsed_scores, indices = tf.nn.top_k(nmsed_scores,
                                        k=max_total_size,
                                        sorted=True)
    nmsed_boxes = tf.gather(nmsed_boxes, indices, batch_dims=1, axis=1)
    nmsed_classes = tf.gather(nmsed_classes, indices, batch_dims=1)
    valid_detections = tf.reduce_sum(input_tensor=tf.cast(
        tf.greater(nmsed_scores, -1), tf.int32),
                                     axis=1)
    return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
    def oln_multilevel_propose_rois(
        self,
        rpn_boxes,
        rpn_scores,
        anchor_boxes,
        image_shape,
        rpn_pre_nms_top_k=2000,
        rpn_post_nms_top_k=1000,
        rpn_nms_threshold=0.7,
        rpn_score_threshold=0.0,
        rpn_min_size_threshold=0.0,
        decode_boxes=True,
        clip_boxes=True,
        use_batched_nms=False,
        apply_sigmoid_to_score=True,
        is_box_lrtb=False,
        rpn_object_scores=None,
    ):
        """Proposes RoIs given a group of candidates from different FPN levels.

    The following describes the steps:
      1. For each individual level:
        a. Adjust scores for each level if specified by rpn_object_scores.
        b. Apply sigmoid transform if specified.
        c. Decode boxes (either of xyhw or left-right-top-bottom format) if
          specified.
        d. Clip boxes if specified.
        e. Filter small boxes and those fall outside image if specified.
        f. Apply pre-NMS filtering including pre-NMS top k and score
           thresholding.
        g. Apply NMS.
      2. Aggregate post-NMS boxes from each level.
      3. Apply an overall top k to generate the final selected RoIs.

    Args:
      rpn_boxes: a dict with keys representing FPN levels and values
        representing box tenors of shape [batch_size, feature_h, feature_w,
        num_anchors * 4].
      rpn_scores: a dict with keys representing FPN levels and values
        representing logit tensors of shape [batch_size, feature_h, feature_w,
        num_anchors].
      anchor_boxes: a dict with keys representing FPN levels and values
        representing anchor box tensors of shape [batch_size, feature_h,
        feature_w, num_anchors * 4].
      image_shape: a tensor of shape [batch_size, 2] where the last dimension
        are [height, width] of the scaled image.
      rpn_pre_nms_top_k: an integer of top scoring RPN proposals *per level* to
        keep before applying NMS. Default: 2000.
      rpn_post_nms_top_k: an integer of top scoring RPN proposals *in total* to
        keep after applying NMS. Default: 1000.
      rpn_nms_threshold: a float between 0 and 1 representing the IoU threshold
        used for NMS. If 0.0, no NMS is applied. Default: 0.7.
      rpn_score_threshold: a float between 0 and 1 representing the minimal box
        score to keep before applying NMS. This is often used as a pre-filtering
        step for better performance. If 0, no filtering is applied. Default: 0.
      rpn_min_size_threshold: a float representing the minimal box size in each
        side (w.r.t. the scaled image) to keep before applying NMS. This is
        often used as a pre-filtering step for better performance. If 0, no
        filtering is applied. Default: 0.
      decode_boxes: a boolean indicating whether `rpn_boxes` needs to be decoded
        using `anchor_boxes`. If False, use `rpn_boxes` directly and ignore
        `anchor_boxes`. Default: True.
      clip_boxes: a boolean indicating whether boxes are first clipped to the
        scaled image size before appliying NMS. If False, no clipping is applied
        and `image_shape` is ignored. Default: True.
      use_batched_nms: a boolean indicating whether NMS is applied in batch
        using `tf.image.combined_non_max_suppression`. Currently only available
        in CPU/GPU. Default: False.
      apply_sigmoid_to_score: a boolean indicating whether apply sigmoid to
        `rpn_scores` before applying NMS. Default: True.
      is_box_lrtb: a bool indicating whether boxes are in lrtb (=left,right,top,
        bottom) format.
      rpn_object_scores: a predicted objectness score (e.g., centerness). In
        OLN, we use object_scores=centerness as a replacement of the scores at
        each level. A dict with keys representing FPN levels and values
        representing logit tensors of shape [batch_size, feature_h, feature_w,
        num_anchors].

    Returns:
      selected_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
        representing the box coordinates of the selected proposals w.r.t. the
        scaled image.
      selected_roi_scores: a tensor of shape [batch_size, rpn_post_nms_top_k,
      1],representing the scores of the selected proposals.
    """
        with tf.name_scope('multilevel_propose_rois'):
            rois = []
            roi_scores = []
            image_shape = tf.expand_dims(image_shape, axis=1)
            for level in sorted(rpn_scores.keys()):
                with tf.name_scope('level_%d' % level):
                    _, feature_h, feature_w, num_anchors_per_location = (
                        rpn_scores[level].get_shape().as_list())

                    num_boxes = feature_h * feature_w * num_anchors_per_location
                    this_level_scores = tf.reshape(rpn_scores[level],
                                                   [-1, num_boxes])
                    this_level_boxes = tf.reshape(rpn_boxes[level],
                                                  [-1, num_boxes, 4])
                    this_level_anchors = tf.cast(tf.reshape(
                        anchor_boxes[level], [-1, num_boxes, 4]),
                                                 dtype=this_level_scores.dtype)

                    if rpn_object_scores:
                        this_level_object_scores = rpn_object_scores[level]
                        this_level_object_scores = tf.reshape(
                            this_level_object_scores, [-1, num_boxes])
                        this_level_object_scores = tf.cast(
                            this_level_object_scores, this_level_scores.dtype)
                        this_level_scores = this_level_object_scores

                    if apply_sigmoid_to_score:
                        this_level_scores = tf.sigmoid(this_level_scores)

                    if decode_boxes:
                        if is_box_lrtb:  # Box in left-right-top-bottom format.
                            this_level_boxes = box_utils.decode_boxes_lrtb(
                                this_level_boxes, this_level_anchors)
                        else:  # Box in standard x-y-h-w format.
                            this_level_boxes = box_utils.decode_boxes(
                                this_level_boxes, this_level_anchors)

                    if clip_boxes:
                        this_level_boxes = box_utils.clip_boxes(
                            this_level_boxes, image_shape)

                    if rpn_min_size_threshold > 0.0:
                        this_level_boxes, this_level_scores = box_utils.filter_boxes(
                            this_level_boxes, this_level_scores, image_shape,
                            rpn_min_size_threshold)

                    this_level_pre_nms_top_k = min(num_boxes,
                                                   rpn_pre_nms_top_k)
                    this_level_post_nms_top_k = min(num_boxes,
                                                    rpn_post_nms_top_k)
                    if rpn_nms_threshold > 0.0:
                        if use_batched_nms:
                            this_level_rois, this_level_roi_scores, _, _ = (
                                tf.image.combined_non_max_suppression(
                                    tf.expand_dims(this_level_boxes, axis=2),
                                    tf.expand_dims(this_level_scores, axis=-1),
                                    max_output_size_per_class=
                                    this_level_pre_nms_top_k,
                                    max_total_size=this_level_post_nms_top_k,
                                    iou_threshold=rpn_nms_threshold,
                                    score_threshold=rpn_score_threshold,
                                    pad_per_class=False,
                                    clip_boxes=False))
                        else:
                            if rpn_score_threshold > 0.0:
                                this_level_boxes, this_level_scores = (
                                    box_utils.filter_boxes_by_scores(
                                        this_level_boxes, this_level_scores,
                                        rpn_score_threshold))
                            this_level_boxes, this_level_scores = box_utils.top_k_boxes(
                                this_level_boxes,
                                this_level_scores,
                                k=this_level_pre_nms_top_k)
                            this_level_roi_scores, this_level_rois = (
                                nms.sorted_non_max_suppression_padded(
                                    this_level_scores,
                                    this_level_boxes,
                                    max_output_size=this_level_post_nms_top_k,
                                    iou_threshold=rpn_nms_threshold))
                    else:
                        this_level_rois, this_level_roi_scores = box_utils.top_k_boxes(
                            this_level_rois,
                            this_level_scores,
                            k=this_level_post_nms_top_k)

                    rois.append(this_level_rois)
                    roi_scores.append(this_level_roi_scores)

            all_rois = tf.concat(rois, axis=1)
            all_roi_scores = tf.concat(roi_scores, axis=1)

            with tf.name_scope('top_k_rois'):
                _, num_valid_rois = all_roi_scores.get_shape().as_list()
                overall_top_k = min(num_valid_rois, rpn_post_nms_top_k)

                selected_rois, selected_roi_scores = box_utils.top_k_boxes(
                    all_rois, all_roi_scores, k=overall_top_k)

            return selected_rois, selected_roi_scores