コード例 #1
0
ファイル: postprocess.py プロジェクト: qing0991/tpu
    def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
        # Collects outputs from all levels into a list.
        boxes = []
        scores = []
        for i in range(self._min_level, self._max_level + 1):
            batch_size = tf.shape(class_outputs[i])[0]

            # Applies score transformation and remove the implicit background class.
            scores_i = _apply_score_activation(class_outputs[i],
                                               self._num_classes,
                                               self._score_activation)

            # Box decoding.
            # The anchor boxes are shared for all data in a batch.
            # One stage detector only supports class agnostic box regression.
            anchor_boxes_i = tf.reshape(anchor_boxes[i], [batch_size, -1, 4])
            box_outputs_i = tf.reshape(box_outputs[i], [batch_size, -1, 4])
            boxes_i = box_utils.decode_boxes(box_outputs_i, anchor_boxes_i)

            # Box clipping.
            boxes_i = box_utils.clip_boxes(boxes_i, image_shape)

            boxes.append(boxes_i)
            scores.append(scores_i)
        boxes = tf.concat(boxes, axis=1)
        scores = tf.concat(scores, axis=1)
        boxes = tf.expand_dims(boxes, axis=2)

        (nmsed_boxes, nmsed_scores, nmsed_classes,
         valid_detections) = self._generate_detections(boxes, scores)
        # Adds 1 to offset the background class which has index 0.
        nmsed_classes += 1
        return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
コード例 #2
0
    def get_proposals(self, rpn_boxes, rpn_labels, anchors, image_info):
        with tf.name_scope("proposals_layer"):
            height = tf.cast(image_info["input_size"][:, 0:1, None], tf.float32)
            width = tf.cast(image_info["input_size"][:, 1:2, None], tf.float32)
            valid_height = tf.cast(image_info["valid_size"][:, 0:1, None], tf.float32)
            valid_width = tf.cast(image_info["valid_size"][:, 1:2, None], tf.float32)

            rpn_boxes = tf.cast(rpn_boxes, tf.float32)
            rpn_labels = tf.cast(rpn_labels, tf.float32)
            anchors = tf.cast(anchors, tf.float32)
            
            if self.rpn_head.use_sigmoid:
                rpn_scores = tf.nn.sigmoid(rpn_labels)
            else:
                rpn_scores =  tf.nn.softmax(rpn_labels, -1)[:, :, 1:]

            rpn_boxes = self.rpn_head.bbox_decoder(anchors, rpn_boxes)
            rpn_boxes = box_utils.clip_boxes(rpn_boxes, valid_height, valid_width)
            rpn_boxes, rpn_scores = box_utils.filter_boxes(
                rpn_boxes, rpn_scores, self.proposal_cfg.min_size, valid_height, valid_width)
            rpn_boxes = box_utils.to_normalized_coordinates(rpn_boxes, height, width)
            
            rois, rois_scores = self.proposal_layer(rpn_boxes, rpn_scores)
        
            rois = box_utils.to_absolute_coordinates(rois, height, width)

            return rois, rois_scores
コード例 #3
0
ファイル: postprocess_ops.py プロジェクト: vishalbelsare/tpu
    def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
        # Collects outputs from all levels into a list.
        boxes = []
        encoded_boxes = []
        scores = []
        for i in range(self._min_level, self._max_level + 1):
            _, feature_h, feature_w, num_predicted_corners = (
                box_outputs[i].get_shape().as_list())
            num_anchors_per_locations = num_predicted_corners // 4
            num_classes = (class_outputs[i].get_shape().as_list()[-1] //
                           num_anchors_per_locations)
            num_anchors = feature_h * feature_w * num_anchors_per_locations

            scores_i = tf.reshape(class_outputs[i],
                                  [-1, num_anchors, num_classes])
            if self._apply_sigmoid:
                # Applies score transformation.
                scores_i = tf.sigmoid(scores_i)

            # Remove the implicit background class.
            scores_i = tf.slice(scores_i, [0, 0, 1], [-1, -1, -1])

            # Box decoding.
            # The anchor boxes are shared for all data in a batch.
            # One stage detector only supports class agnostic box regression.
            anchor_boxes_i = tf.reshape(anchor_boxes[i], [-1, num_anchors, 4])
            box_outputs_i = tf.reshape(box_outputs[i], [-1, num_anchors, 4])
            encoded_boxes.append(box_outputs_i)
            boxes_i = box_utils.decode_boxes(box_outputs_i, anchor_boxes_i)

            # Box clipping.
            boxes_i = box_utils.clip_boxes(boxes_i, image_shape)

            boxes.append(boxes_i)
            scores.append(scores_i)
        boxes = tf.concat(boxes, axis=1)
        boxes = tf.expand_dims(boxes, axis=2)
        encoded_boxes = tf.concat(encoded_boxes, axis=1)
        scores = tf.concat(scores, axis=1)

        if not self._apply_nms:
            return {
                'raw_boxes': boxes,
                'raw_encoded_boxes': encoded_boxes,
                'raw_scores': scores,
            }

        nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
            self._generate_detections(boxes, scores))

        # Adds 1 to offset the background class which has index 0.
        nmsed_classes += 1

        return {
            'num_detections': valid_detections,
            'detection_boxes': nmsed_boxes,
            'detection_classes': nmsed_classes,
            'detection_scores': nmsed_scores,
        }
