示例#1
0
    def _parse_single_example(self, example):
        """Parses a single serialized tf.Example proto.

    Args:
      example: a serialized tf.Example proto string.

    Returns:
      A dictionary of groundtruth with the following fields:
        source_id: a scalar tensor of int64 representing the image source_id.
        height: a scalar tensor of int64 representing the image height.
        width: a scalar tensor of int64 representing the image width.
        boxes: a float tensor of shape [K, 4], representing the groundtruth
          boxes in absolute coordinates with respect to the original image size.
        classes: a int64 tensor of shape [K], representing the class labels of
          each instances.
        is_crowds: a bool tensor of shape [K], indicating whether the instance
          is crowd.
        areas: a float tensor of shape [K], indicating the area of each
          instance.
        masks: a string tensor of shape [K], containing the bytes of the png
          mask of each instance.
    """
        decoder = tf_example_decoder.TfExampleDecoder(
            include_mask=self._include_mask)
        decoded_tensors = decoder.decode(example)

        image = decoded_tensors['image']
        image_size = tf.shape(image)[0:2]
        boxes = box_utils.denormalize_boxes(
            decoded_tensors['groundtruth_boxes'], image_size)
        groundtruths = {
            'source_id':
            tf.string_to_number(decoded_tensors['source_id'],
                                out_type=tf.int64),
            'height':
            decoded_tensors['height'],
            'width':
            decoded_tensors['width'],
            'num_detections':
            tf.shape(decoded_tensors['groundtruth_classes'])[0],
            'boxes':
            boxes,
            'classes':
            decoded_tensors['groundtruth_classes'],
            'is_crowds':
            decoded_tensors['groundtruth_is_crowd'],
            'areas':
            decoded_tensors['groundtruth_area'],
        }
        if self._include_mask:
            groundtruths.update({
                'masks':
                decoded_tensors['groundtruth_instance_masks_png'],
            })
        return groundtruths
    def _parse_predict_data(self, data):
        """Parses data for prediction."""
        # Gets original image and its size.
        image = data['image']
        image_shape = tf.shape(input=image)[0:2]

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            padded_size=input_utils.compute_padded_size(
                self._output_size, 2**self._max_level),
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image_height, image_width, _ = image.get_shape().as_list()

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        # Compute Anchor boxes.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size,
                                     (image_height, image_width))

        labels = {
            'anchor_boxes': input_anchor.multilevel_boxes,
            'image_info': image_info,
        }
        # If mode is PREDICT_WITH_GT, returns groundtruths and training targets
        # in labels.
        if self._mode == ModeKeys.PREDICT_WITH_GT:
            # Converts boxes from normalized coordinates to pixel coordinates.
            boxes = box_utils.denormalize_boxes(data['groundtruth_boxes'],
                                                image_shape)
            groundtruths = {
                'source_id': data['source_id'],
                'num_detections': tf.shape(data['groundtruth_classes']),
                'boxes': boxes,
                'classes': data['groundtruth_classes'],
                'areas': data['groundtruth_area'],
                'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
            }
            groundtruths['source_id'] = process_source_id(
                groundtruths['source_id'])
            groundtruths = pad_groundtruths_to_fixed_size(
                groundtruths, self._max_num_instances)
            labels['groundtruths'] = groundtruths

            # Computes training objective for evaluation loss.
            classes = data['groundtruth_classes']

            image_scale = image_info[2, :]
            offset = image_info[3, :]
            boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
                                                      image_info[1, :], offset)
            # Filters out ground truth boxes that are all zeros.
            indices = box_utils.get_non_empty_box_indices(boxes)
            boxes = tf.gather(boxes, indices)

            # Assigns anchors.
            anchor_labeler = anchor.AnchorLabeler(input_anchor,
                                                  self._match_threshold,
                                                  self._unmatched_threshold)
            (cls_targets, box_targets,
             num_positives) = anchor_labeler.label_anchors(
                 boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
            labels['cls_targets'] = cls_targets
            labels['box_targets'] = box_targets
            labels['num_positives'] = num_positives
        return image, labels
    def _parse_train_data(self, data):
        """Parses data for training and evaluation."""
        classes = data['groundtruth_classes']
        boxes = data['groundtruth_boxes']
        is_crowds = data['groundtruth_is_crowd']
        # Skips annotations with `is_crowd` = True.
        if self._skip_crowd_during_training and self._is_training:
            num_groundtrtuhs = tf.shape(input=classes)[0]
            with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
                indices = tf.cond(
                    pred=tf.greater(tf.size(input=is_crowds), 0),
                    true_fn=lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
                    false_fn=lambda: tf.cast(tf.range(num_groundtrtuhs), tf.
                                             int64))
            classes = tf.gather(classes, indices)
            boxes = tf.gather(boxes, indices)

        # Gets original image and its size.
        image = data['image']

        image_shape = tf.shape(input=image)[0:2]

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Flips image randomly during training.
        if self._aug_rand_hflip:
            image, boxes = input_utils.random_horizontal_flip(image, boxes)

        # Converts boxes from normalized coordinates to pixel coordinates.
        boxes = box_utils.denormalize_boxes(boxes, image_shape)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            padded_size=input_utils.compute_padded_size(
                self._output_size, 2**self._max_level),
            aug_scale_min=self._aug_scale_min,
            aug_scale_max=self._aug_scale_max)
        image_height, image_width, _ = image.get_shape().as_list()

        # Resizes and crops boxes.
        image_scale = image_info[2, :]
        offset = image_info[3, :]
        boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
                                                  image_info[1, :], offset)
        # Filters out ground truth boxes that are all zeros.
        indices = box_utils.get_non_empty_box_indices(boxes)
        boxes = tf.gather(boxes, indices)
        classes = tf.gather(classes, indices)

        # Assigns anchors.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size,
                                     (image_height, image_width))
        anchor_labeler = anchor.AnchorLabeler(input_anchor,
                                              self._match_threshold,
                                              self._unmatched_threshold)
        (cls_targets, box_targets,
         num_positives) = anchor_labeler.label_anchors(
             boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        # Packs labels for model_fn outputs.
        labels = {
            'cls_targets': cls_targets,
            'box_targets': box_targets,
            'anchor_boxes': input_anchor.multilevel_boxes,
            'num_positives': num_positives,
            'image_info': image_info,
        }
        return image, labels
    def _parse_predict_data(self, data):
        """Parses data for prediction.

    Args:
      data: the decoded tensor dictionary from TfExampleDecoder.

    Returns:
      A dictionary of {'images': image, 'labels': labels} where
        image: image tensor that is preproessed to have normalized value and
          dimension [output_size[0], output_size[1], 3]
        labels: a dictionary of tensors used for training. The following
          describes {key: value} pairs in the dictionary.
          source_ids: Source image id. Default value -1 if the source id is
            empty in the groundtruth annotation.
          image_info: a 2D `Tensor` that encodes the information of the image
            and the applied preprocessing. It is in the format of
            [[original_height, original_width], [scaled_height, scaled_width],
          anchor_boxes: ordered dictionary with keys
            [min_level, min_level+1, ..., max_level]. The values are tensor with
            shape [height_l, width_l, 4] representing anchor boxes at each
            level.
    """
        # Gets original image and its size.
        image = data['image']
        image_shape = tf.shape(image)[0:2]

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            padded_size=input_utils.compute_padded_size(
                self._output_size, 2**self._max_level),
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image_height, image_width, _ = image.get_shape().as_list()

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        # Compute Anchor boxes.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size,
                                     (image_height, image_width))

        labels = {
            'source_id': dataloader_utils.process_source_id(data['source_id']),
            'anchor_boxes': input_anchor.multilevel_boxes,
            'image_info': image_info,
        }

        if self._mode == ModeKeys.PREDICT_WITH_GT:
            # Converts boxes from normalized coordinates to pixel coordinates.
            boxes = box_utils.denormalize_boxes(data['groundtruth_boxes'],
                                                image_shape)
            groundtruths = {
                'source_id': data['source_id'],
                'height': data['height'],
                'width': data['width'],
                'num_detections': tf.shape(data['groundtruth_classes']),
                'boxes': boxes,
                'classes': data['groundtruth_classes'],
                'areas': data['groundtruth_area'],
                'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
            }
            groundtruths['source_id'] = dataloader_utils.process_source_id(
                groundtruths['source_id'])
            groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
                groundtruths, self._max_num_instances)
            labels['groundtruths'] = groundtruths

        return image, labels
    def _parse_train_data(self, data):
        """Parses data for training.

    Args:
      data: the decoded tensor dictionary from TfExampleDecoder.

    Returns:
      image: image tensor that is preproessed to have normalized value and
        dimension [output_size[0], output_size[1], 3]
      labels: a dictionary of tensors used for training. The following describes
        {key: value} pairs in the dictionary.
        image_info: a 2D `Tensor` that encodes the information of the image and
          the applied preprocessing. It is in the format of
          [[original_height, original_width], [scaled_height, scaled_width],
        anchor_boxes: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, 4] representing anchor boxes at each level.
        rpn_score_targets: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, anchors_per_location]. The height_l and
          width_l represent the dimension of class logits at l-th level.
        rpn_box_targets: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, anchors_per_location * 4]. The height_l and
          width_l represent the dimension of bounding box regression output at
          l-th level.
        gt_boxes: Groundtruth bounding box annotations. The box is represented
           in [y1, x1, y2, x2] format. The coordinates are w.r.t the scaled
           image that is fed to the network. The tennsor is padded with -1 to
           the fixed dimension [self._max_num_instances, 4].
        gt_classes: Groundtruth classes annotations. The tennsor is padded
          with -1 to the fixed dimension [self._max_num_instances].
        gt_masks: groundtrugh masks cropped by the bounding box and
          resized to a fixed size determined by mask_crop_size.
    """
        classes = data['groundtruth_classes']
        boxes = data['groundtruth_boxes']
        if self._include_mask:
            masks = data['groundtruth_instance_masks']

        is_crowds = data['groundtruth_is_crowd']
        # Skips annotations with `is_crowd` = True.
        if self._skip_crowd_during_training and self._is_training:
            num_groundtrtuhs = tf.shape(classes)[0]
            with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
                indices = tf.cond(
                    tf.greater(tf.size(is_crowds), 0),
                    lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
                    lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
            classes = tf.gather(classes, indices)
            boxes = tf.gather(boxes, indices)
            if self._include_mask:
                masks = tf.gather(masks, indices)

        # Gets original image and its size.
        image = data['image']
        image_shape = tf.shape(image)[0:2]

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Flips image randomly during training.
        if self._aug_rand_hflip:
            if self._include_mask:
                image, boxes, masks = input_utils.random_horizontal_flip(
                    image, boxes, masks)
            else:
                image, boxes = input_utils.random_horizontal_flip(image, boxes)

        # Converts boxes from normalized coordinates to pixel coordinates.
        # Now the coordinates of boxes are w.r.t. the original image.
        boxes = box_utils.denormalize_boxes(boxes, image_shape)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            padded_size=input_utils.compute_padded_size(
                self._output_size, 2**self._max_level),
            aug_scale_min=self._aug_scale_min,
            aug_scale_max=self._aug_scale_max)
        image_height, image_width, _ = image.get_shape().as_list()

        # Resizes and crops boxes.
        # Now the coordinates of boxes are w.r.t the scaled image.
        image_scale = image_info[2, :]
        offset = image_info[3, :]
        boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
                                                  (image_height, image_width),
                                                  offset)

        # Filters out ground truth boxes that are all zeros.
        indices = box_utils.get_non_empty_box_indices(boxes)
        boxes = tf.gather(boxes, indices)
        classes = tf.gather(classes, indices)
        if self._include_mask:
            masks = tf.gather(masks, indices)
            cropped_boxes = boxes + tf.cast(tf.tile(
                tf.expand_dims(offset, axis=0), [1, 2]),
                                            dtype=tf.float32)
            cropped_boxes = box_utils.normalize_boxes(cropped_boxes,
                                                      image_info[1, :])
            num_masks = tf.shape(masks)[0]
            masks = tf.image.crop_and_resize(
                tf.expand_dims(masks, axis=-1),
                cropped_boxes,
                box_indices=tf.range(num_masks, dtype=tf.int32),
                crop_size=[self._mask_crop_size, self._mask_crop_size],
                method='bilinear')
            masks = tf.squeeze(masks, axis=-1)

        # Assigns anchor targets.
        # Note that after the target assignment, box targets are absolute pixel
        # offsets w.r.t. the scaled image.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size,
                                     (image_height, image_width))
        anchor_labeler = anchor.RpnAnchorLabeler(input_anchor,
                                                 self._rpn_match_threshold,
                                                 self._rpn_unmatched_threshold,
                                                 self._rpn_batch_size_per_im,
                                                 self._rpn_fg_fraction)
        rpn_score_targets, rpn_box_targets = anchor_labeler.label_anchors(
            boxes, tf.cast(tf.expand_dims(classes, axis=-1), dtype=tf.float32))

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        # Packs labels for model_fn outputs.
        labels = {
            'anchor_boxes': input_anchor.multilevel_boxes,
            'image_info': image_info,
            'rpn_score_targets': rpn_score_targets,
            'rpn_box_targets': rpn_box_targets,
        }
        labels['gt_boxes'] = input_utils.pad_to_fixed_size(
            boxes, self._max_num_instances, -1)
        labels['gt_classes'] = input_utils.pad_to_fixed_size(
            classes, self._max_num_instances, -1)
        if self._include_mask:
            labels['gt_masks'] = input_utils.pad_to_fixed_size(
                masks, self._max_num_instances, -1)

        return image, labels
