Beispiel #1
0
        def _sample_single_image(inputs):
            (single_image_proposal_boxes, single_image_proposal_scores,
             single_image_num_proposals, single_image_groundtruth_boxes,
             single_image_groundtruth_classes_with_background,
             single_image_groundtruth_weights) = inputs
            single_image_boxlist = box_list.BoxList(
                single_image_proposal_boxes)
            single_image_boxlist.add_field(
                od_standard_fields.BoxListFields.scores,
                single_image_proposal_scores)
            single_image_groundtruth_boxlist = box_list.BoxList(
                single_image_groundtruth_boxes)
            (sampled_boxlist
             ) = self._sample_box_classifier_minibatch_single_image(
                 single_image_boxlist, single_image_num_proposals,
                 single_image_groundtruth_boxlist,
                 single_image_groundtruth_classes_with_background,
                 single_image_groundtruth_weights)
            sampled_padded_boxlist = box_list_ops.pad_or_clip_box_list(
                sampled_boxlist, num_boxes=self._second_stage_batch_size)

            single_num_proposals_sampled = tf.minimum(
                sampled_boxlist.num_boxes(), self._second_stage_batch_size)
            single_boxes_sampled = sampled_padded_boxlist.get()
            single_scores_sampled = sampled_padded_boxlist.get_field(
                od_standard_fields.BoxListFields.scores)
            return (single_boxes_sampled, single_scores_sampled,
                    single_num_proposals_sampled)
Beispiel #2
0
 def graph_fn():
   boxlist = box_list.BoxList(
       tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5]], tf.float32))
   boxlist.add_field('classes', tf.constant([0, 1]))
   boxlist.add_field('scores', tf.constant([0.75, 0.2]))
   num_boxes = 4
   padded_boxlist = box_list_ops.pad_or_clip_box_list(boxlist, num_boxes)
   return (padded_boxlist.get(), padded_boxlist.get_field('classes'),
           padded_boxlist.get_field('scores'))