コード例 #4
0
    def _box_outputs_to_rois(self, box_outputs, rois, correct_class,
                             image_info, regression_weights):
        """Convert the box_outputs to be the new rois for the next cascade.

    Args:
      box_outputs: `tensor` with predicted bboxes in the most recent frcnn head.
        The predictions are relative to the anchors/rois, so we must convert
          them to x/y min/max to be used as rois in the following layer.
      rois: `tensor`, the rois used as input for frcnn head.
      correct_class: `tensor` of classes that the box should be predicted for.
        Used to filter the correct bbox prediction since they are done for
        all classes if `class_agnostic_bbox_pred` is not set to true.
      image_info: `list`, the height and width of the input image.
      regression_weights: `list`, weights used for l1 loss in bounding box
        regression.

    Returns:
      new_rois: rois to be used for the next frcnn layer in the cascade.
    """
        if self._class_agnostic_bbox_pred:
            new_rois = box_outputs
        else:
            dtype = box_outputs.dtype
            batch_size, num_rois, num_class_specific_boxes = (
                box_outputs.get_shape().as_list())
            num_classes = num_class_specific_boxes // 4
            box_outputs = tf.reshape(box_outputs,
                                     [batch_size, num_rois, num_classes, 4])

            # correct_class is of shape [batch_size, num_rois].
            # correct_class_one_hot has shape [batch_size, num_rois, num_classes, 4].
            correct_class_one_hot = tf.tile(
                tf.expand_dims(
                    tf.one_hot(correct_class, num_classes, dtype=dtype), -1),
                [1, 1, 1, 4])
            new_rois = tf.reduce_sum(box_outputs * correct_class_one_hot,
                                     axis=2)
        new_rois = tf.cast(new_rois, tf.float32)

        # Before new_rois are predicting the relative center coords and
        # log scale offsets, so we need to run decode on them to get
        # the x/y min/max values needed for roi operations.
        # operations.
        new_rois = box_utils.decode_boxes(new_rois,
                                          rois,
                                          weights=regression_weights)
        new_rois = box_utils.clip_boxes(new_rois, image_info)
        return new_rois
コード例 #5
0
def resize_and_crop_boxes(boxes, image_scale, output_size, offset):
    """Resizes boxes to output size with scale and offset.

  Args:
    boxes: `Tensor` of shape [N, 4] representing ground truth boxes.
    image_scale: 2D float `Tensor` representing scale factors that apply to
      [height, width] of input image.
    output_size: 2D `Tensor` or `int` representing [height, width] of target
      output image size.
    offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
      boxes.

  Returns:
    boxes: `Tensor` of shape [N, 4] representing the scaled boxes.
  """
    # Adjusts box coordinates based on image_scale and offset.
    boxes *= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
    boxes -= tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
    # Clips the boxes.
    boxes = box_utils.clip_boxes(boxes, output_size)
    return boxes
