def _create_regression_targets(self, anchors, groundtruth_boxes, match):
     """Returns a regression target for each anchor.
     Args:
       anchors: a BoxList representing N anchors
       groundtruth_boxes: a BoxList representing M groundtruth_boxes
       match: a matcher.Match object
     Returns:
       reg_targets: a float32 tensor with shape [N, box_code_dimension]
     """
     matched_anchor_indices = match.matched_column_indices()
     unmatched_ignored_anchor_indices = (
         match.unmatched_or_ignored_column_indices())
     matched_gt_indices = match.matched_row_indices()
     matched_anchors = box_list_ops.gather(anchors, matched_anchor_indices)
     matched_gt_boxes = box_list_ops.gather(groundtruth_boxes,
                                            matched_gt_indices)
     matched_reg_targets = self._box_coder.encode(matched_gt_boxes,
                                                  matched_anchors)
     unmatched_ignored_reg_targets = tf.tile(
         self._default_regression_target(),
         tf.stack([tf.size(unmatched_ignored_anchor_indices), 1]))
     reg_targets = tf.dynamic_stitch(
         [matched_anchor_indices, unmatched_ignored_anchor_indices],
         [matched_reg_targets, unmatched_ignored_reg_targets])
     # TODO: summarize the number of matches on average.
     return reg_targets
  def _create_regression_targets(self, anchors, groundtruth_boxes, match):
    """Returns a regression target for each anchor.

    Args:
      anchors: a BoxList representing N anchors
      groundtruth_boxes: a BoxList representing M groundtruth_boxes
      match: a matcher.Match object

    Returns:
      reg_targets: a float32 tensor with shape [N, box_code_dimension]
    """
    matched_anchor_indices = match.matched_column_indices()
    unmatched_ignored_anchor_indices = (match.
                                        unmatched_or_ignored_column_indices())
    matched_gt_indices = match.matched_row_indices()
    matched_anchors = box_list_ops.gather(anchors,
                                          matched_anchor_indices)
    matched_gt_boxes = box_list_ops.gather(groundtruth_boxes,
                                           matched_gt_indices)
    matched_reg_targets = self._box_coder.encode(matched_gt_boxes,
                                                 matched_anchors)
    unmatched_ignored_reg_targets = tf.tile(
        self._default_regression_target(),
        tf.stack([tf.size(unmatched_ignored_anchor_indices), 1]))
    reg_targets = tf.dynamic_stitch(
        [matched_anchor_indices, unmatched_ignored_anchor_indices],
        [matched_reg_targets, unmatched_ignored_reg_targets])
    # TODO: summarize the number of matches on average.
    return reg_targets
  def test_gather_with_invalid_field(self):
    corners = tf.constant([4 * [0.0], 4 * [1.0]])
    indices = tf.constant([0, 1], tf.int32)
    weights = tf.constant([[.1], [.3]], tf.float32)

    boxes = box_list.BoxList(corners)
    boxes.add_field('weights', weights)
    with self.assertRaises(ValueError):
      box_list_ops.gather(boxes, indices, ['foo', 'bar'])