Beispiel #3
0
        def single_image_nms_fn(args):
            """Runs NMS on a single image and returns padded output."""
            (per_image_boxes, per_image_scores, per_image_masks,
             per_image_motions, per_image_num_valid_boxes) = args
            per_image_boxes = tf.reshape(
                tf.slice(per_image_boxes, 3 * [0],
                         tf.stack([per_image_num_valid_boxes, -1, -1])),
                [-1, q, 4])
            per_image_scores = tf.reshape(
                tf.slice(per_image_scores, [0, 0],
                         tf.stack([per_image_num_valid_boxes, -1])),
                [-1, num_classes])

            per_image_masks = tf.reshape(
                tf.slice(per_image_masks, 4 * [0],
                         tf.stack([per_image_num_valid_boxes, -1, -1, -1])), [
                             -1, q, per_image_masks.shape[2].value,
                             per_image_masks.shape[3].value
                         ])
            per_image_motions = tf.reshape(
                tf.slice(per_image_motions, 3 * [0],
                         tf.stack([per_image_num_valid_boxes, -1, -1])),
                [-1, q, per_image_motions.shape[2].value])

            nmsed_boxlist = multiclass_non_max_suppression(
                per_image_boxes,
                per_image_scores,
                score_thresh,
                iou_thresh,
                max_size_per_class,
                max_total_size,
                masks=per_image_masks,
                motions=per_image_motions,
                clip_window=clip_window,
                change_coordinate_frame=change_coordinate_frame)
            padded_boxlist = box_list_ops.pad_or_clip_box_list(
                nmsed_boxlist, max_total_size)
            num_detections = nmsed_boxlist.num_boxes()
            nmsed_boxes = padded_boxlist.get()
            nmsed_scores = padded_boxlist.get_field(
                fields.BoxListFields.scores)
            nmsed_classes = padded_boxlist.get_field(
                fields.BoxListFields.classes)
            nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
            nmsed_motions = padded_boxlist.get_field(
                fields.BoxListFields.motions)
            return [
                nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
                nmsed_motions, num_detections
            ]
  def test_pad_box_list(self):
    boxlist = box_list.BoxList(
        tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5]], tf.float32))
    boxlist.add_field('classes', tf.constant([0, 1]))
    boxlist.add_field('scores', tf.constant([0.75, 0.2]))
    num_boxes = 4
    padded_boxlist = box_list_ops.pad_or_clip_box_list(boxlist, num_boxes)

    expected_boxes = [[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
                      [0, 0, 0, 0], [0, 0, 0, 0]]
    expected_classes = [0, 1, 0, 0]
    expected_scores = [0.75, 0.2, 0, 0]
    with self.test_session() as sess:
      boxes_out, classes_out, scores_out = sess.run(
          [padded_boxlist.get(), padded_boxlist.get_field('classes'),
           padded_boxlist.get_field('scores')])

      self.assertAllClose(expected_boxes, boxes_out)
      self.assertAllEqual(expected_classes, classes_out)
      self.assertAllClose(expected_scores, scores_out)
  def test_pad_box_list(self):
    boxlist = box_list.BoxList(
        tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5]], tf.float32))
    boxlist.add_field('classes', tf.constant([0, 1]))
    boxlist.add_field('scores', tf.constant([0.75, 0.2]))
    num_boxes = 4
    padded_boxlist = box_list_ops.pad_or_clip_box_list(boxlist, num_boxes)

    expected_boxes = [[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
                      [0, 0, 0, 0], [0, 0, 0, 0]]
    expected_classes = [0, 1, 0, 0]
    expected_scores = [0.75, 0.2, 0, 0]
    with self.test_session() as sess:
      boxes_out, classes_out, scores_out = sess.run(
          [padded_boxlist.get(), padded_boxlist.get_field('classes'),
           padded_boxlist.get_field('scores')])

      self.assertAllClose(expected_boxes, boxes_out)
      self.assertAllEqual(expected_classes, classes_out)
      self.assertAllClose(expected_scores, scores_out)
    def single_image_nms_fn(args):
      """Runs NMS on a single image and returns padded output."""
      (per_image_boxes, per_image_scores, per_image_masks,
       per_image_num_valid_boxes) = args
      per_image_boxes = tf.reshape(
          tf.slice(per_image_boxes, 3 * [0],
                   tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4])
      per_image_scores = tf.reshape(
          tf.slice(per_image_scores, [0, 0],
                   tf.stack([per_image_num_valid_boxes, -1])),
          [-1, num_classes])

      per_image_masks = tf.reshape(
          tf.slice(per_image_masks, 4 * [0],
                   tf.stack([per_image_num_valid_boxes, -1, -1, -1])),
          [-1, q, per_image_masks.shape[2].value,
           per_image_masks.shape[3].value])
      nmsed_boxlist = multiclass_non_max_suppression(
          per_image_boxes,
          per_image_scores,
          score_thresh,
          iou_thresh,
          max_size_per_class,
          max_total_size,
          masks=per_image_masks,
          clip_window=clip_window,
          change_coordinate_frame=change_coordinate_frame)
      padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
                                                         max_total_size)
      num_detections = nmsed_boxlist.num_boxes()
      nmsed_boxes = padded_boxlist.get()
      nmsed_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
      nmsed_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
      nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
      return [nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
              num_detections]