コード例 #6
0
    def get_boxes(self, outputs, image_info):
        with tf.name_scope("get_boxes"):
            predicted_boxes = tf.cast(outputs["rcnn_boxes"], tf.float32)
            predicted_labels = tf.cast(outputs["rcnn_labels"], tf.float32)
            rois = tf.cast(outputs["rois"], tf.float32)
        
            input_size = image_info["input_size"]
            valid_height = image_info["valid_size"][:, 0:1, None]
            valid_width = image_info["valid_size"][:, 1:2, None]
            predicted_boxes = self.rcnn_head.bbox_head.bbox_decoder(rois, predicted_boxes)
            rpn_boxes = box_utils.clip_boxes(rpn_boxes, valid_height, valid_width)
            predicted_boxes = box_utils.to_normalized_coordinates(
                predicted_boxes, input_size[:, 0:1, None], input_size[:, 1:2, None])
            predicted_boxes = tf.clip_by_value(predicted_boxes, 0, 1)

            if self.rcnn_head.bbox_head.use_sigmoid:
                predicted_scores = tf.nn.sigmoid(predicted_labels)
            else:
                predicted_scores = tf.nn.softmax(predicted_labels, axis=-1)
                predicted_scores = predicted_scores[:, :, 1:]
            
            return self.nms(predicted_boxes, predicted_scores)
