Ejemplo n.º 1
0
 def _scale_box_to_normalized_true_image(args):
   """Scale the box coordinates to be relative to the true image shape."""
   boxes, true_image_shape = args
   true_image_shape = tf.cast(true_image_shape, tf.float32)
   true_height, true_width = true_image_shape[0], true_image_shape[1]
   normalized_window = tf.stack([0.0, 0.0, true_height / image_height,
                                 true_width / image_width])
   return box_list_ops.change_coordinate_frame(
       box_list.BoxList(boxes), normalized_window).get()
Ejemplo n.º 2
0
    def graph_fn():
      corners = tf.constant([[0.25, 0.5, 0.75, 0.75], [0.5, 0.0, 1.0, 1.0]])
      window = tf.constant([0.25, 0.25, 0.75, 0.75])
      boxes = box_list.BoxList(corners)

      expected_corners = tf.constant([[0, 0.5, 1.0, 1.0],
                                      [0.5, -0.5, 1.5, 1.5]])
      expected_boxes = box_list.BoxList(expected_corners)
      output = box_list_ops.change_coordinate_frame(boxes, window)
      return output.get(), expected_boxes.get()
Ejemplo n.º 3
0
  def test_change_coordinate_frame(self):
    corners = tf.constant([[0.25, 0.5, 0.75, 0.75], [0.5, 0.0, 1.0, 1.0]])
    window = tf.constant([0.25, 0.25, 0.75, 0.75])
    boxes = box_list.BoxList(corners)

    expected_corners = tf.constant([[0, 0.5, 1.0, 1.0], [0.5, -0.5, 1.5, 1.5]])
    expected_boxes = box_list.BoxList(expected_corners)
    output = box_list_ops.change_coordinate_frame(boxes, window)

    with self.test_session() as sess:
      output_, expected_boxes_ = sess.run([output.get(), expected_boxes.get()])
      self.assertAllClose(output_, expected_boxes_)
Ejemplo n.º 4
0
  def test_change_coordinate_frame(self):
    corners = tf.constant([[0.25, 0.5, 0.75, 0.75], [0.5, 0.0, 1.0, 1.0]])
    window = tf.constant([0.25, 0.25, 0.75, 0.75])
    boxes = box_list.BoxList(corners)

    expected_corners = tf.constant([[0, 0.5, 1.0, 1.0], [0.5, -0.5, 1.5, 1.5]])
    expected_boxes = box_list.BoxList(expected_corners)
    output = box_list_ops.change_coordinate_frame(boxes, window)

    with self.test_session() as sess:
      output_, expected_boxes_ = sess.run([output.get(), expected_boxes.get()])
      self.assertAllClose(output_, expected_boxes_)
Ejemplo n.º 5
0
def _clip_window_prune_boxes(sorted_boxes, clip_window, pad_to_max_output_size,
                             change_coordinate_frame):
    """Prune boxes with zero area.

  Args:
    sorted_boxes: A BoxList containing k detections.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    pad_to_max_output_size: flag indicating whether to pad to max output size or
      not.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window is
      provided).

  Returns:
    sorted_boxes: A BoxList containing k detections after pruning.
    num_valid_nms_boxes_cumulative: Number of valid NMS boxes
  """
    sorted_boxes = box_list_ops.clip_to_window(
        sorted_boxes,
        clip_window,
        filter_nonoverlapping=not pad_to_max_output_size)
    # Set the scores of boxes with zero area to -1 to keep the default
    # behaviour of pruning out zero area boxes.
    sorted_boxes_size = tf.shape(sorted_boxes.get())[0]
    non_zero_box_area = tf.cast(box_list_ops.area(sorted_boxes), tf.bool)
    sorted_boxes_scores = tf.where(
        non_zero_box_area, sorted_boxes.get_field(fields.BoxListFields.scores),
        -1 * tf.ones(sorted_boxes_size))
    sorted_boxes.add_field(fields.BoxListFields.scores, sorted_boxes_scores)
    num_valid_nms_boxes_cumulative = tf.reduce_sum(
        tf.cast(tf.greater_equal(sorted_boxes_scores, 0), tf.int32))
    sorted_boxes = box_list_ops.sort_by_field(sorted_boxes,
                                              fields.BoxListFields.scores)
    if change_coordinate_frame:
        sorted_boxes = box_list_ops.change_coordinate_frame(
            sorted_boxes, clip_window)
    return sorted_boxes, num_valid_nms_boxes_cumulative
