def call(self, inputs):
    box_outputs, class_outputs, anchor_boxes, image_shape = inputs
    # Collects outputs from all levels into a list.
    boxes = []
    scores = []
    for i in range(self._min_level, self._max_level + 1):
      batch_size = tf.shape(input=class_outputs[i])[0]

      # Applies score transformation and remove the implicit background class.
      scores_i = _apply_score_activation(
          class_outputs[i], self._num_classes, self._score_activation)

      # Box decoding.
      # The anchor boxes are shared for all data in a batch.
      # One stage detector only supports class agnostic box regression.
      anchor_boxes_i = tf.reshape(anchor_boxes[i], [batch_size, -1, 4])
      box_outputs_i = tf.reshape(box_outputs[i], [batch_size, -1, 4])
      boxes_i = box_utils.decode_boxes(box_outputs_i, anchor_boxes_i)

      # Box clipping.
      boxes_i = box_utils.clip_boxes(boxes_i, image_shape)

      boxes.append(boxes_i)
      scores.append(scores_i)
    boxes = tf.concat(boxes, axis=1)
    scores = tf.concat(scores, axis=1)
    boxes = tf.expand_dims(boxes, axis=2)

    (nmsed_boxes, nmsed_scores, nmsed_classes,
     valid_detections) = self._generate_detections(
         tf.cast(boxes, tf.float32), tf.cast(scores, tf.float32))
    # Adds 1 to offset the background class which has index 0.
    nmsed_classes += 1
    return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