Beispiel #7
0
    def _single_image_nms_fn(args):
      """Runs NMS on a single image and returns padded output.

      Args:
        args: A list of tensors consisting of the following:
          per_image_boxes - A [num_anchors, q, 4] float32 tensor containing
            detections. If `q` is 1 then same boxes are used for all classes
            otherwise, if `q` is equal to number of classes, class-specific
            boxes are used.
          per_image_scores - A [num_anchors, num_classes] float32 tensor
            containing the scores for each of the `num_anchors` detections.
          per_image_masks - A [num_anchors, q, mask_height, mask_width] float32
            tensor containing box masks. `q` can be either number of classes
            or 1 depending on whether a separate mask is predicted per class.
          per_image_clip_window - A 1D float32 tensor of the form
            [ymin, xmin, ymax, xmax] representing the window to clip the boxes
            to.
          per_image_additional_fields - (optional) A variable number of float32
            tensors each with size [num_anchors, ...].
          per_image_num_valid_boxes - A tensor of type `int32`. A 1-D tensor of
            shape [batch_size] representing the number of valid boxes to be
            considered for each image in the batch.  This parameter allows for
            ignoring zero paddings.

      Returns:
        'nmsed_boxes': A [max_detections, 4] float32 tensor containing the
          non-max suppressed boxes.
        'nmsed_scores': A [max_detections] float32 tensor containing the scores
          for the boxes.
        'nmsed_classes': A [max_detections] float32 tensor containing the class
          for boxes.
        'nmsed_masks': (optional) a [max_detections, mask_height, mask_width]
          float32 tensor containing masks for each selected box. This is set to
          None if input `masks` is None.
        'nmsed_additional_fields':  (optional) A variable number of float32
          tensors each with size [max_detections, ...] corresponding to the
          input `per_image_additional_fields`.
        'num_detections': A [batch_size] int32 tensor indicating the number of
          valid detections per batch item. Only the top num_detections[i]
          entries in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The
          rest of the entries are zero paddings.
      """
      per_image_boxes = args[0]
      per_image_scores = args[1]
      per_image_masks = args[2]
      per_image_clip_window = args[3]
      per_image_additional_fields = {
          key: value
          for key, value in zip(additional_fields, args[4:-1])
      }
      per_image_num_valid_boxes = args[-1]
      per_image_boxes = tf.reshape(
          tf.slice(per_image_boxes, 3 * [0],
                   tf.stack([per_image_num_valid_boxes, -1, -1])), [-1, q, 4])
      per_image_scores = tf.reshape(
          tf.slice(per_image_scores, [0, 0],
                   tf.stack([per_image_num_valid_boxes, -1])),
          [-1, num_classes])
      per_image_masks = tf.reshape(
          tf.slice(per_image_masks, 4 * [0],
                   tf.stack([per_image_num_valid_boxes, -1, -1, -1])),
          [-1, q, per_image_masks.shape[2].value,
           per_image_masks.shape[3].value])
      if per_image_additional_fields is not None:
        for key, tensor in per_image_additional_fields.items():
          additional_field_shape = tensor.get_shape()
          additional_field_dim = len(additional_field_shape)
          per_image_additional_fields[key] = tf.reshape(
              tf.slice(per_image_additional_fields[key],
                       additional_field_dim * [0],
                       tf.stack([per_image_num_valid_boxes] +
                                (additional_field_dim - 1) * [-1])),
              [-1] + [dim.value for dim in additional_field_shape[1:]])
      nmsed_boxlist = multiclass_non_max_suppression(
          per_image_boxes,
          per_image_scores,
          score_thresh,
          iou_thresh,
          max_size_per_class,
          max_total_size,
          clip_window=per_image_clip_window,
          change_coordinate_frame=change_coordinate_frame,
          masks=per_image_masks,
          additional_fields=per_image_additional_fields)
      padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
                                                         max_total_size)
      num_detections = nmsed_boxlist.num_boxes()
      nmsed_boxes = padded_boxlist.get()
      nmsed_scores = padded_boxlist.get_field(fields.BoxListFields.scores)
      nmsed_classes = padded_boxlist.get_field(fields.BoxListFields.classes)
      nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
      nmsed_additional_fields = [
          padded_boxlist.get_field(key) for key in per_image_additional_fields
      ]
      return ([nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks] +
              nmsed_additional_fields + [num_detections])
Beispiel #8
0
	  3. _remove_invalid_anchors_and_predictions
	    I: pruned_anchors_boxlist, keep_indices = box_list_ops.prune_outside_window()
	    II: _batch_gather_kept_indices(box_encodings)
	        _batch_gather_kept_indices(objectness_predictions_with_background)
	        'Extremely hard to get through!!!!!!'

	b: Classification (second stage)
	  1. _predict_second_stage
	  	I: flattened_proposal_feature_maps = self._postprocess_rpn()
      'Very complicate function!!!!!!'
        i: self._format_groundtruth_data(): 
        ii: decoded_boxes = self._box_coder.decode(rpn_box_encodings, box_list.BoxList(anchors))
                            --> faster_rcnn_box_coder.FasterRcnnBoxCoder._decode()
            objectness_scores = tf.nn.softmax(rpn_objectness_predictions_with_background)
        iii:proposal_boxlist = post_processing.multiclass_non_max_suppression()
        iv: padded_proposals = box_list_ops.pad_or_clip_box_list()
	  	II: self._compute_second_stage_input_feature_maps()
	  	III: box_classifier_features = self._feature_extractor.extract_box_classifier_features()
	  	IV: box_predictions = self._mask_rcnn_box_predictor.predict()
	  	V: absolute_proposal_boxes = ops.normalized_to_image_coordinates()

