def test_select_from_ten_identical_boxes(self):
        corners = tf.constant(10 * [[0, 0, 1, 1, 0]], tf.float32)
        rboxes = rbox_list.RBoxList(corners)
        rboxes.add_field('scores', tf.constant(10 * [.9]))
        iou_thresh = .5
        max_output_size = 3

        exp_nms = [[0, 0, 1, 1, 0]]
        nms = rbox_list_ops.non_max_suppression(rboxes, iou_thresh,
                                                max_output_size)
        with self.test_session() as sess:
            nms_output = sess.run(nms.get())
            self.assertAllClose(nms_output, exp_nms)
    def test_select_at_most_two_boxes_from_three_clusters(self):
        corners = tf.constant(
            [[0.5, 0.5, 1, 1, 0.0], [0.5, 0.6, 1, 1, 0.0],
             [0.5, 0.4, 1, 1, 0.0], [0.5, 10.5, 1, 1, 0.0],
             [0.5, 10.6, 1, 1, 0.0], [0.5, 100.5, 1, 1, 0.0]], tf.float32)
        boxes = rbox_list.RBoxList(corners)
        boxes.add_field('scores', tf.constant([.9, .75, .6, .95, .5, .3]))
        iou_thresh = .5
        max_output_size = 2

        exp_nms = [[0.5, 10.5, 1, 1, 0.0], [0.5, 0.5, 1, 1, 0.0]]
        nms = rbox_list_ops.non_max_suppression(boxes, iou_thresh,
                                                max_output_size)
        with self.test_session() as sess:
            nms_output = sess.run(nms.get())
            self.assertAllClose(nms_output, exp_nms)