コード例 #7
0
ファイル: roi_ops.py プロジェクト: yuezha01/tpu
def multilevel_propose_rois(rpn_boxes,
                            rpn_scores,
                            anchor_boxes,
                            image_shape,
                            rpn_pre_nms_top_k=2000,
                            rpn_post_nms_top_k=1000,
                            rpn_nms_threshold=0.7,
                            rpn_score_threshold=0.0,
                            rpn_min_size_threshold=0.0,
                            decode_boxes=True,
                            clip_boxes=True,
                            use_batched_nms=False,
                            apply_sigmoid_to_score=True):
    """Proposes RoIs given a group of candidates from different FPN levels.

  The following describes the steps:
    1. For each individual level:
      a. Apply sigmoid transform if specified.
      b. Decode boxes if specified.
      c. Clip boxes if specified.
      d. Filter small boxes and those fall outside image if specified.
      e. Apply pre-NMS filtering including pre-NMS top k and score thresholding.
      f. Apply NMS.
    2. Aggregate post-NMS boxes from each level.
    3. Apply an overall top k to generate the final selected RoIs.

  Args:
    rpn_boxes: a dict with keys representing FPN levels and values representing
      box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
    rpn_scores: a dict with keys representing FPN levels and values representing
      logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
    anchor_boxes: a dict with keys representing FPN levels and values
      representing anchor box tensors of shape
      [batch_size, feature_h, feature_w, num_anchors * 4].
    image_shape: a tensor of shape [batch_size, 2] where the last dimension are
      [height, width] of the scaled image.
    rpn_pre_nms_top_k: an integer of top scoring RPN proposals *per level* to
      keep before applying NMS. Default: 2000.
    rpn_post_nms_top_k: an integer of top scoring RPN proposals *in total* to
      keep after applying NMS. Default: 1000.
    rpn_nms_threshold: a float between 0 and 1 representing the IoU threshold
      used for NMS. If 0.0, no NMS is applied. Default: 0.7.
    rpn_score_threshold: a float between 0 and 1 representing the minimal box
      score to keep before applying NMS. This is often used as a pre-filtering
      step for better performance. If 0, no filtering is applied. Default: 0.
    rpn_min_size_threshold: a float representing the minimal box size in each
      side (w.r.t. the scaled image) to keep before applying NMS. This is often
      used as a pre-filtering step for better performance. If 0, no filtering is
      applied. Default: 0.
    decode_boxes: a boolean indicating whether `rpn_boxes` needs to be decoded
      using `anchor_boxes`. If False, use `rpn_boxes` directly and ignore
      `anchor_boxes`. Default: True.
    clip_boxes: a boolean indicating whether boxes are first clipped to the
      scaled image size before appliying NMS. If False, no clipping is applied
      and `image_shape` is ignored. Default: True.
    use_batched_nms: a boolean indicating whether NMS is applied in batch using
      `tf.image.combined_non_max_suppression`. Currently only available in
      CPU/GPU. Default: False.
    apply_sigmoid_to_score: a boolean indicating whether apply sigmoid to
      `rpn_scores` before applying NMS. Default: True.

  Returns:
    selected_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 1],
      representing the scores of the selected proposals.
    selected_roi_scores: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
      representing the box coordinates of the selected proposals w.r.t. the
      scaled image.
  """
    with tf.name_scope('multilevel_propose_rois'):
        rois = []
        roi_scores = []
        for level in sorted(rpn_scores.keys()):
            with tf.name_scope('level_%d' % level):
                _, feature_h, feature_w, num_anchors_per_location = (
                    rpn_scores[level].get_shape().as_list())

                num_boxes = feature_h * feature_w * num_anchors_per_location
                this_level_scores = tf.reshape(rpn_scores[level],
                                               [-1, num_boxes])
                this_level_boxes = tf.reshape(rpn_boxes[level],
                                              [-1, num_boxes, 4])
                this_level_anchors = tf.cast(tf.reshape(
                    anchor_boxes[level], [-1, num_boxes, 4]),
                                             dtype=this_level_scores.dtype)

                if apply_sigmoid_to_score:
                    this_level_scores = tf.sigmoid(this_level_scores)

                image_shape = tf.expand_dims(image_shape, axis=1)
                if decode_boxes:
                    this_level_boxes = box_utils.decode_boxes(
                        this_level_boxes, this_level_anchors)
                if clip_boxes:
                    this_level_boxes = box_utils.clip_boxes(
                        this_level_boxes, image_shape)

                if rpn_min_size_threshold > 0.0:
                    this_level_boxes, this_level_scores = box_utils.filter_boxes(
                        this_level_boxes, this_level_scores, image_shape,
                        rpn_min_size_threshold)

                if rpn_nms_threshold > 0.0:
                    this_level_pre_nms_top_k = min(num_boxes,
                                                   rpn_pre_nms_top_k)
                    if use_batched_nms:
                        this_level_rois, this_level_roi_scores, _, _ = (
                            tf.image.combined_non_max_suppression(
                                tf.expand_dims(this_level_boxes, axis=2),
                                tf.expand_dims(this_level_scores, axis=-1),
                                max_output_size_per_class=
                                this_level_pre_nms_top_k,
                                max_total_size=rpn_post_nms_top_k,
                                iou_threshold=rpn_nms_threshold,
                                score_threshold=rpn_score_threshold,
                                pad_per_class=False,
                                clip_boxes=False))
                    else:
                        if rpn_score_threshold > 0.0:
                            this_level_boxes, this_level_scores = (
                                box_utils.filter_boxes_by_scores(
                                    this_level_boxes, this_level_scores,
                                    rpn_score_threshold))
                        this_level_boxes, this_level_scores = box_utils.top_k_boxes(
                            this_level_boxes,
                            this_level_scores,
                            k=this_level_pre_nms_top_k)
                        this_level_roi_scores, this_level_rois = (
                            nms.sorted_non_max_suppression_padded(
                                this_level_scores,
                                this_level_boxes,
                                max_output_size=rpn_post_nms_top_k,
                                iou_threshold=rpn_nms_threshold))
                else:
                    this_level_rois, this_level_roi_scores = box_utils.top_k_boxes(
                        this_level_rois,
                        this_level_scores,
                        k=rpn_post_nms_top_k)

                rois.append(this_level_rois)
                roi_scores.append(this_level_roi_scores)

        rois = tf.concat(rois, axis=1)
        roi_scores = tf.concat(roi_scores, axis=1)

        with tf.name_scope('top_k_rois'):
            _, num_valid_rois = roi_scores.get_shape().as_list()
            overall_top_k = min(num_valid_rois, rpn_post_nms_top_k)

            selected_rois, selected_roi_scores = box_utils.top_k_boxes(
                rois, roi_scores, k=overall_top_k)

        return selected_rois, selected_roi_scores