B: losses_dict = detection_model.loss
	a. _loss_rpn
	  1. target_assigner.batch_assign_targets()
      I: target_assigner.assign()
        i: match_quality_matrix = self._similarity_calc.compare(groundtruth_boxes,anchors)
                -->sim_calc.IouSimilarity()-->box_list_ops.iou()
        ii: match = self._matcher.match(match_quality_matrix, **params)
                -->argmax_matcher.ArgMaxMatcher._match()
        iii: reg_targets = self._create_regression_targets(anchors,groundtruth_boxes,match)
        iv: cls_targets = self._create_classification_targets(groundtruth_labels,match)
   def _single_image_nms_fn(args):
       """Runs NMS on a single image and returns padded output.
 
       Args:
         args: A list of tensors consisting of the following:
           per_image_boxes - A [num_anchors, q, 4] float32 tensor containing
             detections. If `q` is 1 then same boxes are used for all classes
             otherwise, if `q` is equal to number of classes, class-specific
             boxes are used.
           per_image_scores - A [num_anchors, num_classes] float32 tensor
             containing the scores for each of the `num_anchors` detections.
           per_image_masks - A [num_anchors, q, mask_height, mask_width] float32
             tensor containing box masks. `q` can be either number of classes
             or 1 depending on whether a separate mask is predicted per class.
           per_image_additional_fields - (optional) A variable number of float32
             tensors each with size [num_anchors, ...].
           per_image_num_valid_boxes - A tensor of type `int32`. A 1-D tensor of
             shape [batch_size] representing the number of valid boxes to be
             considered for each image in the batch.  This parameter allows for
             ignoring zero paddings.
 
       Returns:
         'nmsed_boxes': A [max_detections, 4] float32 tensor containing the
           non-max suppressed boxes.
         'nmsed_scores': A [max_detections] float32 tensor containing the scores
           for the boxes.
         'nmsed_classes': A [max_detections] float32 tensor containing the class
           for boxes.
         'nmsed_masks': (optional) a [max_detections, mask_height, mask_width]
           float32 tensor containing masks for each selected box. This is set to
           None if input `masks` is None.
         'nmsed_additional_fields':  (optional) A variable number of float32
           tensors each with size [max_detections, ...] corresponding to the
           input `per_image_additional_fields`.
         'num_detections': A [batch_size] int32 tensor indicating the number of
           valid detections per batch item. Only the top num_detections[i]
           entries in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The
           rest of the entries are zero paddings.
       """
       per_image_boxes = args[0]
       per_image_scores = args[1]
       per_image_masks = args[2]
       per_image_additional_fields = {
           key: value
           for key, value in zip(additional_fields, args[3:-1])
       }
       per_image_num_valid_boxes = args[-1]
       per_image_boxes = tf.reshape(
           tf.slice(per_image_boxes, 3 * [0],
                    tf.stack([per_image_num_valid_boxes, -1, -1])),
           [-1, q, 4])
       per_image_scores = tf.reshape(
           tf.slice(per_image_scores, [0, 0],
                    tf.stack([per_image_num_valid_boxes, -1])),
           [-1, num_classes])
       per_image_masks = tf.reshape(
           tf.slice(per_image_masks, 4 * [0],
                    tf.stack([per_image_num_valid_boxes, -1, -1, -1])), [
                        -1, q, per_image_masks.shape[2].value,
                        per_image_masks.shape[3].value
                    ])
       if per_image_additional_fields is not None:
           for key, tensor in per_image_additional_fields.items():
               additional_field_shape = tensor.get_shape()
               additional_field_dim = len(additional_field_shape)
               per_image_additional_fields[key] = tf.reshape(
                   tf.slice(
                       per_image_additional_fields[key],
                       additional_field_dim * [0],
                       tf.stack([per_image_num_valid_boxes] +
                                (additional_field_dim - 1) * [-1])),
                   [-1] +
                   [dim.value for dim in additional_field_shape[1:]])
       nmsed_boxlist = multiclass_non_max_suppression(
           per_image_boxes,
           per_image_scores,
           score_thresh,
           iou_thresh,
           max_size_per_class,
           max_total_size,
           clip_window=clip_window,
           change_coordinate_frame=change_coordinate_frame,
           masks=per_image_masks,
           additional_fields=per_image_additional_fields)
       padded_boxlist = box_list_ops.pad_or_clip_box_list(
           nmsed_boxlist, max_total_size)
       num_detections = nmsed_boxlist.num_boxes()
       nmsed_boxes = padded_boxlist.get()
       nmsed_scores = padded_boxlist.get_field(
           fields.BoxListFields.scores)
       nmsed_classes = padded_boxlist.get_field(
           fields.BoxListFields.classes)
       nmsed_masks = padded_boxlist.get_field(fields.BoxListFields.masks)
       nmsed_additional_fields = [
           padded_boxlist.get_field(key)
           for key in per_image_additional_fields
       ]
       return ([nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks] +
               nmsed_additional_fields + [num_detections])