def multiclass_non_max_suppression_rbox(boxes,
                                        scores,
                                        score_thresh,
                                        iou_thresh,
                                        max_size_per_class,
                                        max_total_size=0,
                                        image_shape=(1, 1, 1, 1),
                                        clip_window=None,
                                        change_coordinate_frame=False,
                                        masks=None,
                                        handle_as_single_class=False,
                                        additional_fields=None,
                                        intersection_tf=False,
                                        scope=None):
    """Multi-class version of non maximum suppression.

    This op greedily selects a subset of detection bounding boxes, pruning
    away boxes that have high IOU (intersection over union) overlap (> thresh)
    with already selected boxes.  It operates independently for each class for
    which scores are provided (via the scores field of the input box_list),
    pruning boxes with score less than a provided threshold prior to
    applying NMS.

    Please note that this operation is performed on *all* classes, therefore any
    background classes should be removed prior to calling this function.

    Args:
      boxes: A [k, q, 5] float32 tensor containing k detections. `q` can be either
        number of classes or 1 depending on whether a separate box is predicted
        per class.
      scores: A [k, num_classes] float32 tensor containing the scores for each of
        the k detections.
      score_thresh: scalar threshold for score (low scoring boxes are removed).
      iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
        with previously selected boxes are removed).
      max_size_per_class: maximum number of retained boxes per class.
      max_total_size: maximum number of boxes retained over all classes. By
        default returns all boxes retained after capping boxes per class.
      clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
        representing the window to clip and normalize boxes to before performing
        non-max suppression.
      change_coordinate_frame: Whether to normalize coordinates after clipping
        relative to clip_window (this can only be set to True if a clip_window
        is provided)
      masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
        containing box masks. `q` can be either number of classes or 1 depending
        on whether a separate mask is predicted per class.
      handle_as_single_class: (optional) if True, all classes are merged into just one class.
      additional_fields: (optional) If not None, a dictionary that maps keys to
        tensors whose first dimensions are all of size `k`. After non-maximum
        suppression, all tensors corresponding to the selected boxes will be
        added to resulting BoxList.
      intersection_tf: (optional) Whether to use a tf version of the intersection.
      scope: name scope.

    Returns:
      a BoxList holding M boxes with a rank-1 scores field representing
        corresponding scores for each box with scores sorted in decreasing order
        and a rank-1 classes field representing a class label for each box.
        If masks, keypoints, keypoint_heatmaps is not None, the boxlist will
        contain masks, keypoints, keypoint_heatmaps corresponding to boxes.

    Raises:
      ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
        a valid scores field.
    """
    if not 0 <= iou_thresh <= 1.0:
        raise ValueError('iou_thresh must be between 0 and 1')
    if scores.shape.ndims != 2:
        raise ValueError('scores field must be of rank 2')
    if scores.shape[1].value is None:
        raise ValueError(
            'scores must have statically defined second dimension')
    if boxes.shape.ndims != 3:
        raise ValueError('boxes must be of rank 3.')
    if not (boxes.shape[1].value == scores.shape[1].value
            or boxes.shape[1].value == 1):
        raise ValueError(
            'second dimension of boxes must be either 1 or equal to the second dimension of scores'
        )
    if boxes.shape[2].value != 5:
        raise ValueError('last dimension of boxes must be of size 5.')
    if change_coordinate_frame and clip_window is None:
        raise ValueError(
            'if change_coordinate_frame is True, then a clip_window must be specified.'
        )

    with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
        num_boxes = tf.shape(boxes)[0]
        num_scores = tf.shape(scores)[0]
        num_classes = scores.get_shape()[1]

        length_assert = tf.Assert(tf.equal(num_boxes, num_scores), [
            'Incorrect scores field length: actual vs expected.', num_scores,
            num_boxes
        ])

        selected_boxes_list = []
        is_run_handle_as_single_class = False
        if handle_as_single_class and num_classes > 1:
            # [k]
            max_idxs = tf.argmax(scores, axis=1, output_type=tf.int32)

            # boxes : [k, q, 5]
            # [k]
            boxs_range = tf.range(num_boxes, dtype=tf.int32)
            indices = tf.concat([
                tf.expand_dims(boxs_range, axis=1),
                tf.expand_dims(max_idxs, axis=1)
            ],
                                axis=1)

            boxes = tf.gather_nd(boxes, indices)
            boxes = tf.expand_dims(boxes, axis=1)
            scores = tf.reduce_max(scores, axis=1, keep_dims=True)
            is_run_handle_as_single_class = True
            num_classes = 1

        per_class_boxes_list = tf.unstack(boxes, axis=1)

        if masks is not None:
            per_class_masks_list = tf.unstack(masks, axis=1)

        boxes_ids = (range(num_classes)
                     if len(per_class_boxes_list) > 1 else [0] * num_classes)

        for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
            per_class_boxes = per_class_boxes_list[boxes_idx]
            boxlist_and_class_scores = rbox_list.RBoxList(per_class_boxes)
            with tf.control_dependencies([length_assert]):
                class_scores = tf.reshape(
                    tf.slice(scores, [0, class_idx], tf.stack([num_scores,
                                                               1])), [-1])
            boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                               class_scores)

            if is_run_handle_as_single_class:
                max_idxs = tf.cast(max_idxs, tf.float32)
                boxlist_and_class_scores.add_field(
                    fields.BoxListFields.classes, max_idxs)

            if masks is not None:
                per_class_masks = per_class_masks_list[boxes_idx]
                boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
                                                   per_class_masks)

            if additional_fields is not None:
                for key, tensor in additional_fields.items():
                    boxlist_and_class_scores.add_field(key, tensor)

            rboxlist_filtered = rbox_list_ops.filter_greater_than(
                boxlist_and_class_scores, score_thresh)
            if clip_window is not None:
                # Not implemented
                # boxlist_filtered = rbox_list_ops.clip_to_window(boxlist_filtered, clip_window)
                if change_coordinate_frame:
                    rboxlist_filtered = rbox_list_ops.change_coordinate_frame(
                        rboxlist_filtered, clip_window)
            max_selection_size = tf.minimum(max_size_per_class,
                                            rboxlist_filtered.num_boxes())

            # TODO: support frozen graph
            nms_result = rbox_list_ops.non_max_suppression(
                rboxlist_filtered,
                iou_thresh,
                max_selection_size,
                image_shape=(image_shape[1], image_shape[2]),
                intersection_tf=intersection_tf)

            if not is_run_handle_as_single_class:
                nms_result.add_field(
                    fields.BoxListFields.classes, (tf.zeros_like(
                        nms_result.get_field(fields.BoxListFields.scores)) +
                                                   class_idx))
            selected_boxes_list.append(nms_result)

        selected_boxes = rbox_list_ops.concatenate(selected_boxes_list)
        sorted_boxes = rbox_list_ops.sort_by_field(selected_boxes,
                                                   fields.BoxListFields.scores)
        if max_total_size:
            max_total_size = tf.minimum(max_total_size,
                                        sorted_boxes.num_boxes())
            sorted_boxes = rbox_list_ops.gather(sorted_boxes,
                                                tf.range(max_total_size))
        return sorted_boxes