Exemple #4
0
  def test_gather_with_invalid_field(self):
    corners = tf.constant([4 * [0.0], 4 * [1.0]])
    indices = tf.constant([0, 1], tf.int32)
    weights = tf.constant([[.1], [.3]], tf.float32)

    boxes = box_list.BoxList(corners)
    boxes.add_field('weights', weights)
    with self.assertRaises(ValueError):
      box_list_ops.gather(boxes, indices, ['foo', 'bar'])
 def test_gather_with_invalid_inputs(self):
   corners = tf.constant(
       [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
   indices_float32 = tf.constant([0, 2, 4], tf.float32)
   boxes = box_list.BoxList(corners)
   with self.assertRaises(ValueError):
     _ = box_list_ops.gather(boxes, indices_float32)
   indices_2d = tf.constant([[0, 2, 4]], tf.int32)
   boxes = box_list.BoxList(corners)
   with self.assertRaises(ValueError):
     _ = box_list_ops.gather(boxes, indices_2d)
Exemple #6
0
 def test_gather_with_invalid_inputs(self):
   corners = tf.constant(
       [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
   indices_float32 = tf.constant([0, 2, 4], tf.float32)
   boxes = box_list.BoxList(corners)
   with self.assertRaises(ValueError):
     _ = box_list_ops.gather(boxes, indices_float32)
   indices_2d = tf.constant([[0, 2, 4]], tf.int32)
   boxes = box_list.BoxList(corners)
   with self.assertRaises(ValueError):
     _ = box_list_ops.gather(boxes, indices_2d)
  def test_dynamic_gather_with_field(self):
    corners = tf.placeholder(tf.float32, [None, 4])
    indices = tf.placeholder(tf.int32, [None])
    weights = tf.placeholder(tf.float32, [None, 1])
    expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
    expected_weights = [[.1], [.5], [.9]]

    boxes = box_list.BoxList(corners)
    boxes.add_field('weights', weights)
    subset = box_list_ops.gather(boxes, indices, ['weights'],
                                 use_static_shapes=True)
    with self.test_session() as sess:
      subset_output, weights_output = sess.run(
          [subset.get(), subset.get_field('weights')],
          feed_dict={
              corners:
                  np.array(
                      [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]]),
              indices:
                  np.array([0, 2, 4]).astype(np.int32),
              weights:
                  np.array([[.1], [.3], [.5], [.7], [.9]])
          })
      self.assertAllClose(subset_output, expected_subset)
      self.assertAllClose(weights_output, expected_weights)
    def _create_mtl_targets(self, groundtruth_boxes, match, field_name):
        """Returns a regression target for each anchor.

    Args:
      groundtruth_boxes: a BoxList representing M groundtruth_boxes
      match: a matcher.Match object

    Returns:
      reg_targets: a float32 tensor with shape [N, box_code_dimension]
    """
        matched_anchor_indices = match.matched_column_indices()
        unmatched_ignored_anchor_indices = (
            match.unmatched_or_ignored_column_indices())
        matched_gt_indices = match.matched_row_indices()

        matched_gt_boxes = box_list_ops.gather(groundtruth_boxes,
                                               matched_gt_indices)
        matched_targets = matched_gt_boxes.get_field(field_name)

        unmatched_ignored_reg_targets = tf.tile(
            tf.zeros([1, matched_targets.get_shape().as_list()[-1]],
                     dtype=tf.float32),
            tf.stack([tf.size(unmatched_ignored_anchor_indices), 1]))
        targets = tf.dynamic_stitch(
            [matched_anchor_indices, unmatched_ignored_anchor_indices],
            [matched_targets, unmatched_ignored_reg_targets])
        # TODO: summarize the number of matches on average.
        return targets
Exemple #9
0
 def graph_fn():
   corners = tf.constant(
       [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
   indices = tf.constant([0, 2, 4], tf.int32)
   boxes = box_list.BoxList(corners)
   subset = box_list_ops.gather(boxes, indices)
   return subset.get()
  def test_dynamic_gather_with_field(self):
    corners = tf.placeholder(tf.float32, [None, 4])
    indices = tf.placeholder(tf.int32, [None])
    weights = tf.placeholder(tf.float32, [None, 1])
    expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
    expected_weights = [[.1], [.5], [.9]]

    boxes = box_list.BoxList(corners)
    boxes.add_field('weights', weights)
    subset = box_list_ops.gather(boxes, indices, ['weights'],
                                 use_static_shapes=True)
    with self.test_session() as sess:
      subset_output, weights_output = sess.run(
          [subset.get(), subset.get_field('weights')],
          feed_dict={
              corners:
                  np.array(
                      [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]]),
              indices:
                  np.array([0, 2, 4]).astype(np.int32),
              weights:
                  np.array([[.1], [.3], [.5], [.7], [.9]])
          })
      self.assertAllClose(subset_output, expected_subset)
      self.assertAllClose(weights_output, expected_weights)
Exemple #11
0
 def graph_fn():
   corners = tf.constant(
       [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
   weights = tf.constant([.5, .3, .7, .1, .9], tf.float32)
   indices = tf.reshape(tf.where(tf.greater(weights, 0.4)), [-1])
   boxes = box_list.BoxList(corners)
   boxes.add_field('weights', weights)
   subset = box_list_ops.gather(boxes, indices, ['weights'])
   return subset.get(), subset.get_field('weights')
 def test_gather(self):
   corners = tf.constant(
       [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
   indices = tf.constant([0, 2, 4], tf.int32)
   expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
   boxes = box_list.BoxList(corners)
   subset = box_list_ops.gather(boxes, indices)
   with self.test_session() as sess:
     subset_output = sess.run(subset.get())
     self.assertAllClose(subset_output, expected_subset)
 def test_gather(self):
     corners = tf.constant(
         [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
     indices = tf.constant([0, 2, 4], tf.int32)
     expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
     boxes = box_list.BoxList(corners)
     subset = box_list_ops.gather(boxes, indices)
     with self.test_session() as sess:
         subset_output = sess.run(subset.get())
         self.assertAllClose(subset_output, expected_subset)
Exemple #14
0
  def _assign_rbox_target(self, anchors, groundtruth_boxes, groundtruth_rotations,
                          match):
    '''
    Args:
      anchors: a BoxList representing N anchors
      groundtruth_boxes: a BoxList representing M groundtruth boxes
      groundtruth_rotations:

    Returns:
      reg_targets: a float32 tensor with shape [num_anchors, 5]
      reg_weights: a float32 tensor with shape [num_anchors]
    '''

    if not isinstance(anchors, box_list.BoxList):
      raise ValueError('anchors must be an BoxList')
    if not isinstance(groundtruth_boxes, box_list.BoxList):
      raise ValueError('groundtruth_boxes must be an BoxList')

    matched_anchor_indices = match.matched_column_indices()
    unmatched_ignored_anchor_indices = (match.
                                        unmatched_or_ignored_column_indices())
    matched_gt_indices = match.matched_row_indices()
    matched_anchors = box_list_ops.gather(anchors,
                                          matched_anchor_indices)
    matched_gt_boxes = box_list_ops.gather(groundtruth_boxes,
                                           matched_gt_indices)
    matched_gt_rotations = tf.gather(groundtruth_rotations,
                                     matched_gt_indices)
    matched_reg_targets = self._box_coder.encode(matched_gt_boxes,
                                                 matched_gt_rotations,
                                                 matched_anchors)
    default_target = tf.constant([5*[0]], tf.float32)
    unmatched_ignored_reg_targets = tf.tile(
        default_target,
        tf.stack([tf.size(unmatched_ignored_anchor_indices), 1]))
    reg_targets = tf.dynamic_stitch(
        [matched_anchor_indices, unmatched_ignored_anchor_indices],
        [matched_reg_targets, unmatched_ignored_reg_targets])
    reg_weights = tf.cast(match.matched_column_indicator(), tf.float32)
    return reg_targets, reg_weights
  def test_gather_with_field(self):
    corners = tf.constant([4*[0.0], 4*[1.0], 4*[2.0], 4*[3.0], 4*[4.0]])
    indices = tf.constant([0, 2, 4], tf.int32)
    weights = tf.constant([[.1], [.3], [.5], [.7], [.9]], tf.float32)
    expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
    expected_weights = [[.1], [.5], [.9]]

    boxes = box_list.BoxList(corners)
    boxes.add_field('weights', weights)
    subset = box_list_ops.gather(boxes, indices, ['weights'])
    with self.test_session() as sess:
      subset_output, weights_output = sess.run(
          [subset.get(), subset.get_field('weights')])
      self.assertAllClose(subset_output, expected_subset)
      self.assertAllClose(weights_output, expected_weights)
  def test_gather_with_field(self):
    corners = tf.constant([4*[0.0], 4*[1.0], 4*[2.0], 4*[3.0], 4*[4.0]])
    indices = tf.constant([0, 2, 4], tf.int32)
    weights = tf.constant([[.1], [.3], [.5], [.7], [.9]], tf.float32)
    expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
    expected_weights = [[.1], [.5], [.9]]

    boxes = box_list.BoxList(corners)
    boxes.add_field('weights', weights)
    subset = box_list_ops.gather(boxes, indices, ['weights'])
    with self.test_session() as sess:
      subset_output, weights_output = sess.run(
          [subset.get(), subset.get_field('weights')])
      self.assertAllClose(subset_output, expected_subset)
      self.assertAllClose(weights_output, expected_weights)
  def test_gather_with_dynamic_indexing(self):
    corners = tf.constant([4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]
                          ])
    weights = tf.constant([.5, .3, .7, .1, .9], tf.float32)
    indices = tf.reshape(tf.where(tf.greater(weights, 0.4)), [-1])
    expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
    expected_weights = [.5, .7, .9]

    boxes = box_list.BoxList(corners)
    boxes.add_field('weights', weights)
    subset = box_list_ops.gather(boxes, indices, ['weights'])
    with self.test_session() as sess:
      subset_output, weights_output = sess.run([subset.get(), subset.get_field(
          'weights')])
      self.assertAllClose(subset_output, expected_subset)
      self.assertAllClose(weights_output, expected_weights)
    def test_gather_with_dynamic_indexing(self):
        corners = tf.constant(
            [4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
        weights = tf.constant([.5, .3, .7, .1, .9], tf.float32)
        indices = tf.reshape(tf.where(tf.greater(weights, 0.4)), [-1])
        expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
        expected_weights = [.5, .7, .9]

        boxes = box_list.BoxList(corners)
        boxes.add_field('weights', weights)
        subset = box_list_ops.gather(boxes, indices, ['weights'])
        with self.test_session() as sess:
            subset_output, weights_output = sess.run(
                [subset.get(), subset.get_field('weights')])
            self.assertAllClose(subset_output, expected_subset)
            self.assertAllClose(weights_output, expected_weights)
 def graph_fn(corners, weights, indices):
   boxes = box_list.BoxList(corners)
   boxes.add_field('weights', weights)
   subset = box_list_ops.gather(
       boxes, indices, ['weights'], use_static_shapes=True)
   return (subset.get_field('boxes'), subset.get_field('weights'))
def multiclass_non_max_suppression(boxes,
                                   scores,
                                   score_thresh,
                                   iou_thresh,
                                   max_size_per_class,
                                   max_total_size=0,
                                   clip_window=None,
                                   change_coordinate_frame=False,
                                   masks=None,
                                   boundaries=None,
                                   pad_to_max_output_size=False,
                                   additional_fields=None,
                                   scope=None):
    """Multi-class version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates independently for each class for
  which scores are provided (via the scores field of the input box_list),
  pruning boxes with score less than a provided threshold prior to
  applying NMS.

  Please note that this operation is performed on *all* classes, therefore any
  background classes should be removed prior to calling this function.

  Selected boxes are guaranteed to be sorted in decreasing order by score (but
  the sort is not guaranteed to be stable).

  Args:
    boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
      number of classes or 1 depending on whether a separate box is predicted
      per class.
    scores: A [k, num_classes] float32 tensor containing the scores for each of
      the k detections. The scores have to be non-negative when
      pad_to_max_output_size is True.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_size_per_class: maximum number of retained boxes per class.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window
      is provided)
    masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
      containing box masks. `q` can be either number of classes or 1 depending
      on whether a separate mask is predicted per class.
    boundaries: (optional) a [k, q, boundary_height, boundary_width] float32
      tensor containing box boundaries. `q` can be either number of classes or 1
      depending on whether a separate boundary is predicted per class.
    pad_to_max_output_size: If true, the output nmsed boxes are padded to be of
      length `max_size_per_class`. Defaults to false.
    additional_fields: (optional) If not None, a dictionary that maps keys to
      tensors whose first dimensions are all of size `k`. After non-maximum
      suppression, all tensors corresponding to the selected boxes will be
      added to resulting BoxList.
    scope: name scope.

  Returns:
    A tuple of sorted_boxes and num_valid_nms_boxes. The sorted_boxes is a
      BoxList holds M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box. The
      num_valid_nms_boxes is a 0-D integer tensor representing the number of
      valid elements in `BoxList`, with the valid elements appearing first.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
    if not 0 <= iou_thresh <= 1.0:
        raise ValueError('iou_thresh must be between 0 and 1')
    if scores.shape.ndims != 2:
        raise ValueError('scores field must be of rank 2')
    if scores.shape[1].value is None:
        raise ValueError('scores must have statically defined second '
                         'dimension')
    if boxes.shape.ndims != 3:
        raise ValueError('boxes must be of rank 3.')
    if not (boxes.shape[1].value == scores.shape[1].value
            or boxes.shape[1].value == 1):
        raise ValueError('second dimension of boxes must be either 1 or equal '
                         'to the second dimension of scores')
    if boxes.shape[2].value != 4:
        raise ValueError('last dimension of boxes must be of size 4.')
    if change_coordinate_frame and clip_window is None:
        raise ValueError(
            'if change_coordinate_frame is True, then a clip_window'
            'must be specified.')

    with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
        num_scores = tf.shape(scores)[0]
        num_classes = scores.get_shape()[1]

        selected_boxes_list = []
        num_valid_nms_boxes_cumulative = tf.constant(0)
        per_class_boxes_list = tf.unstack(boxes, axis=1)
        if masks is not None:
            per_class_masks_list = tf.unstack(masks, axis=1)
        if boundaries is not None:
            per_class_boundaries_list = tf.unstack(boundaries, axis=1)
        boxes_ids = (range(num_classes) if len(per_class_boxes_list) > 1 else
                     [0] * num_classes.value)
        for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
            per_class_boxes = per_class_boxes_list[boxes_idx]
            boxlist_and_class_scores = box_list.BoxList(per_class_boxes)
            class_scores = tf.reshape(
                tf.slice(scores, [0, class_idx], tf.stack([num_scores, 1])),
                [-1])

            boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                               class_scores)
            if masks is not None:
                per_class_masks = per_class_masks_list[boxes_idx]
                boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
                                                   per_class_masks)
            if boundaries is not None:
                per_class_boundaries = per_class_boundaries_list[boxes_idx]
                boxlist_and_class_scores.add_field(
                    fields.BoxListFields.boundaries, per_class_boundaries)
            if additional_fields is not None:
                for key, tensor in additional_fields.items():
                    boxlist_and_class_scores.add_field(key, tensor)

            if pad_to_max_output_size:
                max_selection_size = max_size_per_class
                selected_indices, num_valid_nms_boxes = (
                    tf.image.non_max_suppression_padded(
                        boxlist_and_class_scores.get(),
                        boxlist_and_class_scores.get_field(
                            fields.BoxListFields.scores),
                        max_selection_size,
                        iou_threshold=iou_thresh,
                        score_threshold=score_thresh,
                        pad_to_max_output_size=True))
            else:
                max_selection_size = tf.minimum(
                    max_size_per_class, boxlist_and_class_scores.num_boxes())
                selected_indices = tf.image.non_max_suppression(
                    boxlist_and_class_scores.get(),
                    boxlist_and_class_scores.get_field(
                        fields.BoxListFields.scores),
                    max_selection_size,
                    iou_threshold=iou_thresh,
                    score_threshold=score_thresh)
                num_valid_nms_boxes = tf.shape(selected_indices)[0]
                selected_indices = tf.concat([
                    selected_indices,
                    tf.zeros(max_selection_size - num_valid_nms_boxes,
                             tf.int32)
                ], 0)
            nms_result = box_list_ops.gather(boxlist_and_class_scores,
                                             selected_indices)
            # Make the scores -1 for invalid boxes.
            valid_nms_boxes_indx = tf.less(tf.range(max_selection_size),
                                           num_valid_nms_boxes)
            nms_scores = nms_result.get_field(fields.BoxListFields.scores)
            nms_result.add_field(
                fields.BoxListFields.scores,
                tf.where(valid_nms_boxes_indx, nms_scores,
                         -1 * tf.ones(max_selection_size)))
            num_valid_nms_boxes_cumulative += num_valid_nms_boxes

            nms_result.add_field(fields.BoxListFields.classes, (tf.zeros_like(
                nms_result.get_field(fields.BoxListFields.scores)) +
                                                                class_idx))
            selected_boxes_list.append(nms_result)
        selected_boxes = box_list_ops.concatenate(selected_boxes_list)
        sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
                                                  fields.BoxListFields.scores)
        if clip_window is not None:
            # When pad_to_max_output_size is False, it prunes the boxes with zero
            # area.
            sorted_boxes = box_list_ops.clip_to_window(
                sorted_boxes,
                clip_window,
                filter_nonoverlapping=not pad_to_max_output_size)
            # Set the scores of boxes with zero area to -1 to keep the default
            # behaviour of pruning out zero area boxes.
            sorted_boxes_size = tf.shape(sorted_boxes.get())[0]
            non_zero_box_area = tf.cast(box_list_ops.area(sorted_boxes),
                                        tf.bool)
            sorted_boxes_scores = tf.where(
                non_zero_box_area,
                sorted_boxes.get_field(fields.BoxListFields.scores),
                -1 * tf.ones(sorted_boxes_size))
            sorted_boxes.add_field(fields.BoxListFields.scores,
                                   sorted_boxes_scores)
            num_valid_nms_boxes_cumulative = tf.reduce_sum(
                tf.cast(tf.greater_equal(sorted_boxes_scores, 0), tf.int32))
            sorted_boxes = box_list_ops.sort_by_field(
                sorted_boxes, fields.BoxListFields.scores)
            if change_coordinate_frame:
                sorted_boxes = box_list_ops.change_coordinate_frame(
                    sorted_boxes, clip_window)

        if max_total_size:
            max_total_size = tf.minimum(max_total_size,
                                        sorted_boxes.num_boxes())
            sorted_boxes = box_list_ops.gather(sorted_boxes,
                                               tf.range(max_total_size))
            num_valid_nms_boxes_cumulative = tf.where(
                max_total_size > num_valid_nms_boxes_cumulative,
                num_valid_nms_boxes_cumulative, max_total_size)
        # Select only the valid boxes if pad_to_max_output_size is False.
        if not pad_to_max_output_size:
            sorted_boxes = box_list_ops.gather(
                sorted_boxes, tf.range(num_valid_nms_boxes_cumulative))

        return sorted_boxes, num_valid_nms_boxes_cumulative
Exemple #21
0
 def graph_fn(corners, weights, indices):
   boxes = box_list.BoxList(corners)
   boxes.add_field('weights', weights)
   subset = box_list_ops.gather(
       boxes, indices, ['weights'], use_static_shapes=True)
   return (subset.get_field('boxes'), subset.get_field('weights'))
Exemple #22
0
def multiclass_non_max_suppression(boxes,
                                   scores,
                                   score_thresh,
                                   iou_thresh,
                                   max_size_per_class,
                                   max_total_size=0,
                                   clip_window=None,
                                   change_coordinate_frame=False,
                                   masks=None,
                                   boundaries=None,
                                   pad_to_max_output_size=False,
                                   additional_fields=None,
                                   scope=None):
  """Multi-class version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates independently for each class for
  which scores are provided (via the scores field of the input box_list),
  pruning boxes with score less than a provided threshold prior to
  applying NMS.

  Please note that this operation is performed on *all* classes, therefore any
  background classes should be removed prior to calling this function.

  Selected boxes are guaranteed to be sorted in decreasing order by score (but
  the sort is not guaranteed to be stable).

  Args:
    boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
      number of classes or 1 depending on whether a separate box is predicted
      per class.
    scores: A [k, num_classes] float32 tensor containing the scores for each of
      the k detections. The scores have to be non-negative when
      pad_to_max_output_size is True.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_size_per_class: maximum number of retained boxes per class.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window
      is provided)
    masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
      containing box masks. `q` can be either number of classes or 1 depending
      on whether a separate mask is predicted per class.
    boundaries: (optional) a [k, q, boundary_height, boundary_width] float32
      tensor containing box boundaries. `q` can be either number of classes or 1
      depending on whether a separate boundary is predicted per class.
    pad_to_max_output_size: If true, the output nmsed boxes are padded to be of
      length `max_size_per_class`. Defaults to false.
    additional_fields: (optional) If not None, a dictionary that maps keys to
      tensors whose first dimensions are all of size `k`. After non-maximum
      suppression, all tensors corresponding to the selected boxes will be
      added to resulting BoxList.
    scope: name scope.

  Returns:
    A tuple of sorted_boxes and num_valid_nms_boxes. The sorted_boxes is a
      BoxList holds M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box. The
      num_valid_nms_boxes is a 0-D integer tensor representing the number of
      valid elements in `BoxList`, with the valid elements appearing first.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
  if not 0 <= iou_thresh <= 1.0:
    raise ValueError('iou_thresh must be between 0 and 1')
  if scores.shape.ndims != 2:
    raise ValueError('scores field must be of rank 2')
  if scores.shape[1].value is None:
    raise ValueError('scores must have statically defined second '
                     'dimension')
  if boxes.shape.ndims != 3:
    raise ValueError('boxes must be of rank 3.')
  if not (boxes.shape[1].value == scores.shape[1].value or
          boxes.shape[1].value == 1):
    raise ValueError('second dimension of boxes must be either 1 or equal '
                     'to the second dimension of scores')
  if boxes.shape[2].value != 4:
    raise ValueError('last dimension of boxes must be of size 4.')
  if change_coordinate_frame and clip_window is None:
    raise ValueError('if change_coordinate_frame is True, then a clip_window'
                     'must be specified.')

  with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
    num_scores = tf.shape(scores)[0]
    num_classes = scores.get_shape()[1]

    selected_boxes_list = []
    num_valid_nms_boxes_cumulative = tf.constant(0)
    per_class_boxes_list = tf.unstack(boxes, axis=1)
    if masks is not None:
      per_class_masks_list = tf.unstack(masks, axis=1)
    if boundaries is not None:
      per_class_boundaries_list = tf.unstack(boundaries, axis=1)
    boxes_ids = (range(num_classes) if len(per_class_boxes_list) > 1
                 else [0] * num_classes.value)
    for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
      per_class_boxes = per_class_boxes_list[boxes_idx]
      boxlist_and_class_scores = box_list.BoxList(per_class_boxes)
      class_scores = tf.reshape(
          tf.slice(scores, [0, class_idx], tf.stack([num_scores, 1])), [-1])

      boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                         class_scores)
      if masks is not None:
        per_class_masks = per_class_masks_list[boxes_idx]
        boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
                                           per_class_masks)
      if boundaries is not None:
        per_class_boundaries = per_class_boundaries_list[boxes_idx]
        boxlist_and_class_scores.add_field(fields.BoxListFields.boundaries,
                                           per_class_boundaries)
      if additional_fields is not None:
        for key, tensor in additional_fields.items():
          boxlist_and_class_scores.add_field(key, tensor)

      if pad_to_max_output_size:
        max_selection_size = max_size_per_class
        selected_indices, num_valid_nms_boxes = (
            tf.image.non_max_suppression_padded(
                boxlist_and_class_scores.get(),
                boxlist_and_class_scores.get_field(fields.BoxListFields.scores),
                max_selection_size,
                iou_threshold=iou_thresh,
                score_threshold=score_thresh,
                pad_to_max_output_size=True))
      else:
        max_selection_size = tf.minimum(max_size_per_class,
                                        boxlist_and_class_scores.num_boxes())
        selected_indices = tf.image.non_max_suppression(
            boxlist_and_class_scores.get(),
            boxlist_and_class_scores.get_field(fields.BoxListFields.scores),
            max_selection_size,
            iou_threshold=iou_thresh,
            score_threshold=score_thresh)
        num_valid_nms_boxes = tf.shape(selected_indices)[0]
        selected_indices = tf.concat(
            [selected_indices,
             tf.zeros(max_selection_size-num_valid_nms_boxes, tf.int32)], 0)
      nms_result = box_list_ops.gather(boxlist_and_class_scores,
                                       selected_indices)
      # Make the scores -1 for invalid boxes.
      valid_nms_boxes_indx = tf.less(
          tf.range(max_selection_size), num_valid_nms_boxes)
      nms_scores = nms_result.get_field(fields.BoxListFields.scores)
      nms_result.add_field(fields.BoxListFields.scores,
                           tf.where(valid_nms_boxes_indx,
                                    nms_scores, -1*tf.ones(max_selection_size)))
      num_valid_nms_boxes_cumulative += num_valid_nms_boxes

      nms_result.add_field(
          fields.BoxListFields.classes, (tf.zeros_like(
              nms_result.get_field(fields.BoxListFields.scores)) + class_idx))
      selected_boxes_list.append(nms_result)
    selected_boxes = box_list_ops.concatenate(selected_boxes_list)
    sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
                                              fields.BoxListFields.scores)
    if clip_window is not None:
      # When pad_to_max_output_size is False, it prunes the boxes with zero
      # area.
      sorted_boxes = box_list_ops.clip_to_window(
          sorted_boxes,
          clip_window,
          filter_nonoverlapping=not pad_to_max_output_size)
      # Set the scores of boxes with zero area to -1 to keep the default
      # behaviour of pruning out zero area boxes.
      sorted_boxes_size = tf.shape(sorted_boxes.get())[0]
      non_zero_box_area = tf.cast(box_list_ops.area(sorted_boxes), tf.bool)
      sorted_boxes_scores = tf.where(
          non_zero_box_area,
          sorted_boxes.get_field(fields.BoxListFields.scores),
          -1*tf.ones(sorted_boxes_size))
      sorted_boxes.add_field(fields.BoxListFields.scores, sorted_boxes_scores)
      num_valid_nms_boxes_cumulative = tf.reduce_sum(
          tf.cast(tf.greater_equal(sorted_boxes_scores, 0), tf.int32))
      sorted_boxes = box_list_ops.sort_by_field(sorted_boxes,
                                                fields.BoxListFields.scores)
      if change_coordinate_frame:
        sorted_boxes = box_list_ops.change_coordinate_frame(
            sorted_boxes, clip_window)

    if max_total_size:
      max_total_size = tf.minimum(max_total_size,
                                  sorted_boxes.num_boxes())
      sorted_boxes = box_list_ops.gather(sorted_boxes,
                                         tf.range(max_total_size))
      num_valid_nms_boxes_cumulative = tf.where(
          max_total_size > num_valid_nms_boxes_cumulative,
          num_valid_nms_boxes_cumulative, max_total_size)
    # Select only the valid boxes if pad_to_max_output_size is False.
    if not pad_to_max_output_size:
      sorted_boxes = box_list_ops.gather(
          sorted_boxes, tf.range(num_valid_nms_boxes_cumulative))

    return sorted_boxes, num_valid_nms_boxes_cumulative
Exemple #23
0
def multiclass_non_max_suppression(boxes,
                                   scores,
                                   score_thresh,
                                   iou_thresh,
                                   max_size_per_class,
                                   max_total_size=0,
                                   clip_window=None,
                                   change_coordinate_frame=False,
                                   masks=None,
                                   additional_fields=None,
                                   scope=None):
    """Multi-class version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates independently for each class for
  which scores are provided (via the scores field of the input box_list),
  pruning boxes with score less than a provided threshold prior to
  applying NMS.

  Please note that this operation is performed on *all* classes, therefore any
  background classes should be removed prior to calling this function.

  Args:
    boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
      number of classes or 1 depending on whether a separate box is predicted
      per class.
    scores: A [k, num_classes] float32 tensor containing the scores for each of
      the k detections.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_size_per_class: maximum number of retained boxes per class.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window
      is provided)
    masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
      containing box masks. `q` can be either number of classes or 1 depending
      on whether a separate mask is predicted per class.
    additional_fields: (optional) If not None, a dictionary that maps keys to
      tensors whose first dimensions are all of size `k`. After non-maximum
      suppression, all tensors corresponding to the selected boxes will be
      added to resulting BoxList.
    scope: name scope.

  Returns:
    a BoxList holding M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box.
      If masks, keypoints, keypoint_heatmaps is not None, the boxlist will
      contain masks, keypoints, keypoint_heatmaps corresponding to boxes.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
    if not 0 <= iou_thresh <= 1.0:
        raise ValueError('iou_thresh must be between 0 and 1')
    if scores.shape.ndims != 2:
        raise ValueError('scores field must be of rank 2')
    if scores.shape[1].value is None:
        raise ValueError('scores must have statically defined second '
                         'dimension')
    if boxes.shape.ndims != 3:
        raise ValueError('boxes must be of rank 3.')
    if not (boxes.shape[1].value == scores.shape[1].value
            or boxes.shape[1].value == 1):
        raise ValueError('second dimension of boxes must be either 1 or equal '
                         'to the second dimension of scores')
    if boxes.shape[2].value != 4:
        raise ValueError('last dimension of boxes must be of size 4.')
    if change_coordinate_frame and clip_window is None:
        raise ValueError(
            'if change_coordinate_frame is True, then a clip_window'
            'must be specified.')

    with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
        num_boxes = tf.shape(boxes)[0]
        num_scores = tf.shape(scores)[0]
        num_classes = scores.get_shape()[1]

        length_assert = tf.Assert(tf.equal(num_boxes, num_scores), [
            'Incorrect scores field length: actual vs expected.', num_scores,
            num_boxes
        ])

        selected_boxes_list = []
        per_class_boxes_list = tf.unstack(boxes, axis=1)
        if masks is not None:
            per_class_masks_list = tf.unstack(masks, axis=1)
        boxes_ids = (range(num_classes)
                     if len(per_class_boxes_list) > 1 else [0] * num_classes)
        for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
            per_class_boxes = per_class_boxes_list[boxes_idx]
            boxlist_and_class_scores = box_list.BoxList(per_class_boxes)
            with tf.control_dependencies([length_assert]):
                class_scores = tf.reshape(
                    tf.slice(scores, [0, class_idx], tf.stack([num_scores,
                                                               1])), [-1])
            boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                               class_scores)
            if masks is not None:
                per_class_masks = per_class_masks_list[boxes_idx]
                boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
                                                   per_class_masks)
            if additional_fields is not None:
                for key, tensor in additional_fields.items():
                    boxlist_and_class_scores.add_field(key, tensor)
            boxlist_filtered = box_list_ops.filter_greater_than(
                boxlist_and_class_scores, score_thresh)
            if clip_window is not None:
                boxlist_filtered = box_list_ops.clip_to_window(
                    boxlist_filtered, clip_window)
                if change_coordinate_frame:
                    boxlist_filtered = box_list_ops.change_coordinate_frame(
                        boxlist_filtered, clip_window)
            max_selection_size = tf.minimum(max_size_per_class,
                                            boxlist_filtered.num_boxes())
            selected_indices = tf.image.non_max_suppression(
                boxlist_filtered.get(),
                boxlist_filtered.get_field(fields.BoxListFields.scores),
                max_selection_size,
                iou_threshold=iou_thresh)
            nms_result = box_list_ops.gather(boxlist_filtered,
                                             selected_indices)
            nms_result.add_field(fields.BoxListFields.classes, (tf.zeros_like(
                nms_result.get_field(fields.BoxListFields.scores)) +
                                                                class_idx))
            selected_boxes_list.append(nms_result)
        selected_boxes = box_list_ops.concatenate(selected_boxes_list)
        sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
                                                  fields.BoxListFields.scores)
        if max_total_size:
            max_total_size = tf.minimum(max_total_size,
                                        sorted_boxes.num_boxes())
            sorted_boxes = box_list_ops.gather(sorted_boxes,
                                               tf.range(max_total_size))
        return sorted_boxes
def multiclass_non_max_suppression(boxes,
                                   scores,
                                   score_thresh,
                                   iou_thresh,
                                   max_size_per_class,
                                   max_total_size=0,
                                   clip_window=None,
                                   change_coordinate_frame=False,
                                   masks=None,
                                   additional_fields=None,
                                   scope=None):
  """Multi-class version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates independently for each class for
  which scores are provided (via the scores field of the input box_list),
  pruning boxes with score less than a provided threshold prior to
  applying NMS.

  Please note that this operation is performed on *all* classes, therefore any
  background classes should be removed prior to calling this function.

  Args:
    boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
      number of classes or 1 depending on whether a separate box is predicted
      per class.
    scores: A [k, num_classes] float32 tensor containing the scores for each of
      the k detections.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_size_per_class: maximum number of retained boxes per class.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window
      is provided)
    masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
      containing box masks. `q` can be either number of classes or 1 depending
      on whether a separate mask is predicted per class.
    additional_fields: (optional) If not None, a dictionary that maps keys to
      tensors whose first dimensions are all of size `k`. After non-maximum
      suppression, all tensors corresponding to the selected boxes will be
      added to resulting BoxList.
    scope: name scope.

  Returns:
    a BoxList holding M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box.
      If masks, keypoints, keypoint_heatmaps is not None, the boxlist will
      contain masks, keypoints, keypoint_heatmaps corresponding to boxes.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
  if not 0 <= iou_thresh <= 1.0:
    raise ValueError('iou_thresh must be between 0 and 1')
  if scores.shape.ndims != 2:
    raise ValueError('scores field must be of rank 2')
  if scores.shape[1].value is None:
    raise ValueError('scores must have statically defined second '
                     'dimension')
  if boxes.shape.ndims != 3:
    raise ValueError('boxes must be of rank 3.')
  if not (boxes.shape[1].value == scores.shape[1].value or
          boxes.shape[1].value == 1):
    raise ValueError('second dimension of boxes must be either 1 or equal '
                     'to the second dimension of scores')
  if boxes.shape[2].value != 4:
    raise ValueError('last dimension of boxes must be of size 4.')
  if change_coordinate_frame and clip_window is None:
    raise ValueError('if change_coordinate_frame is True, then a clip_window'
                     'must be specified.')

  with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
    num_boxes = tf.shape(boxes)[0]
    num_scores = tf.shape(scores)[0]
    num_classes = scores.get_shape()[1]

    length_assert = tf.Assert(
        tf.equal(num_boxes, num_scores),
        ['Incorrect scores field length: actual vs expected.',
         num_scores, num_boxes])

    selected_boxes_list = []
    per_class_boxes_list = tf.unstack(boxes, axis=1)
    if masks is not None:
      per_class_masks_list = tf.unstack(masks, axis=1)
    boxes_ids = (range(num_classes) if len(per_class_boxes_list) > 1
                 else [0] * num_classes)
    for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
      per_class_boxes = per_class_boxes_list[boxes_idx]
      boxlist_and_class_scores = box_list.BoxList(per_class_boxes)
      with tf.control_dependencies([length_assert]):
        class_scores = tf.reshape(
            tf.slice(scores, [0, class_idx], tf.stack([num_scores, 1])), [-1])
      boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                         class_scores)
      if masks is not None:
        per_class_masks = per_class_masks_list[boxes_idx]
        boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
                                           per_class_masks)
      if additional_fields is not None:
        for key, tensor in additional_fields.items():
          boxlist_and_class_scores.add_field(key, tensor)
      boxlist_filtered = box_list_ops.filter_greater_than(
          boxlist_and_class_scores, score_thresh)
      if clip_window is not None:
        boxlist_filtered = box_list_ops.clip_to_window(
            boxlist_filtered, clip_window)
        if change_coordinate_frame:
          boxlist_filtered = box_list_ops.change_coordinate_frame(
              boxlist_filtered, clip_window)
      max_selection_size = tf.minimum(max_size_per_class,
                                      boxlist_filtered.num_boxes())
      selected_indices = tf.image.non_max_suppression(
          boxlist_filtered.get(),
          boxlist_filtered.get_field(fields.BoxListFields.scores),
          max_selection_size,
          iou_threshold=iou_thresh)
      nms_result = box_list_ops.gather(boxlist_filtered, selected_indices)
      nms_result.add_field(
          fields.BoxListFields.classes, (tf.zeros_like(
              nms_result.get_field(fields.BoxListFields.scores)) + class_idx))
      selected_boxes_list.append(nms_result)
    selected_boxes = box_list_ops.concatenate(selected_boxes_list)
    sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
                                              fields.BoxListFields.scores)
    if max_total_size:
      max_total_size = tf.minimum(max_total_size,
                                  sorted_boxes.num_boxes())
      sorted_boxes = box_list_ops.gather(sorted_boxes,
                                         tf.range(max_total_size))
    return sorted_boxes
Exemple #25
0
def class_agnostic_non_max_suppression(boxes,
                                       boxes_3d,
                                       scores,
                                       score_thresh,
                                       iou_thresh,
                                       max_classes_per_detection=1,
                                       max_total_size=0,
                                       clip_window=None,
                                       change_coordinate_frame=False,
                                       boundaries=None,
                                       pad_to_max_output_size=False,
                                       additional_fields=None,
                                       scope=None):
    """Class-agnostic version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates on all the boxes using
  max scores across all classes for which scores are provided (via the scores
  field of the input box_list), pruning boxes with score less than a provided
  threshold prior to applying NMS.

  Please note that this operation is performed in a class-agnostic way,
  therefore any background classes should be removed prior to calling this
  function.

  Selected boxes are guaranteed to be sorted in decreasing order by score (but
  the sort is not guaranteed to be stable).

  Args:
    boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
      number of classes or 1 depending on whether a separate box is predicted
      per class.
    scores: A [k, num_classes] float32 tensor containing the scores for each of
      the k detections. The scores have to be non-negative when
      pad_to_max_output_size is True.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_classes_per_detection: maximum number of retained classes per detection
      box in class-agnostic NMS.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window is
      provided)
    masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
      containing box masks. `q` can be either number of classes or 1 depending
      on whether a separate mask is predicted per class.
    boundaries: (optional) a [k, q, boundary_height, boundary_width] float32
      tensor containing box boundaries. `q` can be either number of classes or 1
      depending on whether a separate boundary is predicted per class.
    pad_to_max_output_size: If true, the output nmsed boxes are padded to be of
      length `max_size_per_class`. Defaults to false.
    additional_fields: (optional) If not None, a dictionary that maps keys to
      tensors whose first dimensions are all of size `k`. After non-maximum
      suppression, all tensors corresponding to the selected boxes will be added
      to resulting BoxList.
    scope: name scope.

  Returns:
    A tuple of sorted_boxes and num_valid_nms_boxes. The sorted_boxes is a
      BoxList holds M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box. The
      num_valid_nms_boxes is a 0-D integer tensor representing the number of
      valid elements in `BoxList`, with the valid elements appearing first.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
    _validate_boxes_scores_iou_thresh(boxes, scores, iou_thresh,
                                      change_coordinate_frame, clip_window)

    if max_classes_per_detection > 1:
        raise ValueError('Max classes per detection box >1 not supported.')
    q = boxes.shape[1].value
    if q > 1:
        class_ids = tf.expand_dims(tf.argmax(scores,
                                             axis=1,
                                             output_type=tf.int32),
                                   axis=1)
        boxes = tf.batch_gather(boxes, class_ids)
        boxes_3d = tf.batch_gather(boxes_3d, class_ids)
        if boundaries is not None:
            boundaries = tf.batch_gather(boundaries, class_ids)
    boxes = tf.squeeze(boxes, axis=[1])
    boxes_3d = tf.squeeze(boxes_3d, axis=[1])
    if boundaries is not None:
        boundaries = tf.squeeze(boundaries, axis=[1])

    with tf.name_scope(scope, 'ClassAgnosticNonMaxSuppression'):
        boxlist_and_class_scores = box_list.BoxList(boxes)
        max_scores = tf.reduce_max(scores, axis=-1)
        classes_with_max_scores = tf.argmax(scores, axis=-1)
        boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                           max_scores)
        boxlist_and_class_scores.add_field(fields.BoxListFields.boxes_3d,
                                           boxes_3d)
        if boundaries is not None:
            boxlist_and_class_scores.add_field(fields.BoxListFields.boundaries,
                                               boundaries)

        if additional_fields is not None:
            for key, tensor in additional_fields.items():
                boxlist_and_class_scores.add_field(key, tensor)

        if pad_to_max_output_size:
            max_selection_size = max_total_size
            selected_indices, num_valid_nms_boxes = (
                tf.image.non_max_suppression_padded(
                    boxlist_and_class_scores.get(),
                    boxlist_and_class_scores.get_field(
                        fields.BoxListFields.scores),
                    max_selection_size,
                    iou_threshold=iou_thresh,
                    score_threshold=score_thresh,
                    pad_to_max_output_size=True))
        else:
            max_selection_size = tf.minimum(
                max_total_size, boxlist_and_class_scores.num_boxes())
            selected_indices = tf.image.non_max_suppression(
                boxlist_and_class_scores.get(),
                boxlist_and_class_scores.get_field(
                    fields.BoxListFields.scores),
                max_selection_size,
                iou_threshold=iou_thresh,
                score_threshold=score_thresh)
            num_valid_nms_boxes = tf.shape(selected_indices)[0]
            selected_indices = tf.concat([
                selected_indices,
                tf.zeros(max_selection_size - num_valid_nms_boxes, tf.int32)
            ], 0)

        nms_result = box_list_ops.gather(boxlist_and_class_scores,
                                         selected_indices)

        valid_nms_boxes_indx = tf.less(tf.range(max_selection_size),
                                       num_valid_nms_boxes)
        nms_scores = nms_result.get_field(fields.BoxListFields.scores)
        nms_result.add_field(
            fields.BoxListFields.scores,
            tf.where(valid_nms_boxes_indx, nms_scores,
                     -1 * tf.ones(max_selection_size)))
        # selected_classes = tf.gather(classes_with_max_scores, selected_indices)
        selected_classes = tf.cast(
            tf.gather(classes_with_max_scores, selected_indices), tf.float32)
        nms_result.add_field(fields.BoxListFields.classes, selected_classes)
        selected_boxes = nms_result
        sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
                                                  fields.BoxListFields.scores)
        if clip_window is not None:
            # When pad_to_max_output_size is False, it prunes the boxes with zero
            # area.
            sorted_boxes, num_valid_nms_boxes = _clip_window_prune_boxes(
                sorted_boxes, clip_window, pad_to_max_output_size,
                change_coordinate_frame)

        if max_total_size:
            max_total_size = tf.minimum(max_total_size,
                                        sorted_boxes.num_boxes())
            sorted_boxes = box_list_ops.gather(sorted_boxes,
                                               tf.range(max_total_size))
            num_valid_nms_boxes = tf.where(
                max_total_size > num_valid_nms_boxes, num_valid_nms_boxes,
                max_total_size)
        # Select only the valid boxes if pad_to_max_output_size is False.
        if not pad_to_max_output_size:
            sorted_boxes = box_list_ops.gather(sorted_boxes,
                                               tf.range(num_valid_nms_boxes))

        return sorted_boxes, num_valid_nms_boxes