コード例 #8
0
ファイル: postprocess_ops.py プロジェクト: vishalbelsare/tpu
    def __call__(self,
                 box_outputs,
                 class_outputs,
                 anchor_boxes,
                 image_shape,
                 regression_weights=None,
                 bbox_per_class=True,
                 distill_class_outputs=None):
        """Generate final detections.

    Args:
      box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
        representing the class-specific box coordinates relative to anchors.
      class_outputs: a tensor of shape of [batch_size, K, num_classes]
        representing the class logits before applying score activation.
      anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
        corresponding anchor boxes w.r.t `box_outputs`.
      image_shape: a tensor of shape of [batch_size, 2] storing the image height
        and width w.r.t. the scaled image, i.e. the same image space as
        `box_outputs` and `anchor_boxes`.
      regression_weights: A list of four float numbers to scale coordinates.
      bbox_per_class: A `bool`. If True, perform per-class box regression.
      distill_class_outputs: a float tensor of shape of
        [batch_size, K, num_classes-1] representing the distilled class logits
        before applying score activation, without the background class.

    Returns:
      nmsed_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
        representing top detected boxes in [y1, x1, y2, x2].
      nmsed_scores: `float` Tensor of shape [batch_size, max_total_size]
        representing sorted confidence scores for detected boxes. The values are
        between [0, 1].
      nmsed_classes: `int` Tensor of shape [batch_size, max_total_size]
        representing classes for detected boxes.
      valid_detections: `int` Tensor of shape [batch_size] only the top
        `valid_detections` boxes are valid detections.
    """
        class_outputs_shape = tf.shape(class_outputs)
        num_locations = class_outputs_shape[1]
        num_classes = class_outputs_shape[-1]

        if self._discard_background:
            # Removes the background class before softmax.
            class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])

        class_outputs = tf.nn.softmax(class_outputs, axis=-1)

        if not self._discard_background:
            # Removes the background class.
            class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])

        if self._feat_distill == 'double_branch':
            distill_class_outputs = tf.nn.softmax(
                distill_class_outputs, axis=-1)  # [B, num_rois, num_classes]
            third_component = (
                1.0 - self._rare_mask
            ) * distill_class_outputs + self._rare_mask * class_outputs
            weighted_product = distill_class_outputs * class_outputs * third_component
            class_outputs = tf.pow(weighted_product, 1.0 / 3.0)

        if bbox_per_class:
            num_detections = num_locations * (num_classes - 1)
            box_outputs = tf.reshape(box_outputs,
                                     [-1, num_locations, num_classes, 4])
            box_outputs = tf.slice(box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
            anchor_boxes = tf.tile(tf.expand_dims(anchor_boxes, axis=2),
                                   [1, 1, num_classes - 1, 1])
            box_outputs = tf.reshape(box_outputs, [-1, num_detections, 4])
            anchor_boxes = tf.reshape(anchor_boxes, [-1, num_detections, 4])

        # Box decoding.
        if regression_weights is None:
            regression_weights = [10.0, 10.0, 5.0, 5.0]
        decoded_boxes = box_utils.decode_boxes(box_outputs,
                                               anchor_boxes,
                                               weights=regression_weights)

        # Box clipping
        decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)

        if bbox_per_class:
            decoded_boxes = tf.reshape(decoded_boxes,
                                       [-1, num_locations, num_classes - 1, 4])
        else:
            decoded_boxes = tf.expand_dims(decoded_boxes, axis=2)

        if not self._apply_nms:
            return {
                'raw_boxes': decoded_boxes,
                'raw_scores': class_outputs,
            }

        nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
            self._generate_detections(decoded_boxes, class_outputs))

        # Adds 1 to offset the background class which has index 0.
        nmsed_classes += 1

        return {
            'num_detections': valid_detections,
            'detection_boxes': nmsed_boxes,
            'detection_classes': nmsed_classes,
            'detection_scores': nmsed_scores,
        }
コード例 #9
0
  def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
    """Generate final detections.

    Args:
      box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
        representing the class-specific box coordinates relative to anchors.
      class_outputs: a tensor of shape of [batch_size, K, num_classes]
        representing the class logits before applying score activiation.
      anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
        corresponding anchor boxes w.r.t `box_outputs`.
      image_shape: a tensor of shape of [batch_size, 2] storing the image height
        and width w.r.t. the scaled image, i.e. the same image space as
        `box_outputs` and `anchor_boxes`.

    Returns:
      nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
        representing top detected boxes in [y1, x1, y2, x2].
      nms_scores: `float` Tensor of shape [batch_size, max_total_size]
        representing sorted confidence scores for detected boxes. The values are
        between [0, 1].
      nms_classes: `int` Tensor of shape [batch_size, max_total_size]
        representing classes for detected boxes.
      valid_detections: `int` Tensor of shape [batch_size] only the top
        `valid_detections` boxes are valid detections.
    """
    class_outputs = tf.nn.softmax(class_outputs, axis=-1)

    # Removes the background class.
    class_outputs_shape = tf.shape(class_outputs)
    batch_size = class_outputs_shape[0]
    num_locations = class_outputs_shape[1]
    num_classes = class_outputs_shape[-1]
    num_detections = num_locations * (num_classes - 1)

    class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])
    box_outputs = tf.reshape(
        box_outputs,
        tf.stack([batch_size, num_locations, num_classes, 4], axis=-1))
    box_outputs = tf.slice(
        box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
    anchor_boxes = tf.tile(
        tf.expand_dims(anchor_boxes, axis=2), [1, 1, num_classes - 1, 1])
    box_outputs = tf.reshape(
        box_outputs,
        tf.stack([batch_size, num_detections, 4], axis=-1))
    anchor_boxes = tf.reshape(
        anchor_boxes,
        tf.stack([batch_size, num_detections, 4], axis=-1))

    # Box decoding.
    decoded_boxes = box_utils.decode_boxes(
        box_outputs, anchor_boxes, weights=[10.0, 10.0, 5.0, 5.0])

    # Box clipping
    decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)

    decoded_boxes = tf.reshape(
        decoded_boxes,
        tf.stack([batch_size, num_locations, num_classes - 1, 4], axis=-1))

    nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
        self._generate_detections(decoded_boxes, class_outputs))

    # Adds 1 to offset the background class which has index 0.
    nmsed_classes += 1

    return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