Ejemplo n.º 6
0
def transform_input_data(tensor_dict,
                         model_preprocess_fn,
                         image_resizer_fn,
                         num_classes,
                         data_augmentation_fn=None,
                         merge_multiple_boxes=False,
                         retain_original_image=False,
                         use_multiclass_scores=False,
                         use_bfloat16=False,
                         retain_original_image_additional_channels=False):
    """A single function that is responsible for all input data transformations.

  Data transformation functions are applied in the following order.
  1. If key fields.InputDataFields.image_additional_channels is present in
     tensor_dict, the additional channels will be merged into
     fields.InputDataFields.image.
  2. data_augmentation_fn (optional): applied on tensor_dict.
  3. model_preprocess_fn: applied only on image tensor in tensor_dict.
  4. image_resizer_fn: applied on original image and instance mask tensor in
     tensor_dict.
  5. one_hot_encoding: applied to classes tensor in tensor_dict.
  6. merge_multiple_boxes (optional): when groundtruth boxes are exactly the
     same they can be merged into a single box with an associated k-hot class
     label.

  Args:
    tensor_dict: dictionary containing input tensors keyed by
      fields.InputDataFields.
    model_preprocess_fn: model's preprocess function to apply on image tensor.
      This function must take in a 4-D float tensor and return a 4-D preprocess
      float tensor and a tensor containing the true image shape.
    image_resizer_fn: image resizer function to apply on groundtruth instance
      `masks. This function must take a 3-D float tensor of an image and a 3-D
      tensor of instance masks and return a resized version of these along with
      the true shapes.
    num_classes: number of max classes to one-hot (or k-hot) encode the class
      labels.
    data_augmentation_fn: (optional) data augmentation function to apply on
      input `tensor_dict`.
    merge_multiple_boxes: (optional) whether to merge multiple groundtruth boxes
      and classes for a given image if the boxes are exactly the same.
    retain_original_image: (optional) whether to retain original image in the
      output dictionary.
    use_multiclass_scores: whether to use multiclass scores as class targets
      instead of one-hot encoding of `groundtruth_classes`. When
      this is True and multiclass_scores is empty, one-hot encoding of
      `groundtruth_classes` is used as a fallback.
    use_bfloat16: (optional) a bool, whether to use bfloat16 in training.
    retain_original_image_additional_channels: (optional) Whether to retain
      original image additional channels in the output dictionary.

  Returns:
    A dictionary keyed by fields.InputDataFields containing the tensors obtained
    after applying all the transformations.
  """
    out_tensor_dict = tensor_dict.copy()
    if fields.InputDataFields.multiclass_scores in out_tensor_dict:
        out_tensor_dict[
            fields.InputDataFields.
            multiclass_scores] = _multiclass_scores_or_one_hot_labels(
                out_tensor_dict[fields.InputDataFields.multiclass_scores],
                out_tensor_dict[fields.InputDataFields.groundtruth_boxes],
                out_tensor_dict[fields.InputDataFields.groundtruth_classes],
                num_classes)

    if fields.InputDataFields.groundtruth_boxes in out_tensor_dict:
        out_tensor_dict = util_ops.filter_groundtruth_with_nan_box_coordinates(
            out_tensor_dict)
        out_tensor_dict = util_ops.filter_unrecognized_classes(out_tensor_dict)

    if retain_original_image:
        out_tensor_dict[fields.InputDataFields.original_image] = tf.cast(
            image_resizer_fn(out_tensor_dict[fields.InputDataFields.image],
                             None)[0], tf.uint8)

    if fields.InputDataFields.image_additional_channels in out_tensor_dict:
        channels = out_tensor_dict[
            fields.InputDataFields.image_additional_channels]
        out_tensor_dict[fields.InputDataFields.image] = tf.concat(
            [out_tensor_dict[fields.InputDataFields.image], channels], axis=2)
        if retain_original_image_additional_channels:
            out_tensor_dict[
                fields.InputDataFields.image_additional_channels] = tf.cast(
                    image_resizer_fn(channels, None)[0], tf.uint8)

    # Apply data augmentation ops.
    if data_augmentation_fn is not None:
        out_tensor_dict = data_augmentation_fn(out_tensor_dict)

    # Apply model preprocessing ops and resize instance masks.
    image = out_tensor_dict[fields.InputDataFields.image]
    preprocessed_resized_image, true_image_shape = model_preprocess_fn(
        tf.expand_dims(tf.cast(image, dtype=tf.float32), axis=0))

    preprocessed_shape = tf.shape(preprocessed_resized_image)
    new_height, new_width = preprocessed_shape[1], preprocessed_shape[2]

    im_box = tf.stack([
        0.0, 0.0,
        tf.to_float(new_height) / tf.to_float(true_image_shape[0, 0]),
        tf.to_float(new_width) / tf.to_float(true_image_shape[0, 1])
    ])

    if fields.InputDataFields.groundtruth_boxes in tensor_dict:
        bboxes = out_tensor_dict[fields.InputDataFields.groundtruth_boxes]
        boxlist = box_list.BoxList(bboxes)
        realigned_bboxes = box_list_ops.change_coordinate_frame(
            boxlist, im_box)
        out_tensor_dict[
            fields.InputDataFields.groundtruth_boxes] = realigned_bboxes.get()

    if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
        keypoints = out_tensor_dict[
            fields.InputDataFields.groundtruth_keypoints]
        realigned_keypoints = keypoint_ops.change_coordinate_frame(
            keypoints, im_box)
        out_tensor_dict[
            fields.InputDataFields.groundtruth_keypoints] = realigned_keypoints

    if use_bfloat16:
        preprocessed_resized_image = tf.cast(preprocessed_resized_image,
                                             tf.bfloat16)
    out_tensor_dict[fields.InputDataFields.image] = tf.squeeze(
        preprocessed_resized_image, axis=0)
    out_tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze(
        true_image_shape, axis=0)
    if fields.InputDataFields.groundtruth_instance_masks in out_tensor_dict:
        masks = out_tensor_dict[
            fields.InputDataFields.groundtruth_instance_masks]
        _, resized_masks, _ = image_resizer_fn(image, masks)
        if use_bfloat16:
            resized_masks = tf.cast(resized_masks, tf.bfloat16)
        out_tensor_dict[
            fields.InputDataFields.groundtruth_instance_masks] = resized_masks

    label_offset = 1
    zero_indexed_groundtruth_classes = out_tensor_dict[
        fields.InputDataFields.groundtruth_classes] - label_offset
    if use_multiclass_scores:
        out_tensor_dict[
            fields.InputDataFields.groundtruth_classes] = out_tensor_dict[
                fields.InputDataFields.multiclass_scores]
    else:
        out_tensor_dict[
            fields.InputDataFields.groundtruth_classes] = tf.one_hot(
                zero_indexed_groundtruth_classes, num_classes)
    out_tensor_dict.pop(fields.InputDataFields.multiclass_scores, None)

    if fields.InputDataFields.groundtruth_confidences in out_tensor_dict:
        groundtruth_confidences = out_tensor_dict[
            fields.InputDataFields.groundtruth_confidences]
        # Map the confidences to the one-hot encoding of classes
        out_tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
            tf.reshape(groundtruth_confidences, [-1, 1]) *
            out_tensor_dict[fields.InputDataFields.groundtruth_classes])
    else:
        groundtruth_confidences = tf.ones_like(
            zero_indexed_groundtruth_classes, dtype=tf.float32)
        out_tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
            out_tensor_dict[fields.InputDataFields.groundtruth_classes])

    if merge_multiple_boxes:
        merged_boxes, merged_classes, merged_confidences, _ = (
            util_ops.merge_boxes_with_multiple_labels(
                out_tensor_dict[fields.InputDataFields.groundtruth_boxes],
                zero_indexed_groundtruth_classes, groundtruth_confidences,
                num_classes))
        merged_classes = tf.cast(merged_classes, tf.float32)
        out_tensor_dict[
            fields.InputDataFields.groundtruth_boxes] = merged_boxes
        out_tensor_dict[
            fields.InputDataFields.groundtruth_classes] = merged_classes
        out_tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
            merged_confidences)
    if fields.InputDataFields.groundtruth_boxes in out_tensor_dict:
        out_tensor_dict[
            fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
                out_tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]

    return out_tensor_dict