示例#6
0
    def _parse_predict_data(self, data):
        """Parse data for ShapeMask training."""
        classes = data['groundtruth_classes']
        boxes = data['groundtruth_boxes']
        masks = data['groundtruth_instance_masks']

        # Gets original image and its size.
        image = data['image']
        image_shape = tf.shape(image)[0:2]

        # If not using category, makes all categories with id = 0.
        if not self._use_category:
            classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Converts boxes from normalized coordinates to pixel coordinates.
        boxes = box_utils.denormalize_boxes(boxes, image_shape)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            self._output_size,
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image_scale = image_info[2, :]
        offset = image_info[3, :]

        # Resizes and crops boxes and masks.
        boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
                                                  self._output_size, offset)
        masks = input_utils.resize_and_crop_masks(
            tf.expand_dims(masks, axis=-1), image_scale, self._output_size,
            offset)

        # Filters out ground truth boxes that are all zeros.
        indices = input_utils.get_non_empty_box_indices(boxes)
        boxes = tf.gather(boxes, indices)
        classes = tf.gather(classes, indices)

        # Assigns anchors.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size, self._output_size)
        anchor_labeler = anchor.AnchorLabeler(input_anchor,
                                              self._match_threshold,
                                              self._unmatched_threshold)

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        labels = {
            'anchor_boxes': input_anchor.multilevel_boxes,
            'image_info': image_info,
        }
        if self._mode == ModeKeys.PREDICT_WITH_GT:
            # Converts boxes from normalized coordinates to pixel coordinates.
            groundtruths = {
                'source_id':
                data['source_id'],
                'num_detections':
                tf.shape(data['groundtruth_classes']),
                'boxes':
                box_utils.denormalize_boxes(data['groundtruth_boxes'],
                                            image_shape),
                'classes':
                data['groundtruth_classes'],
                # 'masks': tf.squeeze(masks, axis=-1),
                'areas':
                data['groundtruth_area'],
                'is_crowds':
                tf.cast(data['groundtruth_is_crowd'], tf.int32),
            }
            groundtruths['source_id'] = dataloader_utils.process_source_id(
                groundtruths['source_id'])
            groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
                groundtruths, self._max_num_instances)
            # Computes training labels.
            (cls_targets, box_targets,
             num_positives) = anchor_labeler.label_anchors(
                 boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
            # Packs labels for model_fn outputs.
            labels.update({
                'cls_targets': cls_targets,
                'box_targets': box_targets,
                'num_positives': num_positives,
                'groundtruths': groundtruths,
            })
        return image, labels
示例#7
0
    def _parse_train_data(self, data):
        """Parse data for ShapeMask training."""
        classes = data['groundtruth_classes']
        boxes = data['groundtruth_boxes']
        masks = data['groundtruth_instance_masks']
        is_crowds = data['groundtruth_is_crowd']
        # Skips annotations with `is_crowd` = True.
        if self._skip_crowd_during_training and self._is_training:
            num_groundtrtuhs = tf.shape(classes)[0]
            with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
                indices = tf.cond(
                    tf.greater(tf.size(is_crowds), 0),
                    lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
                    lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
            classes = tf.gather(classes, indices)
            boxes = tf.gather(boxes, indices)
            masks = tf.gather(masks, indices)

        # Gets original image and its size.
        image = data['image']
        image_shape = tf.shape(image)[0:2]

        # If not using category, makes all categories with id = 0.
        if not self._use_category:
            classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Flips image randomly during training.
        if self._aug_rand_hflip:
            image, boxes, masks = input_utils.random_horizontal_flip(
                image, boxes, masks)

        # Converts boxes from normalized coordinates to pixel coordinates.
        boxes = box_utils.denormalize_boxes(boxes, image_shape)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            self._output_size,
            aug_scale_min=self._aug_scale_min,
            aug_scale_max=self._aug_scale_max)
        image_scale = image_info[2, :]
        offset = image_info[3, :]

        # Resizes and crops boxes and masks.
        boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
                                                  self._output_size, offset)

        # Filters out ground truth boxes that are all zeros.
        indices = input_utils.get_non_empty_box_indices(boxes)
        boxes = tf.gather(boxes, indices)
        classes = tf.gather(classes, indices)
        masks = tf.gather(masks, indices)

        # Assigns anchors.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size, self._output_size)
        anchor_labeler = anchor.AnchorLabeler(input_anchor,
                                              self._match_threshold,
                                              self._unmatched_threshold)
        (cls_targets, box_targets,
         num_positives) = anchor_labeler.label_anchors(
             boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))

        # Sample groundtruth masks/boxes/classes for mask branch.
        num_masks = tf.shape(masks)[0]
        mask_shape = tf.shape(masks)[1:3]

        # Pad sampled boxes/masks/classes to a constant batch size.
        padded_boxes = input_utils.pad_to_fixed_size(boxes,
                                                     self._num_sampled_masks)
        padded_classes = input_utils.pad_to_fixed_size(classes,
                                                       self._num_sampled_masks)
        padded_masks = input_utils.pad_to_fixed_size(masks,
                                                     self._num_sampled_masks)

        # Randomly sample groundtruth masks for mask branch training. For the image
        # without groundtruth masks, it will sample the dummy padded tensors.
        rand_indices = tf.random.shuffle(
            tf.range(tf.maximum(num_masks, self._num_sampled_masks)))
        rand_indices = tf.math.mod(rand_indices, tf.maximum(num_masks, 1))
        rand_indices = rand_indices[0:self._num_sampled_masks]
        rand_indices = tf.reshape(rand_indices, [self._num_sampled_masks])

        sampled_boxes = tf.gather(padded_boxes, rand_indices)
        sampled_classes = tf.gather(padded_classes, rand_indices)
        sampled_masks = tf.gather(padded_masks, rand_indices)
        # Jitter the sampled boxes to mimic the noisy detections.
        sampled_boxes = box_utils.jitter_boxes(
            sampled_boxes, noise_scale=self._box_jitter_scale)
        sampled_boxes = box_utils.clip_boxes(sampled_boxes, self._output_size)
        # Compute mask targets in feature crop. A feature crop fully contains a
        # sampled box.
        mask_outer_boxes = box_utils.compute_outer_boxes(
            sampled_boxes, tf.shape(image)[0:2], scale=self._outer_box_scale)
        mask_outer_boxes = box_utils.clip_boxes(mask_outer_boxes,
                                                self._output_size)
        # Compensate the offset of mask_outer_boxes to map it back to original image
        # scale.
        mask_outer_boxes_ori = mask_outer_boxes
        mask_outer_boxes_ori += tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
        mask_outer_boxes_ori /= tf.tile(tf.expand_dims(image_scale, axis=0),
                                        [1, 2])
        norm_mask_outer_boxes_ori = box_utils.normalize_boxes(
            mask_outer_boxes_ori, mask_shape)

        # Set sampled_masks shape to [batch_size, height, width, 1].
        sampled_masks = tf.cast(tf.expand_dims(sampled_masks, axis=-1),
                                tf.float32)
        mask_targets = tf.image.crop_and_resize(
            sampled_masks,
            norm_mask_outer_boxes_ori,
            box_indices=tf.range(self._num_sampled_masks),
            crop_size=[self._mask_crop_size, self._mask_crop_size],
            method='bilinear',
            extrapolation_value=0,
            name='train_mask_targets')
        mask_targets = tf.where(tf.greater_equal(mask_targets, 0.5),
                                tf.ones_like(mask_targets),
                                tf.zeros_like(mask_targets))
        mask_targets = tf.squeeze(mask_targets, axis=-1)
        if self._up_sample_factor > 1:
            fine_mask_targets = tf.image.crop_and_resize(
                sampled_masks,
                norm_mask_outer_boxes_ori,
                box_indices=tf.range(self._num_sampled_masks),
                crop_size=[
                    self._mask_crop_size * self._up_sample_factor,
                    self._mask_crop_size * self._up_sample_factor
                ],
                method='bilinear',
                extrapolation_value=0,
                name='train_mask_targets')
            fine_mask_targets = tf.where(
                tf.greater_equal(fine_mask_targets, 0.5),
                tf.ones_like(fine_mask_targets),
                tf.zeros_like(fine_mask_targets))
            fine_mask_targets = tf.squeeze(fine_mask_targets, axis=-1)
        else:
            fine_mask_targets = mask_targets

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        valid_image = tf.cast(tf.not_equal(num_masks, 0), tf.int32)
        if self._mask_train_class == 'all':
            mask_is_valid = valid_image * tf.ones_like(sampled_classes,
                                                       tf.int32)
        else:
            # Get the intersection of sampled classes with training splits.
            mask_valid_classes = tf.cast(
                tf.expand_dims(
                    class_utils.coco_split_class_ids(self._mask_train_class),
                    1), sampled_classes.dtype)
            match = tf.reduce_any(
                tf.equal(tf.expand_dims(sampled_classes, 0),
                         mask_valid_classes), 0)
            mask_is_valid = valid_image * tf.cast(match, tf.int32)

        # Packs labels for model_fn outputs.
        labels = {
            'cls_targets': cls_targets,
            'box_targets': box_targets,
            'anchor_boxes': input_anchor.multilevel_boxes,
            'num_positives': num_positives,
            'image_info': image_info,
            # For ShapeMask.
            'mask_boxes': sampled_boxes,
            'mask_outer_boxes': mask_outer_boxes,
            'mask_targets': mask_targets,
            'fine_mask_targets': fine_mask_targets,
            'mask_classes': sampled_classes,
            'mask_is_valid': mask_is_valid,
        }
        return image, labels