コード例 #10
0
    def parse_train_data(self, data):
        """Parse data for ShapeMask training."""
        classes = data['groundtruth_classes']
        boxes = data['groundtruth_boxes']
        masks = data['groundtruth_instance_masks']
        is_crowds = data['groundtruth_is_crowd']
        # Skips annotations with `is_crowd` = True.
        if self._skip_crowd_during_training and self._is_training:
            num_groundtrtuhs = tf.shape(classes)[0]
            with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
                indices = tf.cond(
                    tf.greater(tf.size(is_crowds), 0),
                    lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
                    lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
            classes = tf.gather(classes, indices)
            boxes = tf.gather(boxes, indices)
            masks = tf.gather(masks, indices)

        # If not using category, makes all categories with id = 0.
        if not self._use_category:
            classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)

        image = self.get_normalized_image(data)

        # Flips image randomly during training.
        if self._aug_rand_hflip:
            image, boxes, masks = input_utils.random_horizontal_flip(
                image, boxes, masks)

        # Converts boxes from normalized coordinates to pixel coordinates.
        image_shape = tf.shape(image)[0:2]
        boxes = box_utils.denormalize_boxes(boxes, image_shape)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            self._output_size,
            aug_scale_min=self._aug_scale_min,
            aug_scale_max=self._aug_scale_max)
        self._train_image_scale = image_info[2, :]
        self._train_offset = image_info[3, :]

        # Resizes and crops boxes and masks.
        boxes = input_utils.resize_and_crop_boxes(boxes,
                                                  self._train_image_scale,
                                                  image_info[1, :],
                                                  self._train_offset)

        # Filters out ground truth boxes that are all zeros.
        indices = box_utils.get_non_empty_box_indices(boxes)
        boxes = tf.gather(boxes, indices)
        classes = tf.gather(classes, indices)
        masks = tf.gather(masks, indices)

        # Assigns anchors.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size, self._output_size)
        anchor_labeler = anchor.AnchorLabeler(input_anchor,
                                              self._match_threshold,
                                              self._unmatched_threshold)
        (cls_targets, box_targets,
         num_positives) = anchor_labeler.label_anchors(
             boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))

        # Sample groundtruth masks/boxes/classes for mask branch.
        num_masks = tf.shape(masks)[0]
        mask_shape = tf.shape(masks)[1:3]

        # Pad sampled boxes/masks/classes to a constant batch size.
        padded_boxes = input_utils.pad_to_fixed_size(boxes,
                                                     self._num_sampled_masks)
        padded_classes = input_utils.pad_to_fixed_size(classes,
                                                       self._num_sampled_masks)
        padded_masks = input_utils.pad_to_fixed_size(masks,
                                                     self._num_sampled_masks)

        # Randomly sample groundtruth masks for mask branch training. For the image
        # without groundtruth masks, it will sample the dummy padded tensors.
        rand_indices = tf.random.shuffle(
            tf.range(tf.maximum(num_masks, self._num_sampled_masks)))
        rand_indices = tf.mod(rand_indices, tf.maximum(num_masks, 1))
        rand_indices = rand_indices[0:self._num_sampled_masks]
        rand_indices = tf.reshape(rand_indices, [self._num_sampled_masks])

        sampled_boxes = tf.gather(padded_boxes, rand_indices)
        sampled_classes = tf.gather(padded_classes, rand_indices)
        sampled_masks = tf.gather(padded_masks, rand_indices)
        # Jitter the sampled boxes to mimic the noisy detections.
        sampled_boxes = box_utils.jitter_boxes(
            sampled_boxes, noise_scale=self._box_jitter_scale)
        sampled_boxes = box_utils.clip_boxes(sampled_boxes, self._output_size)
        # Compute mask targets in feature crop. A feature crop fully contains a
        # sampled box.
        mask_outer_boxes = box_utils.compute_outer_boxes(
            sampled_boxes, tf.shape(image)[0:2], scale=self._outer_box_scale)
        mask_outer_boxes = box_utils.clip_boxes(mask_outer_boxes,
                                                self._output_size)
        # Compensate the offset of mask_outer_boxes to map it back to original image
        # scale.
        mask_outer_boxes_ori = mask_outer_boxes
        mask_outer_boxes_ori += tf.tile(
            tf.expand_dims(self._train_offset, axis=0), [1, 2])
        mask_outer_boxes_ori /= tf.tile(
            tf.expand_dims(self._train_image_scale, axis=0), [1, 2])
        norm_mask_outer_boxes_ori = box_utils.normalize_boxes(
            mask_outer_boxes_ori, mask_shape)

        # Set sampled_masks shape to [batch_size, height, width, 1].
        sampled_masks = tf.cast(tf.expand_dims(sampled_masks, axis=-1),
                                tf.float32)
        mask_targets = tf.image.crop_and_resize(
            sampled_masks,
            norm_mask_outer_boxes_ori,
            box_ind=tf.range(self._num_sampled_masks),
            crop_size=[self._mask_crop_size, self._mask_crop_size],
            method='bilinear',
            extrapolation_value=0,
            name='train_mask_targets')
        mask_targets = tf.where(tf.greater_equal(mask_targets, 0.5),
                                tf.ones_like(mask_targets),
                                tf.zeros_like(mask_targets))
        mask_targets = tf.squeeze(mask_targets, axis=-1)
        if self._up_sample_factor > 1:
            fine_mask_targets = tf.image.crop_and_resize(
                sampled_masks,
                norm_mask_outer_boxes_ori,
                box_ind=tf.range(self._num_sampled_masks),
                crop_size=[
                    self._mask_crop_size * self._up_sample_factor,
                    self._mask_crop_size * self._up_sample_factor
                ],
                method='bilinear',
                extrapolation_value=0,
                name='train_mask_targets')
            fine_mask_targets = tf.where(
                tf.greater_equal(fine_mask_targets, 0.5),
                tf.ones_like(fine_mask_targets),
                tf.zeros_like(fine_mask_targets))
            fine_mask_targets = tf.squeeze(fine_mask_targets, axis=-1)
        else:
            fine_mask_targets = mask_targets

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        valid_image = tf.cast(tf.not_equal(num_masks, 0), tf.int32)
        if self._mask_train_class == 'all':
            mask_is_valid = valid_image * tf.ones_like(sampled_classes,
                                                       tf.int32)
        else:
            # Get the intersection of sampled classes with training splits.
            mask_valid_classes = tf.cast(
                tf.expand_dims(
                    class_utils.coco_split_class_ids(self._mask_train_class),
                    1), sampled_classes.dtype)
            match = tf.reduce_any(
                tf.equal(tf.expand_dims(sampled_classes, 0),
                         mask_valid_classes), 0)
            mask_is_valid = valid_image * tf.cast(match, tf.int32)

        # Packs labels for model_fn outputs.
        labels = {
            'cls_targets': cls_targets,
            'box_targets': box_targets,
            'anchor_boxes': input_anchor.multilevel_boxes,
            'num_positives': num_positives,
            'image_info': image_info,
            # For ShapeMask.
            'mask_boxes': sampled_boxes,
            'mask_outer_boxes': mask_outer_boxes,
            'mask_targets': mask_targets,
            'fine_mask_targets': fine_mask_targets,
            'mask_classes': sampled_classes,
            'mask_is_valid': mask_is_valid,
        }
        return image, labels