Ejemplo n.º 7
0
def multiclass_non_max_suppression(boxes,
                                   scores,
                                   score_thresh,
                                   iou_thresh,
                                   max_size_per_class,
                                   max_total_size=0,
                                   clip_window=None,
                                   change_coordinate_frame=False,
                                   masks=None,
                                   additional_fields=None,
                                   scope=None):
    """Multi-class version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates independently for each class for
  which scores are provided (via the scores field of the input box_list),
  pruning boxes with score less than a provided threshold prior to
  applying NMS.

  Please note that this operation is performed on *all* classes, therefore any
  background classes should be removed prior to calling this function.

  Args:
    boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
      number of classes or 1 depending on whether a separate box is predicted
      per class.
    scores: A [k, num_classes] float32 tensor containing the scores for each of
      the k detections.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_size_per_class: maximum number of retained boxes per class.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window
      is provided)
    masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
      containing box masks. `q` can be either number of classes or 1 depending
      on whether a separate mask is predicted per class.
    additional_fields: (optional) If not None, a dictionary that maps keys to
      tensors whose first dimensions are all of size `k`. After non-maximum
      suppression, all tensors corresponding to the selected boxes will be
      added to resulting BoxList.
    scope: name scope.

  Returns:
    a BoxList holding M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box.
      If masks, keypoints, keypoint_heatmaps is not None, the boxlist will
      contain masks, keypoints, keypoint_heatmaps corresponding to boxes.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
    if not 0 <= iou_thresh <= 1.0:
        raise ValueError('iou_thresh must be between 0 and 1')
    if scores.shape.ndims != 2:
        raise ValueError('scores field must be of rank 2')
    if scores.shape[1].value is None:
        raise ValueError('scores must have statically defined second '
                         'dimension')
    if boxes.shape.ndims != 3:
        raise ValueError('boxes must be of rank 3.')
    if not (boxes.shape[1].value == scores.shape[1].value
            or boxes.shape[1].value == 1):
        raise ValueError('second dimension of boxes must be either 1 or equal '
                         'to the second dimension of scores')
    if boxes.shape[2].value != 4:
        raise ValueError('last dimension of boxes must be of size 4.')
    if change_coordinate_frame and clip_window is None:
        raise ValueError(
            'if change_coordinate_frame is True, then a clip_window'
            'must be specified.')

    with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
        num_boxes = tf.shape(boxes)[0]
        num_scores = tf.shape(scores)[0]
        num_classes = scores.get_shape()[1]

        length_assert = tf.Assert(tf.equal(num_boxes, num_scores), [
            'Incorrect scores field length: actual vs expected.', num_scores,
            num_boxes
        ])

        selected_boxes_list = []
        per_class_boxes_list = tf.unstack(boxes, axis=1)
        if masks is not None:
            per_class_masks_list = tf.unstack(masks, axis=1)
        boxes_ids = (range(num_classes)
                     if len(per_class_boxes_list) > 1 else [0] * num_classes)
        for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
            per_class_boxes = per_class_boxes_list[boxes_idx]
            boxlist_and_class_scores = box_list.BoxList(per_class_boxes)
            with tf.control_dependencies([length_assert]):
                class_scores = tf.reshape(
                    tf.slice(scores, [0, class_idx], tf.stack([num_scores,
                                                               1])), [-1])
            boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                               class_scores)
            if masks is not None:
                per_class_masks = per_class_masks_list[boxes_idx]
                boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
                                                   per_class_masks)
            if additional_fields is not None:
                for key, tensor in additional_fields.items():
                    boxlist_and_class_scores.add_field(key, tensor)
            boxlist_filtered = box_list_ops.filter_greater_than(
                boxlist_and_class_scores, score_thresh)
            if clip_window is not None:
                boxlist_filtered = box_list_ops.clip_to_window(
                    boxlist_filtered, clip_window)
                if change_coordinate_frame:
                    boxlist_filtered = box_list_ops.change_coordinate_frame(
                        boxlist_filtered, clip_window)
            max_selection_size = tf.minimum(max_size_per_class,
                                            boxlist_filtered.num_boxes())
            selected_indices = tf.image.non_max_suppression(
                boxlist_filtered.get(),
                boxlist_filtered.get_field(fields.BoxListFields.scores),
                max_selection_size,
                iou_threshold=iou_thresh)
            nms_result = box_list_ops.gather(boxlist_filtered,
                                             selected_indices)
            nms_result.add_field(fields.BoxListFields.classes, (tf.zeros_like(
                nms_result.get_field(fields.BoxListFields.scores)) +
                                                                class_idx))
            selected_boxes_list.append(nms_result)
        selected_boxes = box_list_ops.concatenate(selected_boxes_list)
        sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
                                                  fields.BoxListFields.scores)
        if max_total_size:
            max_total_size = tf.minimum(max_total_size,
                                        sorted_boxes.num_boxes())
            sorted_boxes = box_list_ops.gather(sorted_boxes,
                                               tf.range(max_total_size))
        return sorted_boxes
def multiclass_non_max_suppression(boxes,
                                   scores,
                                   score_thresh,
                                   iou_thresh,
                                   max_size_per_class,
                                   max_total_size=0,
                                   clip_window=None,
                                   change_coordinate_frame=False,
                                   masks=None,
                                   additional_fields=None,
                                   scope=None):
  """Multi-class version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates independently for each class for
  which scores are provided (via the scores field of the input box_list),
  pruning boxes with score less than a provided threshold prior to
  applying NMS.

  Please note that this operation is performed on *all* classes, therefore any
  background classes should be removed prior to calling this function.

  Args:
    boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
      number of classes or 1 depending on whether a separate box is predicted
      per class.
    scores: A [k, num_classes] float32 tensor containing the scores for each of
      the k detections.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_size_per_class: maximum number of retained boxes per class.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window
      is provided)
    masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
      containing box masks. `q` can be either number of classes or 1 depending
      on whether a separate mask is predicted per class.
    additional_fields: (optional) If not None, a dictionary that maps keys to
      tensors whose first dimensions are all of size `k`. After non-maximum
      suppression, all tensors corresponding to the selected boxes will be
      added to resulting BoxList.
    scope: name scope.

  Returns:
    a BoxList holding M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box.
      If masks, keypoints, keypoint_heatmaps is not None, the boxlist will
      contain masks, keypoints, keypoint_heatmaps corresponding to boxes.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
  if not 0 <= iou_thresh <= 1.0:
    raise ValueError('iou_thresh must be between 0 and 1')
  if scores.shape.ndims != 2:
    raise ValueError('scores field must be of rank 2')
  if scores.shape[1].value is None:
    raise ValueError('scores must have statically defined second '
                     'dimension')
  if boxes.shape.ndims != 3:
    raise ValueError('boxes must be of rank 3.')
  if not (boxes.shape[1].value == scores.shape[1].value or
          boxes.shape[1].value == 1):
    raise ValueError('second dimension of boxes must be either 1 or equal '
                     'to the second dimension of scores')
  if boxes.shape[2].value != 4:
    raise ValueError('last dimension of boxes must be of size 4.')
  if change_coordinate_frame and clip_window is None:
    raise ValueError('if change_coordinate_frame is True, then a clip_window'
                     'must be specified.')

  with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
    num_boxes = tf.shape(boxes)[0]
    num_scores = tf.shape(scores)[0]
    num_classes = scores.get_shape()[1]

    length_assert = tf.Assert(
        tf.equal(num_boxes, num_scores),
        ['Incorrect scores field length: actual vs expected.',
         num_scores, num_boxes])

    selected_boxes_list = []
    per_class_boxes_list = tf.unstack(boxes, axis=1)
    if masks is not None:
      per_class_masks_list = tf.unstack(masks, axis=1)
    boxes_ids = (range(num_classes) if len(per_class_boxes_list) > 1
                 else [0] * num_classes)
    for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
      per_class_boxes = per_class_boxes_list[boxes_idx]
      boxlist_and_class_scores = box_list.BoxList(per_class_boxes)
      with tf.control_dependencies([length_assert]):
        class_scores = tf.reshape(
            tf.slice(scores, [0, class_idx], tf.stack([num_scores, 1])), [-1])
      boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                         class_scores)
      if masks is not None:
        per_class_masks = per_class_masks_list[boxes_idx]
        boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
                                           per_class_masks)
      if additional_fields is not None:
        for key, tensor in additional_fields.items():
          boxlist_and_class_scores.add_field(key, tensor)
      boxlist_filtered = box_list_ops.filter_greater_than(
          boxlist_and_class_scores, score_thresh)
      if clip_window is not None:
        boxlist_filtered = box_list_ops.clip_to_window(
            boxlist_filtered, clip_window)
        if change_coordinate_frame:
          boxlist_filtered = box_list_ops.change_coordinate_frame(
              boxlist_filtered, clip_window)
      max_selection_size = tf.minimum(max_size_per_class,
                                      boxlist_filtered.num_boxes())
      selected_indices = tf.image.non_max_suppression(
          boxlist_filtered.get(),
          boxlist_filtered.get_field(fields.BoxListFields.scores),
          max_selection_size,
          iou_threshold=iou_thresh)
      nms_result = box_list_ops.gather(boxlist_filtered, selected_indices)
      nms_result.add_field(
          fields.BoxListFields.classes, (tf.zeros_like(
              nms_result.get_field(fields.BoxListFields.scores)) + class_idx))
      selected_boxes_list.append(nms_result)
    selected_boxes = box_list_ops.concatenate(selected_boxes_list)
    sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
                                              fields.BoxListFields.scores)
    if max_total_size:
      max_total_size = tf.minimum(max_total_size,
                                  sorted_boxes.num_boxes())
      sorted_boxes = box_list_ops.gather(sorted_boxes,
                                         tf.range(max_total_size))
    return sorted_boxes