def batch_multiclass_non_max_suppression(boxes,
                                         scores,
                                         score_thresh,
                                         iou_thresh,
                                         max_size_per_class,
                                         max_total_size=0,
                                         clip_window=None,
                                         change_coordinate_frame=False,
                                         num_valid_boxes=None,
                                         masks=None,
                                         scope=None):
  """Multi-class version of non maximum suppression that operates on a batch.

  This op is similar to `multiclass_non_max_suppression` but operates on a batch
  of boxes and scores. See documentation for `multiclass_non_max_suppression`
  for details.

  Args:
    boxes: A [batch_size, num_anchors, q, 4] float32 tensor containing
      detections. If `q` is 1 then same boxes are used for all classes
        otherwise, if `q` is equal to number of classes, class-specific boxes
        are used.
    scores: A [batch_size, num_anchors, num_classes] float32 tensor containing
      the scores for each of the `num_anchors` detections.
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
      with previously selected boxes are removed).
    max_size_per_class: maximum number of retained boxes per class.
    max_total_size: maximum number of boxes retained over all classes. By
      default returns all boxes retained after capping boxes per class.
    clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
      representing the window to clip boxes to before performing non-max
      suppression.
    change_coordinate_frame: Whether to normalize coordinates after clipping
      relative to clip_window (this can only be set to True if a clip_window
      is provided)
    num_valid_boxes: (optional) a Tensor of type `int32`. A 1-D tensor of shape
      [batch_size] representing the number of valid boxes to be considered
        for each image in the batch.  This parameter allows for ignoring zero
        paddings.
    masks: (optional) a [batch_size, num_anchors, q, mask_height, mask_width]
      float32 tensor containing box masks. `q` can be either number of classes
      or 1 depending on whether a separate mask is predicted per class.
    scope: tf scope name.

  Returns:
    A dictionary containing the following entries:
    'detection_boxes': A [batch_size, max_detections, 4] float32 tensor
      containing the non-max suppressed boxes.
    'detection_scores': A [bath_size, max_detections] float32 tensor containing
      the scores for the boxes.
    'detection_classes': A [batch_size, max_detections] float32 tensor
      containing the class for boxes.
    'num_detections': A [batchsize] float32 tensor indicating the number of
      valid detections per batch item. Only the top num_detections[i] entries in
      nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the
      entries are zero paddings.
    'detection_masks': (optional) a
      [batch_size, max_detections, mask_height, mask_width] float32 tensor
      containing masks for each selected box.

  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
      a valid scores field.
  """
  q = boxes.shape[2].value
  num_classes = scores.shape[2].value
  if q != 1 and q != num_classes:
    raise ValueError('third dimension of boxes must be either 1 or equal '
                     'to the third dimension of scores')

  with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'):
    per_image_boxes_list = tf.unstack(boxes)
    per_image_scores_list = tf.unstack(scores)
    num_valid_boxes_list = len(per_image_boxes_list) * [None]
    per_image_masks_list = len(per_image_boxes_list) * [None]
    if num_valid_boxes is not None:
      num_valid_boxes_list = tf.unstack(num_valid_boxes)
    if masks is not None:
      per_image_masks_list = tf.unstack(masks)

    detection_boxes_list = []
    detection_scores_list = []
    detection_classes_list = []
    num_detections_list = []
    detection_masks_list = []
    for (per_image_boxes, per_image_scores, per_image_masks, num_valid_boxes
        ) in zip(per_image_boxes_list, per_image_scores_list,
                 per_image_masks_list, num_valid_boxes_list):
      if num_valid_boxes is not None:
        per_image_boxes = tf.reshape(
            tf.slice(per_image_boxes, 3*[0],
                     tf.stack([num_valid_boxes, -1, -1])), [-1, q, 4])
        per_image_scores = tf.reshape(
            tf.slice(per_image_scores, [0, 0],
                     tf.stack([num_valid_boxes, -1])), [-1, num_classes])
        if masks is not None:
          per_image_masks = tf.reshape(
              tf.slice(per_image_masks, 4*[0],
                       tf.stack([num_valid_boxes, -1, -1, -1])),
              [-1, q, masks.shape[3].value, masks.shape[4].value])
      nmsed_boxlist = multiclass_non_max_suppression(
          per_image_boxes,
          per_image_scores,
          score_thresh,
          iou_thresh,
          max_size_per_class,
          max_total_size,
          masks=per_image_masks,
          clip_window=clip_window,
          change_coordinate_frame=change_coordinate_frame)
      num_detections_list.append(tf.to_float(nmsed_boxlist.num_boxes()))
      padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
                                                         max_total_size)
      detection_boxes_list.append(padded_boxlist.get())
      detection_scores_list.append(
          padded_boxlist.get_field(fields.BoxListFields.scores))
      detection_classes_list.append(
          padded_boxlist.get_field(fields.BoxListFields.classes))
      if masks is not None:
        detection_masks_list.append(
            padded_boxlist.get_field(fields.BoxListFields.masks))

    nms_dict = {
        'detection_boxes': tf.stack(detection_boxes_list),
        'detection_scores': tf.stack(detection_scores_list),
        'detection_classes': tf.stack(detection_classes_list),
        'num_detections': tf.stack(num_detections_list)
    }
    if masks is not None:
      nms_dict['detection_masks'] = tf.stack(detection_masks_list)
    return nms_dict