def resize_and_crop_boxes(boxes, image_scale, output_size, offset):
    """Resizes boxes to output size with scale and offset.

  Args:
    boxes: `Tensor` of shape [N, 4] representing ground truth boxes.
    image_scale: 2D float `Tensor` representing scale factors that apply to
      [height, width] of input image.
    output_size: 2D `Tensor` or `int` representing [height, width] of target
      output image size.
    offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
      boxes.

  Returns:
    boxes: `Tensor` of shape [N, 4] representing the scaled boxes.
  """
    # Adjusts box coordinates based on image_scale and offset.
    boxes *= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
    boxes -= tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
    # Clips the boxes.
    boxes = box_utils.clip_boxes(boxes, output_size)
    return boxes
    def call(self, box_outputs, class_outputs, anchor_boxes, image_shape):
        # Collects outputs from all levels into a list.
        boxes = []
        scores = []
        for i in range(self._min_level, self._max_level + 1):
            box_outputs_i_shape = tf.shape(box_outputs[i])
            batch_size = box_outputs_i_shape[0]
            num_anchors_per_locations = box_outputs_i_shape[-1] // 4
            num_classes = tf.shape(
                class_outputs[i])[-1] // num_anchors_per_locations

            # Applies score transformation and remove the implicit background class.
            scores_i = tf.sigmoid(
                tf.reshape(class_outputs[i], [batch_size, -1, num_classes]))
            scores_i = tf.slice(scores_i, [0, 0, 1], [-1, -1, -1])

            # Box decoding.
            # The anchor boxes are shared for all data in a batch.
            # One stage detector only supports class agnostic box regression.
            anchor_boxes_i = tf.reshape(anchor_boxes[i], [batch_size, -1, 4])
            box_outputs_i = tf.reshape(box_outputs[i], [batch_size, -1, 4])
            boxes_i = box_utils.decode_boxes(box_outputs_i, anchor_boxes_i)

            # Box clipping.
            boxes_i = box_utils.clip_boxes(boxes_i, image_shape)

            boxes.append(boxes_i)
            scores.append(scores_i)
        boxes = tf.concat(boxes, axis=1)
        scores = tf.concat(scores, axis=1)

        nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
            self._generate_detections(tf.expand_dims(boxes, axis=2), scores))

        # Adds 1 to offset the background class which has index 0.
        nmsed_classes += 1
        return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
    def __call__(self,
                 box_outputs,
                 class_outputs,
                 anchor_boxes,
                 image_shape,
                 is_single_fg_score=False,
                 keep_nms=True):
        """Generate final detections for Object Localization Network (OLN).

    Args:
      box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
        representing the class-specific box coordinates relative to anchors.
      class_outputs: a tensor of shape of [batch_size, K, num_classes]
        representing the class logits before applying score activiation.
      anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
        corresponding anchor boxes w.r.t `box_outputs`.
      image_shape: a tensor of shape of [batch_size, 2] storing the image height
        and width w.r.t. the scaled image, i.e. the same image space as
        `box_outputs` and `anchor_boxes`.
      is_single_fg_score: a Bool indicator of whether class_outputs includes the
        background scores concatenated or not. By default, class_outputs is a
        concatenation of both scores for the foreground and background. That is,
        scores_without_bg=False.
      keep_nms: a Bool indicator of whether to perform NMS or not.

    Returns:
      nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
        representing top detected boxes in [y1, x1, y2, x2].
      nms_scores: `float` Tensor of shape [batch_size, max_total_size]
        representing sorted confidence scores for detected boxes. The values are
        between [0, 1].
      nms_classes: `int` Tensor of shape [batch_size, max_total_size]
        representing classes for detected boxes.
      valid_detections: `int` Tensor of shape [batch_size] only the top
        `valid_detections` boxes are valid detections.
    """
        if is_single_fg_score:
            # Concatenates dummy background scores.
            dummy_bg_scores = tf.zeros_like(class_outputs)
            class_outputs = tf.stack([dummy_bg_scores, class_outputs], -1)
        else:
            class_outputs = tf.nn.softmax(class_outputs, axis=-1)

        # Removes the background class.
        class_outputs_shape = tf.shape(class_outputs)
        batch_size = class_outputs_shape[0]
        num_locations = class_outputs_shape[1]
        num_classes = class_outputs_shape[-1]
        num_detections = num_locations * (num_classes - 1)

        class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])
        box_outputs = tf.reshape(
            box_outputs,
            tf.stack([batch_size, num_locations, num_classes, 4], axis=-1))
        box_outputs = tf.slice(box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
        anchor_boxes = tf.tile(tf.expand_dims(anchor_boxes, axis=2),
                               [1, 1, num_classes - 1, 1])
        box_outputs = tf.reshape(
            box_outputs, tf.stack([batch_size, num_detections, 4], axis=-1))
        anchor_boxes = tf.reshape(
            anchor_boxes, tf.stack([batch_size, num_detections, 4], axis=-1))

        # Box decoding. For RPN outputs, box_outputs are all zeros.
        decoded_boxes = box_utils.decode_boxes(box_outputs,
                                               anchor_boxes,
                                               weights=[10.0, 10.0, 5.0, 5.0])

        # Box clipping
        decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)

        decoded_boxes = tf.reshape(
            decoded_boxes,
            tf.stack([batch_size, num_locations, num_classes - 1, 4], axis=-1))

        if keep_nms:
            nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
                self._generate_detections(decoded_boxes, class_outputs))
            # Adds 1 to offset the background class which has index 0.
            nmsed_classes += 1
        else:
            nmsed_boxes = decoded_boxes[:, :, 0, :]
            nmsed_scores = class_outputs[:, :, 0]
            nmsed_classes = tf.cast(tf.ones_like(nmsed_scores), tf.int32)
            valid_detections = tf.cast(
                tf.reduce_sum(tf.ones_like(nmsed_scores), axis=-1), tf.int32)

        return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
    def call(self, box_outputs, class_outputs, anchor_boxes, image_shape):
        """Generate final detections.

    Args:
      box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
        representing the class-specific box coordinates relative to anchors.
      class_outputs: a tensor of shape of [batch_size, K, num_classes]
        representing the class logits before applying score activiation.
      anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
        corresponding anchor boxes w.r.t `box_outputs`.
      image_shape: a tensor of shape of [batch_size, 2] storing the image height
        and width w.r.t. the scaled image, i.e. the same image space as
        `box_outputs` and `anchor_boxes`.

    Returns:
      nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
        representing top detected boxes in [y1, x1, y2, x2].
      nms_scores: `float` Tensor of shape [batch_size, max_total_size]
        representing sorted confidence scores for detected boxes. The values are
        between [0, 1].
      nms_classes: `int` Tensor of shape [batch_size, max_total_size]
        representing classes for detected boxes.
      valid_detections: `int` Tensor of shape [batch_size] only the top
        `valid_detections` boxes are valid detections.
    """
        class_outputs = tf.nn.softmax(class_outputs, axis=-1)

        # Removes the background class.
        class_outputs_shape = tf.shape(class_outputs)
        batch_size = class_outputs_shape[0]
        num_locations = class_outputs_shape[1]
        num_classes = class_outputs_shape[-1]
        num_detections = num_locations * (num_classes - 1)

        class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])
        box_outputs = tf.reshape(
            box_outputs,
            tf.stack([batch_size, num_locations, num_classes, 4], axis=-1))
        box_outputs = tf.slice(box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
        anchor_boxes = tf.tile(tf.expand_dims(anchor_boxes, axis=2),
                               [1, 1, num_classes - 1, 1])
        box_outputs = tf.reshape(
            box_outputs, tf.stack([batch_size, num_detections, 4], axis=-1))
        anchor_boxes = tf.reshape(
            anchor_boxes, tf.stack([batch_size, num_detections, 4], axis=-1))

        # Box decoding.
        decoded_boxes = box_utils.decode_boxes(box_outputs,
                                               anchor_boxes,
                                               weights=[10.0, 10.0, 5.0, 5.0])

        # Box clipping
        decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)

        decoded_boxes = tf.reshape(
            decoded_boxes,
            tf.stack([batch_size, num_locations, num_classes - 1, 4], axis=-1))

        nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
            self._generate_detections(decoded_boxes, class_outputs))

        # Adds 1 to offset the background class which has index 0.
        nmsed_classes += 1

        return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
    def oln_multilevel_propose_rois(
        self,
        rpn_boxes,
        rpn_scores,
        anchor_boxes,
        image_shape,
        rpn_pre_nms_top_k=2000,
        rpn_post_nms_top_k=1000,
        rpn_nms_threshold=0.7,
        rpn_score_threshold=0.0,
        rpn_min_size_threshold=0.0,
        decode_boxes=True,
        clip_boxes=True,
        use_batched_nms=False,
        apply_sigmoid_to_score=True,
        is_box_lrtb=False,
        rpn_object_scores=None,
    ):
        """Proposes RoIs given a group of candidates from different FPN levels.

    The following describes the steps:
      1. For each individual level:
        a. Adjust scores for each level if specified by rpn_object_scores.
        b. Apply sigmoid transform if specified.
        c. Decode boxes (either of xyhw or left-right-top-bottom format) if
          specified.
        d. Clip boxes if specified.
        e. Filter small boxes and those fall outside image if specified.
        f. Apply pre-NMS filtering including pre-NMS top k and score
           thresholding.
        g. Apply NMS.
      2. Aggregate post-NMS boxes from each level.
      3. Apply an overall top k to generate the final selected RoIs.

    Args:
      rpn_boxes: a dict with keys representing FPN levels and values
        representing box tenors of shape [batch_size, feature_h, feature_w,
        num_anchors * 4].
      rpn_scores: a dict with keys representing FPN levels and values
        representing logit tensors of shape [batch_size, feature_h, feature_w,
        num_anchors].
      anchor_boxes: a dict with keys representing FPN levels and values
        representing anchor box tensors of shape [batch_size, feature_h,
        feature_w, num_anchors * 4].
      image_shape: a tensor of shape [batch_size, 2] where the last dimension
        are [height, width] of the scaled image.
      rpn_pre_nms_top_k: an integer of top scoring RPN proposals *per level* to
        keep before applying NMS. Default: 2000.
      rpn_post_nms_top_k: an integer of top scoring RPN proposals *in total* to
        keep after applying NMS. Default: 1000.
      rpn_nms_threshold: a float between 0 and 1 representing the IoU threshold
        used for NMS. If 0.0, no NMS is applied. Default: 0.7.
      rpn_score_threshold: a float between 0 and 1 representing the minimal box
        score to keep before applying NMS. This is often used as a pre-filtering
        step for better performance. If 0, no filtering is applied. Default: 0.
      rpn_min_size_threshold: a float representing the minimal box size in each
        side (w.r.t. the scaled image) to keep before applying NMS. This is
        often used as a pre-filtering step for better performance. If 0, no
        filtering is applied. Default: 0.
      decode_boxes: a boolean indicating whether `rpn_boxes` needs to be decoded
        using `anchor_boxes`. If False, use `rpn_boxes` directly and ignore
        `anchor_boxes`. Default: True.
      clip_boxes: a boolean indicating whether boxes are first clipped to the
        scaled image size before appliying NMS. If False, no clipping is applied
        and `image_shape` is ignored. Default: True.
      use_batched_nms: a boolean indicating whether NMS is applied in batch
        using `tf.image.combined_non_max_suppression`. Currently only available
        in CPU/GPU. Default: False.
      apply_sigmoid_to_score: a boolean indicating whether apply sigmoid to
        `rpn_scores` before applying NMS. Default: True.
      is_box_lrtb: a bool indicating whether boxes are in lrtb (=left,right,top,
        bottom) format.
      rpn_object_scores: a predicted objectness score (e.g., centerness). In
        OLN, we use object_scores=centerness as a replacement of the scores at
        each level. A dict with keys representing FPN levels and values
        representing logit tensors of shape [batch_size, feature_h, feature_w,
        num_anchors].

    Returns:
      selected_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
        representing the box coordinates of the selected proposals w.r.t. the
        scaled image.
      selected_roi_scores: a tensor of shape [batch_size, rpn_post_nms_top_k,
      1],representing the scores of the selected proposals.
    """
        with tf.name_scope('multilevel_propose_rois'):
            rois = []
            roi_scores = []
            image_shape = tf.expand_dims(image_shape, axis=1)
            for level in sorted(rpn_scores.keys()):
                with tf.name_scope('level_%d' % level):
                    _, feature_h, feature_w, num_anchors_per_location = (
                        rpn_scores[level].get_shape().as_list())

                    num_boxes = feature_h * feature_w * num_anchors_per_location
                    this_level_scores = tf.reshape(rpn_scores[level],
                                                   [-1, num_boxes])
                    this_level_boxes = tf.reshape(rpn_boxes[level],
                                                  [-1, num_boxes, 4])
                    this_level_anchors = tf.cast(tf.reshape(
                        anchor_boxes[level], [-1, num_boxes, 4]),
                                                 dtype=this_level_scores.dtype)

                    if rpn_object_scores:
                        this_level_object_scores = rpn_object_scores[level]
                        this_level_object_scores = tf.reshape(
                            this_level_object_scores, [-1, num_boxes])
                        this_level_object_scores = tf.cast(
                            this_level_object_scores, this_level_scores.dtype)
                        this_level_scores = this_level_object_scores

                    if apply_sigmoid_to_score:
                        this_level_scores = tf.sigmoid(this_level_scores)

                    if decode_boxes:
                        if is_box_lrtb:  # Box in left-right-top-bottom format.
                            this_level_boxes = box_utils.decode_boxes_lrtb(
                                this_level_boxes, this_level_anchors)
                        else:  # Box in standard x-y-h-w format.
                            this_level_boxes = box_utils.decode_boxes(
                                this_level_boxes, this_level_anchors)

                    if clip_boxes:
                        this_level_boxes = box_utils.clip_boxes(
                            this_level_boxes, image_shape)

                    if rpn_min_size_threshold > 0.0:
                        this_level_boxes, this_level_scores = box_utils.filter_boxes(
                            this_level_boxes, this_level_scores, image_shape,
                            rpn_min_size_threshold)

                    this_level_pre_nms_top_k = min(num_boxes,
                                                   rpn_pre_nms_top_k)
                    this_level_post_nms_top_k = min(num_boxes,
                                                    rpn_post_nms_top_k)
                    if rpn_nms_threshold > 0.0:
                        if use_batched_nms:
                            this_level_rois, this_level_roi_scores, _, _ = (
                                tf.image.combined_non_max_suppression(
                                    tf.expand_dims(this_level_boxes, axis=2),
                                    tf.expand_dims(this_level_scores, axis=-1),
                                    max_output_size_per_class=
                                    this_level_pre_nms_top_k,
                                    max_total_size=this_level_post_nms_top_k,
                                    iou_threshold=rpn_nms_threshold,
                                    score_threshold=rpn_score_threshold,
                                    pad_per_class=False,
                                    clip_boxes=False))
                        else:
                            if rpn_score_threshold > 0.0:
                                this_level_boxes, this_level_scores = (
                                    box_utils.filter_boxes_by_scores(
                                        this_level_boxes, this_level_scores,
                                        rpn_score_threshold))
                            this_level_boxes, this_level_scores = box_utils.top_k_boxes(
                                this_level_boxes,
                                this_level_scores,
                                k=this_level_pre_nms_top_k)
                            this_level_roi_scores, this_level_rois = (
                                nms.sorted_non_max_suppression_padded(
                                    this_level_scores,
                                    this_level_boxes,
                                    max_output_size=this_level_post_nms_top_k,
                                    iou_threshold=rpn_nms_threshold))
                    else:
                        this_level_rois, this_level_roi_scores = box_utils.top_k_boxes(
                            this_level_rois,
                            this_level_scores,
                            k=this_level_post_nms_top_k)

                    rois.append(this_level_rois)
                    roi_scores.append(this_level_roi_scores)

            all_rois = tf.concat(rois, axis=1)
            all_roi_scores = tf.concat(roi_scores, axis=1)

            with tf.name_scope('top_k_rois'):
                _, num_valid_rois = all_roi_scores.get_shape().as_list()
                overall_top_k = min(num_valid_rois, rpn_post_nms_top_k)

                selected_rois, selected_roi_scores = box_utils.top_k_boxes(
                    all_rois, all_roi_scores, k=overall_top_k)

            return selected_rois, selected_roi_scores