def multiclass_non_max_suppression(boxes,
                                   scores,
                                   score_thresh,
                                   iou_thresh,
                                   max_size_per_class,
                                   max_total_size=0,
                                   clip_window=None,
                                   change_coordinate_frame=False,
                                   masks=None,
                                   boundaries=None,
                                   pad_to_max_output_size=False,
                                   additional_fields=None,
                                   scope=None):
    """Multi-class version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates independently for each class for
  which scores are provided (via the scores field of the input box_list),
  pruning boxes with score less than a provided threshold prior to
  applying NMS.

  Please note that this operation is performed on *all* classes, therefore any
  background classes should be removed prior to calling this function.

  Selected boxes are guaranteed to be sorted in decreasing order by score (but
  the sort is not guaranteed to be stable).

  Args:
    boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
      number of classes or 1 depending on whether a separate box is predicted
      per class.
    scores: A [k, num_classes] float32 tensor containing the scores for each of
      the k detections. The scores have to be non-negative when
      pad_to_max_output_size is True.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_size_per_class: maximum number of retained boxes per class.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window
      is provided)
    masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
      containing box masks. `q` can be either number of classes or 1 depending
      on whether a separate mask is predicted per class.
    boundaries: (optional) a [k, q, boundary_height, boundary_width] float32
      tensor containing box boundaries. `q` can be either number of classes or 1
      depending on whether a separate boundary is predicted per class.
    pad_to_max_output_size: If true, the output nmsed boxes are padded to be of
      length `max_size_per_class`. Defaults to false.
    additional_fields: (optional) If not None, a dictionary that maps keys to
      tensors whose first dimensions are all of size `k`. After non-maximum
      suppression, all tensors corresponding to the selected boxes will be
      added to resulting BoxList.
    scope: name scope.

  Returns:
    A tuple of sorted_boxes and num_valid_nms_boxes. The sorted_boxes is a
      BoxList holds M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box. The
      num_valid_nms_boxes is a 0-D integer tensor representing the number of
      valid elements in `BoxList`, with the valid elements appearing first.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
    if not 0 <= iou_thresh <= 1.0:
        raise ValueError('iou_thresh must be between 0 and 1')
    if scores.shape.ndims != 2:
        raise ValueError('scores field must be of rank 2')
    if scores.shape[1].value is None:
        raise ValueError('scores must have statically defined second '
                         'dimension')
    if boxes.shape.ndims != 3:
        raise ValueError('boxes must be of rank 3.')
    if not (boxes.shape[1].value == scores.shape[1].value
            or boxes.shape[1].value == 1):
        raise ValueError('second dimension of boxes must be either 1 or equal '
                         'to the second dimension of scores')
    if boxes.shape[2].value != 4:
        raise ValueError('last dimension of boxes must be of size 4.')
    if change_coordinate_frame and clip_window is None:
        raise ValueError(
            'if change_coordinate_frame is True, then a clip_window'
            'must be specified.')

    with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
        num_scores = tf.shape(scores)[0]
        num_classes = scores.get_shape()[1]

        selected_boxes_list = []
        num_valid_nms_boxes_cumulative = tf.constant(0)
        per_class_boxes_list = tf.unstack(boxes, axis=1)
        if masks is not None:
            per_class_masks_list = tf.unstack(masks, axis=1)
        if boundaries is not None:
            per_class_boundaries_list = tf.unstack(boundaries, axis=1)
        boxes_ids = (range(num_classes) if len(per_class_boxes_list) > 1 else
                     [0] * num_classes.value)
        for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
            per_class_boxes = per_class_boxes_list[boxes_idx]
            boxlist_and_class_scores = box_list.BoxList(per_class_boxes)
            class_scores = tf.reshape(
                tf.slice(scores, [0, class_idx], tf.stack([num_scores, 1])),
                [-1])

            boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                               class_scores)
            if masks is not None:
                per_class_masks = per_class_masks_list[boxes_idx]
                boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
                                                   per_class_masks)
            if boundaries is not None:
                per_class_boundaries = per_class_boundaries_list[boxes_idx]
                boxlist_and_class_scores.add_field(
                    fields.BoxListFields.boundaries, per_class_boundaries)
            if additional_fields is not None:
                for key, tensor in additional_fields.items():
                    boxlist_and_class_scores.add_field(key, tensor)

            if pad_to_max_output_size:
                max_selection_size = max_size_per_class
                selected_indices, num_valid_nms_boxes = (
                    tf.image.non_max_suppression_padded(
                        boxlist_and_class_scores.get(),
                        boxlist_and_class_scores.get_field(
                            fields.BoxListFields.scores),
                        max_selection_size,
                        iou_threshold=iou_thresh,
                        score_threshold=score_thresh,
                        pad_to_max_output_size=True))
            else:
                max_selection_size = tf.minimum(
                    max_size_per_class, boxlist_and_class_scores.num_boxes())
                selected_indices = tf.image.non_max_suppression(
                    boxlist_and_class_scores.get(),
                    boxlist_and_class_scores.get_field(
                        fields.BoxListFields.scores),
                    max_selection_size,
                    iou_threshold=iou_thresh,
                    score_threshold=score_thresh)
                num_valid_nms_boxes = tf.shape(selected_indices)[0]
                selected_indices = tf.concat([
                    selected_indices,
                    tf.zeros(max_selection_size - num_valid_nms_boxes,
                             tf.int32)
                ], 0)
            nms_result = box_list_ops.gather(boxlist_and_class_scores,
                                             selected_indices)
            # Make the scores -1 for invalid boxes.
            valid_nms_boxes_indx = tf.less(tf.range(max_selection_size),
                                           num_valid_nms_boxes)
            nms_scores = nms_result.get_field(fields.BoxListFields.scores)
            nms_result.add_field(
                fields.BoxListFields.scores,
                tf.where(valid_nms_boxes_indx, nms_scores,
                         -1 * tf.ones(max_selection_size)))
            num_valid_nms_boxes_cumulative += num_valid_nms_boxes

            nms_result.add_field(fields.BoxListFields.classes, (tf.zeros_like(
                nms_result.get_field(fields.BoxListFields.scores)) +
                                                                class_idx))
            selected_boxes_list.append(nms_result)
        selected_boxes = box_list_ops.concatenate(selected_boxes_list)
        sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
                                                  fields.BoxListFields.scores)
        if clip_window is not None:
            # When pad_to_max_output_size is False, it prunes the boxes with zero
            # area.
            sorted_boxes = box_list_ops.clip_to_window(
                sorted_boxes,
                clip_window,
                filter_nonoverlapping=not pad_to_max_output_size)
            # Set the scores of boxes with zero area to -1 to keep the default
            # behaviour of pruning out zero area boxes.
            sorted_boxes_size = tf.shape(sorted_boxes.get())[0]
            non_zero_box_area = tf.cast(box_list_ops.area(sorted_boxes),
                                        tf.bool)
            sorted_boxes_scores = tf.where(
                non_zero_box_area,
                sorted_boxes.get_field(fields.BoxListFields.scores),
                -1 * tf.ones(sorted_boxes_size))
            sorted_boxes.add_field(fields.BoxListFields.scores,
                                   sorted_boxes_scores)
            num_valid_nms_boxes_cumulative = tf.reduce_sum(
                tf.cast(tf.greater_equal(sorted_boxes_scores, 0), tf.int32))
            sorted_boxes = box_list_ops.sort_by_field(
                sorted_boxes, fields.BoxListFields.scores)
            if change_coordinate_frame:
                sorted_boxes = box_list_ops.change_coordinate_frame(
                    sorted_boxes, clip_window)

        if max_total_size:
            max_total_size = tf.minimum(max_total_size,
                                        sorted_boxes.num_boxes())
            sorted_boxes = box_list_ops.gather(sorted_boxes,
                                               tf.range(max_total_size))
            num_valid_nms_boxes_cumulative = tf.where(
                max_total_size > num_valid_nms_boxes_cumulative,
                num_valid_nms_boxes_cumulative, max_total_size)
        # Select only the valid boxes if pad_to_max_output_size is False.
        if not pad_to_max_output_size:
            sorted_boxes = box_list_ops.gather(
                sorted_boxes, tf.range(num_valid_nms_boxes_cumulative))

        return sorted_boxes, num_valid_nms_boxes_cumulative
Ejemplo n.º 10
0
def random_crop_to_aspect_ratio(image,
                                boxes,
                                labels,
                                difficult=None,
                                aspect_ratio=21. / 9.,
                                overlap_thresh=0.3):
    with tf.name_scope('RandomCropToAspectRatio', values=[image]):
        image_shape = tf.shape(image)
        orig_height = image_shape[0]
        orig_width = image_shape[1]
        orig_aspect_ratio = tf.to_float(orig_width) / tf.to_float(orig_height)
        target_aspect_ratio = tf.constant(aspect_ratio, dtype=tf.float32)

        def target_height_fn():
            return tf.to_int32(
                tf.round(tf.to_float(orig_width) / target_aspect_ratio))

        target_height = tf.cond(orig_aspect_ratio >= target_aspect_ratio,
                                lambda: orig_height, target_height_fn)

        def target_width_fn():
            return tf.to_int32(
                tf.round(tf.to_float(orig_height) * target_aspect_ratio))

        target_width = tf.cond(orig_aspect_ratio <= target_aspect_ratio,
                               lambda: orig_width, target_width_fn)

        offset_height = tf.random_uniform([],
                                          minval=0,
                                          maxval=orig_height - target_height +
                                          1,
                                          dtype=tf.int32)
        offset_width = tf.random_uniform([],
                                         minval=0,
                                         maxval=orig_width - target_width + 1,
                                         dtype=tf.int32)

        new_image = tf.image.crop_to_bounding_box(image, offset_height,
                                                  offset_width, target_height,
                                                  target_width)

        im_box = tf.stack([
            tf.to_float(offset_height) / tf.to_float(orig_height),
            tf.to_float(offset_width) / tf.to_float(orig_width),
            tf.to_float(offset_height + target_height) /
            tf.to_float(orig_height),
            tf.to_float(offset_width + target_width) / tf.to_float(orig_width)
        ])

        boxlist = box_list.BoxList(boxes)
        boxlist.add_field('labels', labels)

        if difficult is not None:
            boxlist.add_field('difficult', difficult)

        im_boxlist = box_list.BoxList(tf.expand_dims(im_box, axis=0))

        # remove boxes whose overlap with the image is less than overlap_thresh
        overlapping_boxlist, keep_ids = box_list_ops.prune_non_overlapping_boxes(
            boxlist, im_boxlist, overlap_thresh)

        # change the coordinate of the remaining boxes
        new_labels = overlapping_boxlist.get_field('labels')
        new_boxlist = box_list_ops.change_coordinate_frame(
            overlapping_boxlist, im_box)
        new_boxlist = box_list_ops.clip_to_window(
            new_boxlist, tf.constant([0.0, 0.0, 1.0, 1.0], tf.float32))
        new_boxes = new_boxlist.get()

        result = [new_image, new_boxes, new_labels]

        if difficult is not None:
            new_difficult = new_boxlist.get_field('difficult')
            result.append(new_difficult)

        return tuple(result)
Ejemplo n.º 11
0
def transform_input_data(tensor_dict,
                         model_preprocess_fn,
                         image_resizer_fn,
                         num_classes,
                         data_augmentation_fn=None,
                         merge_multiple_boxes=False,
                         retain_original_image=False,
                         use_multiclass_scores=False,
                         use_bfloat16=False,
                         retain_original_image_additional_channels=False,
                         keypoint_type_weight=None):
    """A single function that is responsible for all input data transformations.

  Data transformation functions are applied in the following order.
  1. If key fields.InputDataFields.image_additional_channels is present in
     tensor_dict, the additional channels will be merged into
     fields.InputDataFields.image.
  2. data_augmentation_fn (optional): applied on tensor_dict.
  3. model_preprocess_fn: applied only on image tensor in tensor_dict.
  4. keypoint_type_weight (optional): If groundtruth keypoints are in
     the tensor dictionary, per-keypoint weights are produced. These weights are
     initialized by `keypoint_type_weight` (or ones if left None).
     Then, for all keypoints that are not visible, the weights are set to 0 (to
     avoid penalizing the model in a loss function).
  5. image_resizer_fn: applied on original image and instance mask tensor in
     tensor_dict.
  6. one_hot_encoding: applied to classes tensor in tensor_dict.
  7. merge_multiple_boxes (optional): when groundtruth boxes are exactly the
     same they can be merged into a single box with an associated k-hot class
     label.

  Args:
    tensor_dict: dictionary containing input tensors keyed by
      fields.InputDataFields.
    model_preprocess_fn: model's preprocess function to apply on image tensor.
      This function must take in a 4-D float tensor and return a 4-D preprocess
      float tensor and a tensor containing the true image shape.
    image_resizer_fn: image resizer function to apply on groundtruth instance
      `masks. This function must take a 3-D float tensor of an image and a 3-D
      tensor of instance masks and return a resized version of these along with
      the true shapes.
    num_classes: number of max classes to one-hot (or k-hot) encode the class
      labels.
    data_augmentation_fn: (optional) data augmentation function to apply on
      input `tensor_dict`.
    merge_multiple_boxes: (optional) whether to merge multiple groundtruth boxes
      and classes for a given image if the boxes are exactly the same.
    retain_original_image: (optional) whether to retain original image in the
      output dictionary.
    use_multiclass_scores: whether to use multiclass scores as class targets
      instead of one-hot encoding of `groundtruth_classes`. When
      this is True and multiclass_scores is empty, one-hot encoding of
      `groundtruth_classes` is used as a fallback.
    use_bfloat16: (optional) a bool, whether to use bfloat16 in training.
    retain_original_image_additional_channels: (optional) Whether to retain
      original image additional channels in the output dictionary.
    keypoint_type_weight: A list (of length num_keypoints) containing
      groundtruth loss weights to use for each keypoint. If None, will use a
      weight of 1.

  Returns:
    A dictionary keyed by fields.InputDataFields containing the tensors obtained
    after applying all the transformations.
  """
    out_tensor_dict = tensor_dict.copy()

    labeled_classes_field = fields.InputDataFields.groundtruth_labeled_classes
    if labeled_classes_field in out_tensor_dict:
        # tf_example_decoder casts unrecognized labels to -1. Remove these
        # unrecognized labels before converting labeled_classes to k-hot vector.
        out_tensor_dict[labeled_classes_field] = _remove_unrecognized_classes(
            out_tensor_dict[labeled_classes_field], unrecognized_label=-1)
        out_tensor_dict[
            labeled_classes_field] = _convert_labeled_classes_to_k_hot(
                out_tensor_dict[labeled_classes_field], num_classes)

    if fields.InputDataFields.multiclass_scores in out_tensor_dict:
        out_tensor_dict[
            fields.InputDataFields.
            multiclass_scores] = _multiclass_scores_or_one_hot_labels(
                out_tensor_dict[fields.InputDataFields.multiclass_scores],
                out_tensor_dict[fields.InputDataFields.groundtruth_boxes],
                out_tensor_dict[fields.InputDataFields.groundtruth_classes],
                num_classes)

    if fields.InputDataFields.groundtruth_boxes in out_tensor_dict:
        out_tensor_dict = util_ops.filter_groundtruth_with_nan_box_coordinates(
            out_tensor_dict)
        out_tensor_dict = util_ops.filter_unrecognized_classes(out_tensor_dict)

    if retain_original_image:
        out_tensor_dict[fields.InputDataFields.original_image] = tf.cast(
            image_resizer_fn(out_tensor_dict[fields.InputDataFields.image],
                             None)[0], tf.uint8)

    if fields.InputDataFields.image_additional_channels in out_tensor_dict:
        channels = out_tensor_dict[
            fields.InputDataFields.image_additional_channels]
        out_tensor_dict[fields.InputDataFields.image] = tf.concat(
            [out_tensor_dict[fields.InputDataFields.image], channels], axis=2)
        if retain_original_image_additional_channels:
            out_tensor_dict[
                fields.InputDataFields.image_additional_channels] = tf.cast(
                    image_resizer_fn(channels, None)[0], tf.uint8)

    # Apply data augmentation ops.
    if data_augmentation_fn is not None:
        out_tensor_dict = data_augmentation_fn(out_tensor_dict)

    # Apply model preprocessing ops and resize instance masks.
    image = out_tensor_dict[fields.InputDataFields.image]
    preprocessed_resized_image, true_image_shape = model_preprocess_fn(
        tf.expand_dims(tf.cast(image, dtype=tf.float32), axis=0))

    preprocessed_shape = tf.shape(preprocessed_resized_image)
    new_height, new_width = preprocessed_shape[1], preprocessed_shape[2]

    im_box = tf.stack([
        0.0, 0.0,
        tf.to_float(new_height) / tf.to_float(true_image_shape[0, 0]),
        tf.to_float(new_width) / tf.to_float(true_image_shape[0, 1])
    ])

    if fields.InputDataFields.groundtruth_boxes in tensor_dict:
        bboxes = out_tensor_dict[fields.InputDataFields.groundtruth_boxes]
        boxlist = box_list.BoxList(bboxes)
        realigned_bboxes = box_list_ops.change_coordinate_frame(
            boxlist, im_box)

        realigned_boxes_tensor = realigned_bboxes.get()
        valid_boxes_tensor = assert_or_prune_invalid_boxes(
            realigned_boxes_tensor)
        out_tensor_dict[
            fields.InputDataFields.groundtruth_boxes] = valid_boxes_tensor

    if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
        keypoints = out_tensor_dict[
            fields.InputDataFields.groundtruth_keypoints]
        realigned_keypoints = keypoint_ops.change_coordinate_frame(
            keypoints, im_box)
        out_tensor_dict[
            fields.InputDataFields.groundtruth_keypoints] = realigned_keypoints
        flds_gt_kpt = fields.InputDataFields.groundtruth_keypoints
        flds_gt_kpt_vis = fields.InputDataFields.groundtruth_keypoint_visibilities
        flds_gt_kpt_weights = fields.InputDataFields.groundtruth_keypoint_weights
        if flds_gt_kpt_vis not in out_tensor_dict:
            out_tensor_dict[flds_gt_kpt_vis] = tf.ones_like(
                out_tensor_dict[flds_gt_kpt][:, :, 0], dtype=tf.bool)
        out_tensor_dict[flds_gt_kpt_weights] = (
            keypoint_ops.keypoint_weights_from_visibilities(
                out_tensor_dict[flds_gt_kpt_vis], keypoint_type_weight))

    if use_bfloat16:
        preprocessed_resized_image = tf.cast(preprocessed_resized_image,
                                             tf.bfloat16)
        if fields.InputDataFields.context_features in out_tensor_dict:
            out_tensor_dict[fields.InputDataFields.context_features] = tf.cast(
                out_tensor_dict[fields.InputDataFields.context_features],
                tf.bfloat16)
    out_tensor_dict[fields.InputDataFields.image] = tf.squeeze(
        preprocessed_resized_image, axis=0)
    out_tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze(
        true_image_shape, axis=0)
    if fields.InputDataFields.groundtruth_instance_masks in out_tensor_dict:
        masks = out_tensor_dict[
            fields.InputDataFields.groundtruth_instance_masks]
        _, resized_masks, _ = image_resizer_fn(image, masks)
        if use_bfloat16:
            resized_masks = tf.cast(resized_masks, tf.bfloat16)
        out_tensor_dict[
            fields.InputDataFields.groundtruth_instance_masks] = resized_masks

    zero_indexed_groundtruth_classes = out_tensor_dict[
        fields.InputDataFields.groundtruth_classes] - _LABEL_OFFSET
    if use_multiclass_scores:
        out_tensor_dict[
            fields.InputDataFields.groundtruth_classes] = out_tensor_dict[
                fields.InputDataFields.multiclass_scores]
    else:
        out_tensor_dict[
            fields.InputDataFields.groundtruth_classes] = tf.one_hot(
                zero_indexed_groundtruth_classes, num_classes)
    out_tensor_dict.pop(fields.InputDataFields.multiclass_scores, None)

    if fields.InputDataFields.groundtruth_confidences in out_tensor_dict:
        groundtruth_confidences = out_tensor_dict[
            fields.InputDataFields.groundtruth_confidences]
        # Map the confidences to the one-hot encoding of classes
        out_tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
            tf.reshape(groundtruth_confidences, [-1, 1]) *
            out_tensor_dict[fields.InputDataFields.groundtruth_classes])
    else:
        groundtruth_confidences = tf.ones_like(
            zero_indexed_groundtruth_classes, dtype=tf.float32)
        out_tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
            out_tensor_dict[fields.InputDataFields.groundtruth_classes])

    if merge_multiple_boxes:
        merged_boxes, merged_classes, merged_confidences, _ = (
            util_ops.merge_boxes_with_multiple_labels(
                out_tensor_dict[fields.InputDataFields.groundtruth_boxes],
                zero_indexed_groundtruth_classes, groundtruth_confidences,
                num_classes))
        merged_classes = tf.cast(merged_classes, tf.float32)
        out_tensor_dict[
            fields.InputDataFields.groundtruth_boxes] = merged_boxes
        out_tensor_dict[
            fields.InputDataFields.groundtruth_classes] = merged_classes
        out_tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
            merged_confidences)
    if fields.InputDataFields.groundtruth_boxes in out_tensor_dict:
        out_tensor_dict[
            fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
                out_tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]

    return out_tensor_dict
Ejemplo n.º 12
0
def multiclass_non_max_suppression(boxes,
                                   scores,
                                   score_thresh,
                                   iou_thresh,
                                   max_size_per_class,
                                   max_total_size=0,
                                   clip_window=None,
                                   change_coordinate_frame=False,
                                   masks=None,
                                   boundaries=None,
                                   pad_to_max_output_size=False,
                                   additional_fields=None,
                                   scope=None):
  """Multi-class version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates independently for each class for
  which scores are provided (via the scores field of the input box_list),
  pruning boxes with score less than a provided threshold prior to
  applying NMS.

  Please note that this operation is performed on *all* classes, therefore any
  background classes should be removed prior to calling this function.

  Selected boxes are guaranteed to be sorted in decreasing order by score (but
  the sort is not guaranteed to be stable).

  Args:
    boxes: A [k, q, 4] float32 tensor containing k detections. `q` can be either
      number of classes or 1 depending on whether a separate box is predicted
      per class.
    scores: A [k, num_classes] float32 tensor containing the scores for each of
      the k detections. The scores have to be non-negative when
      pad_to_max_output_size is True.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_size_per_class: maximum number of retained boxes per class.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip and normalize boxes to before performing
      non-max suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window
      is provided)
    masks: (optional) a [k, q, mask_height, mask_width] float32 tensor
      containing box masks. `q` can be either number of classes or 1 depending
      on whether a separate mask is predicted per class.
    boundaries: (optional) a [k, q, boundary_height, boundary_width] float32
      tensor containing box boundaries. `q` can be either number of classes or 1
      depending on whether a separate boundary is predicted per class.
    pad_to_max_output_size: If true, the output nmsed boxes are padded to be of
      length `max_size_per_class`. Defaults to false.
    additional_fields: (optional) If not None, a dictionary that maps keys to
      tensors whose first dimensions are all of size `k`. After non-maximum
      suppression, all tensors corresponding to the selected boxes will be
      added to resulting BoxList.
    scope: name scope.

  Returns:
    A tuple of sorted_boxes and num_valid_nms_boxes. The sorted_boxes is a
      BoxList holds M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box. The
      num_valid_nms_boxes is a 0-D integer tensor representing the number of
      valid elements in `BoxList`, with the valid elements appearing first.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
  if not 0 <= iou_thresh <= 1.0:
    raise ValueError('iou_thresh must be between 0 and 1')
  if scores.shape.ndims != 2:
    raise ValueError('scores field must be of rank 2')
  if scores.shape[1].value is None:
    raise ValueError('scores must have statically defined second '
                     'dimension')
  if boxes.shape.ndims != 3:
    raise ValueError('boxes must be of rank 3.')
  if not (boxes.shape[1].value == scores.shape[1].value or
          boxes.shape[1].value == 1):
    raise ValueError('second dimension of boxes must be either 1 or equal '
                     'to the second dimension of scores')
  if boxes.shape[2].value != 4:
    raise ValueError('last dimension of boxes must be of size 4.')
  if change_coordinate_frame and clip_window is None:
    raise ValueError('if change_coordinate_frame is True, then a clip_window'
                     'must be specified.')

  with tf.name_scope(scope, 'MultiClassNonMaxSuppression'):
    num_scores = tf.shape(scores)[0]
    num_classes = scores.get_shape()[1]

    selected_boxes_list = []
    num_valid_nms_boxes_cumulative = tf.constant(0)
    per_class_boxes_list = tf.unstack(boxes, axis=1)
    if masks is not None:
      per_class_masks_list = tf.unstack(masks, axis=1)
    if boundaries is not None:
      per_class_boundaries_list = tf.unstack(boundaries, axis=1)
    boxes_ids = (range(num_classes) if len(per_class_boxes_list) > 1
                 else [0] * num_classes.value)
    for class_idx, boxes_idx in zip(range(num_classes), boxes_ids):
      per_class_boxes = per_class_boxes_list[boxes_idx]
      boxlist_and_class_scores = box_list.BoxList(per_class_boxes)
      class_scores = tf.reshape(
          tf.slice(scores, [0, class_idx], tf.stack([num_scores, 1])), [-1])

      boxlist_and_class_scores.add_field(fields.BoxListFields.scores,
                                         class_scores)
      if masks is not None:
        per_class_masks = per_class_masks_list[boxes_idx]
        boxlist_and_class_scores.add_field(fields.BoxListFields.masks,
                                           per_class_masks)
      if boundaries is not None:
        per_class_boundaries = per_class_boundaries_list[boxes_idx]
        boxlist_and_class_scores.add_field(fields.BoxListFields.boundaries,
                                           per_class_boundaries)
      if additional_fields is not None:
        for key, tensor in additional_fields.items():
          boxlist_and_class_scores.add_field(key, tensor)

      if pad_to_max_output_size:
        max_selection_size = max_size_per_class
        selected_indices, num_valid_nms_boxes = (
            tf.image.non_max_suppression_padded(
                boxlist_and_class_scores.get(),
                boxlist_and_class_scores.get_field(fields.BoxListFields.scores),
                max_selection_size,
                iou_threshold=iou_thresh,
                score_threshold=score_thresh,
                pad_to_max_output_size=True))
      else:
        max_selection_size = tf.minimum(max_size_per_class,
                                        boxlist_and_class_scores.num_boxes())
        selected_indices = tf.image.non_max_suppression(
            boxlist_and_class_scores.get(),
            boxlist_and_class_scores.get_field(fields.BoxListFields.scores),
            max_selection_size,
            iou_threshold=iou_thresh,
            score_threshold=score_thresh)
        num_valid_nms_boxes = tf.shape(selected_indices)[0]
        selected_indices = tf.concat(
            [selected_indices,
             tf.zeros(max_selection_size-num_valid_nms_boxes, tf.int32)], 0)
      nms_result = box_list_ops.gather(boxlist_and_class_scores,
                                       selected_indices)
      # Make the scores -1 for invalid boxes.
      valid_nms_boxes_indx = tf.less(
          tf.range(max_selection_size), num_valid_nms_boxes)
      nms_scores = nms_result.get_field(fields.BoxListFields.scores)
      nms_result.add_field(fields.BoxListFields.scores,
                           tf.where(valid_nms_boxes_indx,
                                    nms_scores, -1*tf.ones(max_selection_size)))
      num_valid_nms_boxes_cumulative += num_valid_nms_boxes

      nms_result.add_field(
          fields.BoxListFields.classes, (tf.zeros_like(
              nms_result.get_field(fields.BoxListFields.scores)) + class_idx))
      selected_boxes_list.append(nms_result)
    selected_boxes = box_list_ops.concatenate(selected_boxes_list)
    sorted_boxes = box_list_ops.sort_by_field(selected_boxes,
                                              fields.BoxListFields.scores)
    if clip_window is not None:
      # When pad_to_max_output_size is False, it prunes the boxes with zero
      # area.
      sorted_boxes = box_list_ops.clip_to_window(
          sorted_boxes,
          clip_window,
          filter_nonoverlapping=not pad_to_max_output_size)
      # Set the scores of boxes with zero area to -1 to keep the default
      # behaviour of pruning out zero area boxes.
      sorted_boxes_size = tf.shape(sorted_boxes.get())[0]
      non_zero_box_area = tf.cast(box_list_ops.area(sorted_boxes), tf.bool)
      sorted_boxes_scores = tf.where(
          non_zero_box_area,
          sorted_boxes.get_field(fields.BoxListFields.scores),
          -1*tf.ones(sorted_boxes_size))
      sorted_boxes.add_field(fields.BoxListFields.scores, sorted_boxes_scores)
      num_valid_nms_boxes_cumulative = tf.reduce_sum(
          tf.cast(tf.greater_equal(sorted_boxes_scores, 0), tf.int32))
      sorted_boxes = box_list_ops.sort_by_field(sorted_boxes,
                                                fields.BoxListFields.scores)
      if change_coordinate_frame:
        sorted_boxes = box_list_ops.change_coordinate_frame(
            sorted_boxes, clip_window)

    if max_total_size:
      max_total_size = tf.minimum(max_total_size,
                                  sorted_boxes.num_boxes())
      sorted_boxes = box_list_ops.gather(sorted_boxes,
                                         tf.range(max_total_size))
      num_valid_nms_boxes_cumulative = tf.where(
          max_total_size > num_valid_nms_boxes_cumulative,
          num_valid_nms_boxes_cumulative, max_total_size)
    # Select only the valid boxes if pad_to_max_output_size is False.
    if not pad_to_max_output_size:
      sorted_boxes = box_list_ops.gather(
          sorted_boxes, tf.range(num_valid_nms_boxes_cumulative))

    return sorted_boxes, num_valid_nms_boxes_cumulative