Beispiel #11
0
def batch_multiclass_non_max_suppression(boxes,
                                         scores,
                                         score_thresh,
                                         iou_thresh,
                                         max_size_per_class,
                                         max_total_size=0,
                                         clip_window=None,
                                         change_coordinate_frame=False,
                                         num_valid_boxes=None,
                                         masks=None,
                                         scope=None):
    """Multi-class version of non maximum suppression that operates on a batch.

    This op is similar to `multiclass_non_max_suppression` but operates on a batch
    of boxes and scores. See documentation for `multiclass_non_max_suppression`
    for details.

    Args:
      boxes: A [batch_size, num_anchors, q, 4] float32 tensor containing
        detections. If `q` is 1 then same boxes are used for all classes
          otherwise, if `q` is equal to number of classes, class-specific boxes
          are used.
      scores: A [batch_size, num_anchors, num_classes] float32 tensor containing
        the scores for each of the `num_anchors` detections.
      score_thresh: scalar threshold for score (low scoring boxes are removed).
      iou_thresh: scalar threshold for IOU (new boxes that have high IOU overlap
        with previously selected boxes are removed).
      max_size_per_class: maximum number of retained boxes per class.
      max_total_size: maximum number of boxes retained over all classes. By
        default returns all boxes retained after capping boxes per class.
      clip_window: A float32 tensor of the form [y_min, x_min, y_max, x_max]
        representing the window to clip boxes to before performing non-max
        suppression.
      change_coordinate_frame: Whether to normalize coordinates after clipping
        relative to clip_window (this can only be set to True if a clip_window
        is provided)
      num_valid_boxes: (optional) a Tensor of type `int32`. A 1-D tensor of shape
        [batch_size] representing the number of valid boxes to be considered
          for each image in the batch.  This parameter allows for ignoring zero
          paddings.
      masks: (optional) a [batch_size, num_anchors, q, mask_height, mask_width]
        float32 tensor containing box masks. `q` can be either number of classes
        or 1 depending on whether a separate mask is predicted per class.
      scope: tf scope name.

    Returns:
      A dictionary containing the following entries:
      'detection_boxes': A [batch_size, max_detections, 4] float32 tensor
        containing the non-max suppressed boxes.
      'detection_scores': A [bath_size, max_detections] float32 tensor containing
        the scores for the boxes.
      'detection_classes': A [batch_size, max_detections] float32 tensor
        containing the class for boxes.
      'num_detections': A [batchsize] float32 tensor indicating the number of
        valid detections per batch item. Only the top num_detections[i] entries in
        nms_boxes[i], nms_scores[i] and nms_class[i] are valid. the rest of the
        entries are zero paddings.
      'detection_masks': (optional) a
        [batch_size, max_detections, mask_height, mask_width] float32 tensor
        containing masks for each selected box.

    Raises:
      ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
        a valid scores field.
    """
    q = boxes.shape[2].value
    num_classes = scores.shape[2].value
    if q != 1 and q != num_classes:
        raise ValueError('third dimension of boxes must be either 1 or equal '
                         'to the third dimension of scores')

    with tf.name_scope(scope, 'BatchMultiClassNonMaxSuppression'):
        per_image_boxes_list = tf.unstack(boxes)
        per_image_scores_list = tf.unstack(scores)
        num_valid_boxes_list = len(per_image_boxes_list) * [None]
        per_image_masks_list = len(per_image_boxes_list) * [None]
        if num_valid_boxes is not None:
            num_valid_boxes_list = tf.unstack(num_valid_boxes)
        if masks is not None:
            per_image_masks_list = tf.unstack(masks)

        detection_boxes_list = []
        detection_scores_list = []
        detection_classes_list = []
        num_detections_list = []
        detection_masks_list = []
        for (per_image_boxes, per_image_scores, per_image_masks, num_valid_boxes
             ) in zip(per_image_boxes_list, per_image_scores_list,
                      per_image_masks_list, num_valid_boxes_list):
            if num_valid_boxes is not None:
                per_image_boxes = tf.reshape(
                    tf.slice(per_image_boxes, 3 * [0],
                             tf.stack([num_valid_boxes, -1, -1])), [-1, q, 4])
                per_image_scores = tf.reshape(
                    tf.slice(per_image_scores, [0, 0],
                             tf.stack([num_valid_boxes, -1])), [-1, num_classes])
                if masks is not None:
                    per_image_masks = tf.reshape(
                        tf.slice(per_image_masks, 4 * [0],
                                 tf.stack([num_valid_boxes, -1, -1, -1])),
                        [-1, q, masks.shape[3].value, masks.shape[4].value])
            nmsed_boxlist = multiclass_non_max_suppression(
                per_image_boxes,
                per_image_scores,
                score_thresh,
                iou_thresh,
                max_size_per_class,
                max_total_size,
                masks=per_image_masks,
                clip_window=clip_window,
                change_coordinate_frame=change_coordinate_frame)
            num_detections_list.append(tf.to_float(nmsed_boxlist.num_boxes()))
            padded_boxlist = box_list_ops.pad_or_clip_box_list(nmsed_boxlist,
                                                               max_total_size)
            detection_boxes_list.append(padded_boxlist.get())
            detection_scores_list.append(
                padded_boxlist.get_field(fields.BoxListFields.scores))
            detection_classes_list.append(
                padded_boxlist.get_field(fields.BoxListFields.classes))
            if masks is not None:
                detection_masks_list.append(
                    padded_boxlist.get_field(fields.BoxListFields.masks))

        nms_dict = {
            'detection_boxes': tf.stack(detection_boxes_list),
            'detection_scores': tf.stack(detection_scores_list),
            'detection_classes': tf.stack(detection_classes_list),
            'num_detections': tf.stack(num_detections_list)
        }
        if masks is not None:
            nms_dict['detection_masks'] = tf.stack(detection_masks_list)
        return nms_dict
        def _single_image_nms_fn(args):
            """Runs NMS on a single image and returns padded output.

      Args:
        args: A list of tensors consisting of the following:
          per_image_boxes - A [num_anchors, q, 4] float32 tensor containing
            detections. If `q` is 1 then same boxes are used for all classes
            otherwise, if `q` is equal to number of classes, class-specific
            boxes are used.
          per_image_scores - A [num_anchors, num_classes] float32 tensor
            containing the scores for each of the `num_anchors` detections.
          per_image_masks - A [num_anchors, q, mask_height, mask_width] float32
            tensor containing box masks. `q` can be either number of classes
            or 1 depending on whether a separate mask is predicted per class.
          per_image_clip_window - A 1D float32 tensor of the form
            [ymin, xmin, ymax, xmax] representing the window to clip the boxes
            to.
          per_image_additional_fields - (optional) A variable number of float32
            tensors each with size [num_anchors, ...].
          per_image_num_valid_boxes - A tensor of type `int32`. A 1-D tensor of
            shape [batch_size] representing the number of valid boxes to be
            considered for each image in the batch.  This parameter allows for
            ignoring zero paddings.

      Returns:
        'nmsed_boxes': A [max_detections, 4] float32 tensor containing the
          non-max suppressed boxes.
        'nmsed_scores': A [max_detections] float32 tensor containing the scores
          for the boxes.
        'nmsed_classes': A [max_detections] float32 tensor containing the class
          for boxes.
        'nmsed_masks': (optional) a [max_detections, mask_height, mask_width]
          float32 tensor containing masks for each selected box. This is set to
          None if input `masks` is None.
        'nmsed_additional_fields':  (optional) A variable number of float32
          tensors each with size [max_detections, ...] corresponding to the
          input `per_image_additional_fields`.
        'num_detections': A [batch_size] int32 tensor indicating the number of
          valid detections per batch item. Only the top num_detections[i]
          entries in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The
          rest of the entries are zero paddings.
      """
            per_image_boxes = args[0]
            per_image_scores = args[1]
            per_image_masks = args[2]
            per_image_clip_window = args[3]
            # Make sure that the order of elements passed in args is aligned with
            # the iteration order of ordered_additional_fields
            per_image_additional_fields = {
                key: value
                for key, value in zip(ordered_additional_fields, args[4:-1])
            }
            per_image_num_valid_boxes = args[-1]
            if use_static_shapes:
                total_proposals = tf.shape(per_image_scores)
                per_image_scores = tf.where(
                    tf.less(tf.range(total_proposals[0]),
                            per_image_num_valid_boxes), per_image_scores,
                    tf.fill(total_proposals,
                            np.finfo('float32').min))
            else:
                per_image_boxes = tf.reshape(
                    tf.slice(per_image_boxes, 3 * [0],
                             tf.stack([per_image_num_valid_boxes, -1, -1])),
                    [-1, q, 4])
                per_image_scores = tf.reshape(
                    tf.slice(per_image_scores, [0, 0],
                             tf.stack([per_image_num_valid_boxes, -1])),
                    [-1, num_classes])
                per_image_masks = tf.reshape(
                    tf.slice(per_image_masks, 4 * [0],
                             tf.stack([per_image_num_valid_boxes, -1, -1,
                                       -1])),
                    [
                        -1, q, per_image_masks.shape[2].value,
                        per_image_masks.shape[3].value
                    ])
                if per_image_additional_fields is not None:
                    for key, tensor in per_image_additional_fields.items():
                        additional_field_shape = tensor.get_shape()
                        additional_field_dim = len(additional_field_shape)
                        per_image_additional_fields[key] = tf.reshape(
                            tf.slice(
                                per_image_additional_fields[key],
                                additional_field_dim * [0],
                                tf.stack([per_image_num_valid_boxes] +
                                         (additional_field_dim - 1) * [-1])),
                            [-1] +
                            [dim.value for dim in additional_field_shape[1:]])

            nmsed_boxlist, num_valid_nms_boxes = multiclass_non_max_suppression(
                per_image_boxes,
                per_image_scores,
                score_thresh,
                iou_thresh,
                max_size_per_class,
                max_total_size,
                clip_window=per_image_clip_window,
                change_coordinate_frame=change_coordinate_frame,
                masks=per_image_masks,
                pad_to_max_output_size=use_static_shapes,
                additional_fields=per_image_additional_fields)

            if not use_static_shapes:
                nmsed_boxlist = box_list_ops.pad_or_clip_box_list(
                    nmsed_boxlist, max_total_size)
            num_detections = num_valid_nms_boxes
            nmsed_boxes = nmsed_boxlist.get()
            nmsed_scores = nmsed_boxlist.get_field(fields.BoxListFields.scores)
            nmsed_classes = nmsed_boxlist.get_field(
                fields.BoxListFields.classes)
            nmsed_masks = nmsed_boxlist.get_field(fields.BoxListFields.masks)
            nmsed_additional_fields = []
            # Sorting is needed here to ensure that the values stored in
            # nmsed_additional_fields are always kept in the same order
            # across different execution runs.
            for key in sorted(per_image_additional_fields.keys()):
                nmsed_additional_fields.append(nmsed_boxlist.get_field(key))
            return ([nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks] +
                    nmsed_additional_fields + [num_detections])