Beispiel #7
0
    def _parse_train_data(self, data):
        """Parse data for ShapeMask training."""
        classes = data['groundtruth_classes']
        boxes = data['groundtruth_boxes']
        masks = data['groundtruth_instance_masks']
        is_crowds = data['groundtruth_is_crowd']
        # Skips annotations with `is_crowd` = True.
        if self._skip_crowd_during_training and self._is_training:
            num_groundtrtuhs = tf.shape(classes)[0]
            with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
                indices = tf.cond(
                    tf.greater(tf.size(is_crowds), 0),
                    lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
                    lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
            classes = tf.gather(classes, indices)
            boxes = tf.gather(boxes, indices)
            masks = tf.gather(masks, indices)

        # Gets original image and its size.
        image = data['image']
        image_shape = tf.shape(image)[0:2]

        # If not using category, makes all categories with id = 0.
        if not self._use_category:
            classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Flips image randomly during training.
        if self._aug_rand_hflip:
            image, boxes, masks = input_utils.random_horizontal_flip(
                image, boxes, masks)

        # Converts boxes from normalized coordinates to pixel coordinates.
        boxes = box_utils.denormalize_boxes(boxes, image_shape)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            self._output_size,
            aug_scale_min=self._aug_scale_min,
            aug_scale_max=self._aug_scale_max)
        image_scale = image_info[2, :]
        offset = image_info[3, :]

        # Resizes and crops boxes and masks.
        boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
                                                  self._output_size, offset)

        # Filters out ground truth boxes that are all zeros.
        indices = input_utils.get_non_empty_box_indices(boxes)
        boxes = tf.gather(boxes, indices)
        classes = tf.gather(classes, indices)
        masks = tf.gather(masks, indices)

        # Assigns anchors.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size, self._output_size)
        anchor_labeler = anchor.AnchorLabeler(input_anchor,
                                              self._match_threshold,
                                              self._unmatched_threshold)
        (cls_targets, box_targets,
         num_positives) = anchor_labeler.label_anchors(
             boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))

        # Sample groundtruth masks/boxes/classes for mask branch.
        num_masks = tf.shape(masks)[0]
        mask_shape = tf.shape(masks)[1:3]

        # Pad sampled boxes/masks/classes to a constant batch size.
        padded_boxes = input_utils.pad_to_fixed_size(boxes,
                                                     self._num_sampled_masks)
        padded_classes = input_utils.pad_to_fixed_size(classes,
                                                       self._num_sampled_masks)
        padded_masks = input_utils.pad_to_fixed_size(masks,
                                                     self._num_sampled_masks)

        # Randomly sample groundtruth masks for mask branch training. For the image
        # without groundtruth masks, it will sample the dummy padded tensors.
        rand_indices = tf.random.shuffle(
            tf.range(tf.maximum(num_masks, self._num_sampled_masks)))
        rand_indices = tf.math.mod(rand_indices, tf.maximum(num_masks, 1))
        rand_indices = rand_indices[0:self._num_sampled_masks]
        rand_indices = tf.reshape(rand_indices, [self._num_sampled_masks])

        sampled_boxes = tf.gather(padded_boxes, rand_indices)
        sampled_classes = tf.gather(padded_classes, rand_indices)
        sampled_masks = tf.gather(padded_masks, rand_indices)
        # Jitter the sampled boxes to mimic the noisy detections.
        sampled_boxes = box_utils.jitter_boxes(
            sampled_boxes, noise_scale=self._box_jitter_scale)
        sampled_boxes = box_utils.clip_boxes(sampled_boxes, self._output_size)
        # Compute mask targets in feature crop. A feature crop fully contains a
        # sampled box.
        mask_outer_boxes = box_utils.compute_outer_boxes(
            sampled_boxes, tf.shape(image)[0:2], scale=self._outer_box_scale)
        mask_outer_boxes = box_utils.clip_boxes(mask_outer_boxes,
                                                self._output_size)
        # Compensate the offset of mask_outer_boxes to map it back to original image
        # scale.
        mask_outer_boxes_ori = mask_outer_boxes
        mask_outer_boxes_ori += tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
        mask_outer_boxes_ori /= tf.tile(tf.expand_dims(image_scale, axis=0),
                                        [1, 2])
        norm_mask_outer_boxes_ori = box_utils.normalize_boxes(
            mask_outer_boxes_ori, mask_shape)

        # Set sampled_masks shape to [batch_size, height, width, 1].
        sampled_masks = tf.cast(tf.expand_dims(sampled_masks, axis=-1),
                                tf.float32)
        mask_targets = tf.image.crop_and_resize(
            sampled_masks,
            norm_mask_outer_boxes_ori,
            box_indices=tf.range(self._num_sampled_masks),
            crop_size=[self._mask_crop_size, self._mask_crop_size],
            method='bilinear',
            extrapolation_value=0,
            name='train_mask_targets')
        mask_targets = tf.where(tf.greater_equal(mask_targets, 0.5),
                                tf.ones_like(mask_targets),
                                tf.zeros_like(mask_targets))
        mask_targets = tf.squeeze(mask_targets, axis=-1)
        if self._up_sample_factor > 1:
            fine_mask_targets = tf.image.crop_and_resize(
                sampled_masks,
                norm_mask_outer_boxes_ori,
                box_indices=tf.range(self._num_sampled_masks),
                crop_size=[
                    self._mask_crop_size * self._up_sample_factor,
                    self._mask_crop_size * self._up_sample_factor
                ],
                method='bilinear',
                extrapolation_value=0,
                name='train_mask_targets')
            fine_mask_targets = tf.where(
                tf.greater_equal(fine_mask_targets, 0.5),
                tf.ones_like(fine_mask_targets),
                tf.zeros_like(fine_mask_targets))
            fine_mask_targets = tf.squeeze(fine_mask_targets, axis=-1)
        else:
            fine_mask_targets = mask_targets

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        valid_image = tf.cast(tf.not_equal(num_masks, 0), tf.int32)
        if self._mask_train_class == 'all':
            mask_is_valid = valid_image * tf.ones_like(sampled_classes,
                                                       tf.int32)
        else:
            # Get the intersection of sampled classes with training splits.
            mask_valid_classes = tf.cast(
                tf.expand_dims(
                    class_utils.coco_split_class_ids(self._mask_train_class),
                    1), sampled_classes.dtype)
            match = tf.reduce_any(
                tf.equal(tf.expand_dims(sampled_classes, 0),
                         mask_valid_classes), 0)
            mask_is_valid = valid_image * tf.cast(match, tf.int32)

        # Packs labels for model_fn outputs.
        labels = {
            'cls_targets': cls_targets,
            'box_targets': box_targets,
            'anchor_boxes': input_anchor.multilevel_boxes,
            'num_positives': num_positives,
            'image_info': image_info,
            # For ShapeMask.
            'mask_boxes': sampled_boxes,
            'mask_outer_boxes': mask_outer_boxes,
            'mask_targets': mask_targets,
            'fine_mask_targets': fine_mask_targets,
            'mask_classes': sampled_classes,
            'mask_is_valid': mask_is_valid,
        }
        return image, labels