Esempio n. 1
0
  def __call__(self, fpn_features, boxes, outer_boxes, classes,
               is_training):
    """Generate the detection priors from the box detections and FPN features.

    This corresponds to the Fig. 4 of the ShapeMask paper at
    https://arxiv.org/pdf/1904.03239.pdf

    Args:
      fpn_features: a dictionary of FPN features.
      boxes: a float tensor of shape [batch_size, num_instances, 4]
        representing the tight gt boxes from dataloader/detection.
      outer_boxes: a float tensor of shape [batch_size, num_instances, 4]
        representing the loose gt boxes from dataloader/detection.
      classes: a int Tensor of shape [batch_size, num_instances]
        of instance classes.
      is_training: training mode or not.

    Returns:
      instance_features: a float Tensor of shape [batch_size * num_instances,
          mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
          instance feature crop.
      detection_priors: A float Tensor of shape [batch_size * num_instances,
        mask_size, mask_size, 1].
    """
    with tf.variable_scope('prior_mask', reuse=tf.AUTO_REUSE):
      batch_size, num_instances, _ = boxes.get_shape().as_list()
      if batch_size is None:
        batch_size = tf.shape(boxes)[0]
      instance_features = spatial_transform_ops.multilevel_crop_and_resize(
          fpn_features, outer_boxes, output_size=self._mask_crop_size)
      instance_features = tf.layers.dense(instance_features,
                                          self._num_downsample_channels)
      shape_priors = self._get_priors()
      shape_priors = tf.cast(shape_priors, instance_features.dtype)

      # Get uniform priors for each outer box.
      uniform_priors = tf.ones(
          [batch_size, num_instances,
           self._mask_crop_size, self._mask_crop_size])
      uniform_priors = spatial_transform_ops.crop_mask_in_target_box(
          uniform_priors, boxes, outer_boxes, self._mask_crop_size)
      uniform_priors = tf.cast(uniform_priors, instance_features.dtype)

      # Classify shape priors using uniform priors + instance features.
      prior_distribution = self._classify_shape_priors(
          instance_features, uniform_priors, classes)
      instance_priors = tf.gather(shape_priors, classes)
      instance_priors *= tf.expand_dims(
          tf.expand_dims(prior_distribution, axis=-1), axis=-1)
      instance_priors = tf.reduce_sum(instance_priors, axis=2)
      detection_priors = spatial_transform_ops.crop_mask_in_target_box(
          instance_priors, boxes, outer_boxes, self._mask_crop_size)

      return instance_features, detection_priors
Esempio n. 2
0
    def build_outputs(self, features, labels, mode):
        is_training = mode == mode_keys.TRAIN
        model_outputs = {}

        if 'anchor_boxes' in labels:
            anchor_boxes = labels['anchor_boxes']
        else:
            anchor_boxes = anchor.Anchor(
                self._anchor_params.min_level, self._anchor_params.max_level,
                self._anchor_params.num_scales,
                self._anchor_params.aspect_ratios,
                self._anchor_params.anchor_size,
                features.get_shape().as_list()[1:3]).multilevel_boxes

        backbone_features = self._backbone_fn(features, is_training)
        fpn_features = self._fpn_fn(backbone_features, is_training)

        rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
            fpn_features, is_training)
        model_outputs.update({
            'rpn_score_outputs': rpn_score_outputs,
            'rpn_box_outputs': rpn_box_outputs,
        })
        rpn_rois, _ = self._generate_rois_fn(rpn_box_outputs,
                                             rpn_score_outputs, anchor_boxes,
                                             labels['image_info'][:, 1, :],
                                             is_training)

        if is_training:
            rpn_rois = tf.stop_gradient(rpn_rois)

            # Sample proposals.
            rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
                self._sample_rois_fn(rpn_rois, labels['gt_boxes'],
                                     labels['gt_classes']))

            # Create bounding box training targets.
            box_targets = box_utils.encode_boxes(
                matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
            # If the target is background, the box target is set to all 0s.
            box_targets = tf.where(
                tf.tile(
                    tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
                    [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
            model_outputs.update({
                'class_targets': matched_gt_classes,
                'box_targets': box_targets,
            })

        roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, rpn_rois, output_size=7)

        class_outputs, box_outputs = self._frcnn_head_fn(
            roi_features, is_training)
        model_outputs.update({
            'class_outputs': class_outputs,
            'box_outputs': box_outputs,
        })

        if not is_training:
            detection_results = self._generate_detections_fn(
                box_outputs, class_outputs, rpn_rois,
                labels['image_info'][:, 1:2, :])
            model_outputs.update(detection_results)

        if not self._include_mask:
            self._log_model_statistics(features)
            return model_outputs

        if is_training:
            rpn_rois, classes, mask_targets = self._sample_masks_fn(
                rpn_rois, matched_gt_boxes, matched_gt_classes,
                matched_gt_indices, labels['gt_masks'])
            mask_targets = tf.stop_gradient(mask_targets)

            classes = tf.cast(classes, dtype=tf.int32)

            model_outputs.update({
                'mask_targets': mask_targets,
                'sampled_class_targets': classes,
            })
        else:
            rpn_rois = detection_results['detection_boxes']
            classes = tf.cast(detection_results['detection_classes'],
                              dtype=tf.int32)

        mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, rpn_rois, output_size=14)

        mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes,
                                           is_training)

        if is_training:
            model_outputs.update({
                'mask_outputs': mask_outputs,
            })
        else:
            model_outputs.update(
                {'detection_masks': tf.nn.sigmoid(mask_outputs)})

        self._log_model_statistics(features)
        return model_outputs
    def _build_outputs(self, images, labels, mode):
        is_training = mode == mode_keys.TRAIN
        model_outputs = {}

        if "anchor_boxes" in labels:
            anchor_boxes = labels["anchor_boxes"]
        else:
            anchor_boxes = anchor.Anchor(
                self._params.architecture.min_level,
                self._params.architecture.max_level,
                self._params.anchor.num_scales,
                self._params.anchor.aspect_ratios,
                self._params.anchor.anchor_size,
                images.get_shape().as_list()[1:3],
            ).multilevel_boxes

            batch_size = tf.shape(input=images)[0]
            for level in anchor_boxes:
                anchor_boxes[level] = tf.tile(
                    tf.expand_dims(anchor_boxes[level], 0), [batch_size, 1, 1])

        backbone_features = self._backbone_fn(images, is_training)
        fpn_features = self._fpn_fn(backbone_features, is_training)

        rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
            fpn_features, is_training)
        model_outputs.update({
            "rpn_score_outputs": rpn_score_outputs,
            "rpn_box_outputs": rpn_box_outputs,
        })
        rpn_rois, _ = self._generate_rois_fn(
            rpn_box_outputs,
            rpn_score_outputs,
            anchor_boxes,
            labels["image_info"][:, 1, :],
            is_training,
        )

        if is_training:
            rpn_rois = tf.stop_gradient(rpn_rois)

            # Sample proposals.
            (
                rpn_rois,
                matched_gt_boxes,
                matched_gt_classes,
                matched_gt_indices,
            ) = self._sample_rois_fn(rpn_rois, labels["gt_boxes"],
                                     labels["gt_classes"])

            # Create bounding box training targets.
            box_targets = box_utils.encode_boxes(
                matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
            # If the target is background, the box target is set to all 0s.
            box_targets = tf.compat.v1.where(
                tf.tile(
                    tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
                    [1, 1, 4]),
                tf.zeros_like(box_targets),
                box_targets,
            )
            model_outputs.update({
                "class_targets": matched_gt_classes,
                "box_targets": box_targets,
            })

        roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, rpn_rois, output_size=7)

        class_outputs, box_outputs = self._frcnn_head_fn(
            roi_features, is_training)
        model_outputs.update({
            "class_outputs": class_outputs,
            "box_outputs": box_outputs,
        })

        if not is_training:
            detection_results = self._generate_detections_fn(
                box_outputs, class_outputs, rpn_rois,
                labels["image_info"][:, 1:2, :])
            model_outputs.update(detection_results)

        if not self._include_mask:
            return model_outputs

        if is_training:
            (
                rpn_rois,
                classes,
                mask_targets,
                gather_nd_gt_indices,
            ) = self._sample_masks_fn(
                rpn_rois,
                matched_gt_boxes,
                matched_gt_classes,
                matched_gt_indices,
                labels["gt_masks"],
            )
            mask_targets = tf.stop_gradient(mask_targets)

            classes = tf.cast(classes, dtype=tf.int32)

            model_outputs.update({
                "mask_targets": mask_targets,
                "sampled_class_targets": classes,
            })
        else:
            rpn_rois = detection_results["detection_boxes"]
            classes = tf.cast(detection_results["detection_classes"],
                              dtype=tf.int32)

        mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, rpn_rois, output_size=14)

        mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes,
                                           is_training)

        if is_training:
            model_outputs.update({
                "mask_outputs": mask_outputs,
            })
        else:
            model_outputs.update(
                {"detection_masks": tf.nn.sigmoid(mask_outputs)})

        if not self._include_attributes:
            return model_outputs

        attribute_outputs = self._attributes_head_fn(mask_roi_features,
                                                     is_training)

        if is_training:
            attribute_targets = tf.gather_nd(
                labels["gt_attributes"],
                gather_nd_gt_indices)  # [batch, K, num_attributes]

            model_outputs.update({
                "attribute_outputs": attribute_outputs,
                "attribute_targets": attribute_targets,
            })
        else:
            model_outputs["detection_attributes"] = tf.nn.sigmoid(
                attribute_outputs)

        return model_outputs
    def _run_frcnn_head(self, fpn_features, rois, labels, is_training,
                        model_outputs, layer_num, iou_threshold,
                        regression_weights):
        """Runs the frcnn head that does both class and box prediction.

    Args:
      fpn_features: `list` of features from the fpn layer that are used to do
        roi pooling from the `rois`.
      rois: `list` of current rois that will be used to predict bbox refinement
        and classes from.
      labels: `dict` of label information. If `is_training` is used then
        the gt bboxes and classes are used to assign the rois their
        corresponding gt box and class used for computing the loss.
      is_training: `bool`, if model is training or being evaluated.
      model_outputs: `dict`, used for storing outputs used for eval and losses.
      layer_num: `int`, the current frcnn layer in the cascade.
      iou_threshold: `float`, when assigning positives/negatives based on rois,
        this is threshold used.
      regression_weights: `list`, weights used for l1 loss in bounding box
        regression.

    Returns:
      class_outputs: Class predictions for rois.
      box_outputs: Box predictions for rois. These are formatted for the
        regression loss and need to be converted before being used as rois
        in the next stage.
      model_outputs: Updated dict with predictions used for losses and eval.
      matched_gt_boxes: If `is_training` is true, then these give the gt box
        location of its positive match.
      matched_gt_classes: If `is_training` is true, then these give the gt class
         of the predicted box.
      matched_gt_boxes: If `is_training` is true, then these give the box
        location of its positive match.
      matched_gt_indices: If `is_training` is true, then gives the index of
        the positive box match. Used for mask prediction.
      rois: The sampled rois used for this layer.
    """
        # Only used during training.
        matched_gt_boxes, matched_gt_classes, matched_gt_indices = (None, None,
                                                                    None)
        if is_training:
            rois = tf.stop_gradient(rois)

            if layer_num == 0:
                # Sample proposals based on all bbox coordinates. NMS is applied here
                # along with sampling criteria that will make the batch have a constant
                # fraction of foreground to background examples.
                rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
                    self._sample_rois_fn(rois, labels['gt_boxes'],
                                         labels['gt_classes']))
            else:
                # Since now we have a constant number of proposals we no longer
                # need fancier sampling that applies NMS and a fixed fg/bg ratio.
                rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
                    target_ops.assign_and_sample_proposals(
                        rois,
                        labels['gt_boxes'],
                        labels['gt_classes'],
                        num_samples_per_image=self._num_roi_samples,
                        mix_gt_boxes=False,
                        fg_iou_thresh=iou_threshold,
                        bg_iou_thresh_hi=iou_threshold,
                        bg_iou_thresh_lo=0.0,
                        skip_subsampling=True))
            self.add_scalar_summary(
                'fg_bg_ratio_{}'.format(layer_num),
                tf.reduce_mean(
                    tf.cast(tf.greater(matched_gt_classes, 0), rois.dtype)))
            # Create bounding box training targets.
            box_targets = box_utils.encode_boxes(matched_gt_boxes,
                                                 rois,
                                                 weights=regression_weights)
            # If the target is background, the box target is set to all 0s.
            box_targets = tf.where(
                tf.tile(
                    tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
                    [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
            model_outputs.update({
                'class_targets_{}'.format(layer_num):
                matched_gt_classes,
                'box_targets_{}'.format(layer_num):
                box_targets,
            })

        # Get roi features.
        roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, rois, output_size=7)

        # Run frcnn head to get class and bbox predictions.
        with tf.variable_scope('frcnn_layer_{}'.format(layer_num)):
            class_outputs, box_outputs = self._frcnn_head_fn(
                roi_features, is_training)
        model_outputs.update({
            'class_outputs_{}'.format(layer_num): class_outputs,
            'box_outputs_{}'.format(layer_num): box_outputs,
        })
        return (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
                matched_gt_classes, matched_gt_indices, rois)
    def _build_outputs(self, images, labels, mode):
        is_training = mode == mode_keys.TRAIN
        model_outputs = {}

        if 'anchor_boxes' in labels:
            anchor_boxes = labels['anchor_boxes']
        else:
            anchor_boxes = anchor.Anchor(
                self._params.architecture.min_level,
                self._params.architecture.max_level,
                self._params.anchor.num_scales,
                self._params.anchor.aspect_ratios,
                self._params.anchor.anchor_size,
                images.get_shape().as_list()[1:3]).multilevel_boxes

            batch_size = tf.shape(images)[0]
            for level in anchor_boxes:
                anchor_boxes[level] = tf.tile(
                    tf.expand_dims(anchor_boxes[level], 0), [batch_size, 1, 1])

        backbone_features = self._backbone_fn(images, is_training)
        fpn_features = self._fpn_fn(backbone_features, is_training)

        rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
            fpn_features, is_training)
        model_outputs.update({
            'rpn_score_outputs': rpn_score_outputs,
            'rpn_box_outputs': rpn_box_outputs,
        })
        # Run the RPN layer to get bbox coordinates for first frcnn layer.
        current_rois, _ = self._generate_rois_fn(rpn_box_outputs,
                                                 rpn_score_outputs,
                                                 anchor_boxes,
                                                 labels['image_info'][:, 1, :],
                                                 is_training)

        cascade_ious = [-1]
        if self._cascade_iou_thresholds is not None:
            cascade_ious = cascade_ious + self._cascade_iou_thresholds
        next_rois = current_rois
        # Stores the class predictions for each RCNN head.
        all_class_outputs = []
        for cascade_num, iou_threshold in enumerate(cascade_ious):
            # In cascade RCNN we want the higher layers to have different regression
            # weights as the predicted deltas become smaller and smaller.
            regression_weights = self._cascade_layer_to_weights[cascade_num]
            current_rois = next_rois
            (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
             matched_gt_classes, matched_gt_indices,
             current_rois) = self._run_frcnn_head(fpn_features, current_rois,
                                                  labels, is_training,
                                                  model_outputs, cascade_num,
                                                  iou_threshold,
                                                  regression_weights)
            all_class_outputs.append(class_outputs)

            # Generate the next rois if we are running another cascade.
            # Since bboxes are predicted for every class
            # (if `class_agnostic_bbox_pred` is false) this takes the best class
            # bbox and converts it to the correct format to be used for roi
            # operations.
            if is_training:
                correct_class = matched_gt_classes
            else:
                correct_class = tf.arg_max(class_outputs, dimension=-1)

            next_rois = self._box_outputs_to_rois(
                box_outputs, current_rois, correct_class,
                labels['image_info'][:, 1:2, :], regression_weights)

        if not is_training:
            tf.logging.info('(self._class_agnostic_bbox_pred): {}'.format(
                self._class_agnostic_bbox_pred))
            if self._cascade_class_ensemble:
                class_outputs = tf.add_n(all_class_outputs) / len(
                    all_class_outputs)
            # Post processing/NMS is done here for final boxes. Note NMS is done
            # before to generate proposals of the output of the RPN head.
            # The background class is also removed here.
            detection_results = self._generate_detections_fn(
                box_outputs,
                class_outputs,
                current_rois,
                labels['image_info'][:, 1:2, :],
                regression_weights,
                bbox_per_class=(not self._class_agnostic_bbox_pred))
            model_outputs.update(detection_results)

        if not self._include_mask:
            return model_outputs

        if is_training:
            current_rois, classes, mask_targets = self._sample_masks_fn(
                current_rois, matched_gt_boxes, matched_gt_classes,
                matched_gt_indices, labels['gt_masks'])
            mask_targets = tf.stop_gradient(mask_targets)

            classes = tf.cast(classes, dtype=tf.int32)

            model_outputs.update({
                'mask_targets': mask_targets,
                'sampled_class_targets': classes,
            })
        else:
            current_rois = detection_results['detection_boxes']
            classes = tf.cast(detection_results['detection_classes'],
                              dtype=tf.int32)

        mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, current_rois, output_size=14)
        mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes,
                                           is_training)

        if is_training:
            model_outputs.update({
                'mask_outputs': mask_outputs,
            })
        else:
            model_outputs.update(
                {'detection_masks': tf.nn.sigmoid(mask_outputs)})

        return model_outputs
Esempio n. 6
0
    def _build_outputs(self, images, labels, mode):
        is_training = mode == mode_keys.TRAIN
        model_outputs = {}

        if 'anchor_boxes' in labels:
            anchor_boxes = labels['anchor_boxes']
        else:
            anchor_boxes = anchor.Anchor(
                self._params.architecture.min_level,
                self._params.architecture.max_level,
                self._params.anchor.num_scales,
                self._params.anchor.aspect_ratios,
                self._params.anchor.anchor_size,
                images.get_shape().as_list()[1:3]).multilevel_boxes

            batch_size = tf.shape(images)[0]
            for level in anchor_boxes:
                anchor_boxes[level] = tf.tile(
                    tf.expand_dims(anchor_boxes[level], 0), [batch_size, 1, 1])

        backbone_features = self._backbone_fn(images, is_training)
        fpn_features = self._fpn_fn(backbone_features, is_training)

        rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
            fpn_features, is_training)
        model_outputs.update({
            'rpn_score_outputs': rpn_score_outputs,
            'rpn_box_outputs': rpn_box_outputs,
        })
        rpn_rois, _ = self._generate_rois_fn(rpn_box_outputs,
                                             rpn_score_outputs, anchor_boxes,
                                             labels['image_info'][:, 1, :],
                                             is_training)

        if is_training:
            rpn_rois = tf.stop_gradient(rpn_rois)

            # Sample proposals.
            rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
                self._sample_rois_fn(rpn_rois, labels['gt_boxes'],
                                     labels['gt_classes']))

            self.add_scalar_summary(
                'fg_bg_ratio_{}'.format(0),
                tf.reduce_sum(
                    tf.cast(tf.greater(matched_gt_classes, 0), tf.float32)) /
                tf.reduce_sum(
                    tf.cast(tf.greater_equal(matched_gt_classes, 0),
                            tf.float32)))

            # Create bounding box training targets.
            box_targets = box_utils.encode_boxes(
                matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
            # If the target is background, the box target is set to all 0s.
            box_targets = tf.where(
                tf.tile(
                    tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
                    [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
            model_outputs.update({
                'class_targets': matched_gt_classes,
                'box_targets': box_targets,
            })

        _, num_rois_before_cat, _ = rpn_rois.get_shape().as_list()

        if is_training and self._feat_distill:
            tf.logging.info(f'rois before concat distill boxes: {rpn_rois}')
            rpn_rois = tf.concat([rpn_rois, labels['roi_boxes']], axis=1)
            # [batch_size, num_rois+max_distill_rois, 4]
            tf.logging.info(f'rois after concat distill boxes: {rpn_rois}')

        roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, rpn_rois, output_size=7)

        if is_training and self._feat_distill:
            tf.logging.info(f'rois before split: {rpn_rois}')
            rpn_rois, _ = tf.split(
                rpn_rois, [num_rois_before_cat, self._max_distill_rois],
                axis=1)
            tf.logging.info(f'rois after split: {rpn_rois}')

        (class_outputs, box_outputs, distill_feat_outputs,
         distill_class_outputs) = self._frcnn_head_fn(roi_features,
                                                      is_training)
        model_outputs.update({
            'class_outputs': class_outputs,
            'box_outputs': box_outputs,
        })
        if is_training and self._feat_distill:
            model_outputs.update(
                {'distill_feat_outputs': distill_feat_outputs})

        if not is_training:
            detection_results = self._generate_detections_fn(
                box_outputs,
                class_outputs,
                rpn_rois,
                labels['image_info'][:, 1:2, :],
                bbox_per_class=not self._params.frcnn_head.
                class_agnostic_bbox_pred,
                distill_class_outputs=distill_class_outputs,
            )
            model_outputs.update(detection_results)

        if not self._include_mask:
            return model_outputs

        if is_training:
            rpn_rois, classes, mask_targets = self._sample_masks_fn(
                rpn_rois, matched_gt_boxes, matched_gt_classes,
                matched_gt_indices, labels['gt_masks'])
            mask_targets = tf.stop_gradient(mask_targets)

            classes = tf.cast(classes, dtype=tf.int32)

            model_outputs.update({
                'mask_targets': mask_targets,
                'sampled_class_targets': classes,
            })
        else:
            rpn_rois = detection_results['detection_boxes']
            classes = tf.cast(detection_results['detection_classes'],
                              dtype=tf.int32)

        mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
            fpn_features, rpn_rois, output_size=14)

        mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes,
                                           is_training)

        if is_training:
            model_outputs.update({
                'mask_outputs': mask_outputs,
            })
        else:
            model_outputs.update(
                {'detection_masks': tf.nn.sigmoid(mask_outputs)})

        return model_outputs
  def build_outputs(self, inputs, mode):
    is_training = mode == mode_keys.TRAIN
    model_outputs = {}

    image = inputs['image']
    _, image_height, image_width, _ = image.get_shape().as_list()
    backbone_features = self._backbone_fn(image, is_training)
    fpn_features = self._fpn_fn(backbone_features, is_training)

    rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
        fpn_features, is_training)
    model_outputs.update({
        'rpn_score_outputs':
            tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                  rpn_score_outputs),
        'rpn_box_outputs':
            tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                  rpn_box_outputs),
    })
    input_anchor = anchor.Anchor(self._params.anchor.min_level,
                                 self._params.anchor.max_level,
                                 self._params.anchor.num_scales,
                                 self._params.anchor.aspect_ratios,
                                 self._params.anchor.anchor_size,
                                 (image_height, image_width))
    rpn_rois, _ = self._generate_rois_fn(rpn_box_outputs, rpn_score_outputs,
                                         input_anchor.multilevel_boxes,
                                         inputs['image_info'][:, 1, :],
                                         is_training)
    if is_training:
      rpn_rois = tf.stop_gradient(rpn_rois)

      # Sample proposals.
      rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
          self._sample_rois_fn(rpn_rois, inputs['gt_boxes'],
                               inputs['gt_classes']))

      # Create bounding box training targets.
      box_targets = box_utils.encode_boxes(
          matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
      # If the target is background, the box target is set to all 0s.
      box_targets = tf.where(
          tf.tile(
              tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
              [1, 1, 4]),
          tf.zeros_like(box_targets),
          box_targets)
      model_outputs.update({
          'class_targets': matched_gt_classes,
          'box_targets': box_targets,
      })

    roi_features = spatial_transform_ops.multilevel_crop_and_resize(
        fpn_features, rpn_rois, output_size=7)

    class_outputs, box_outputs = self._frcnn_head_fn(roi_features, is_training)

    model_outputs.update({
        'class_outputs':
            tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                  class_outputs),
        'box_outputs':
            tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                  box_outputs),
    })

    # Add this output to train to make the checkpoint loadable in predict mode.
    # If we skip it in train mode, the heads will be out-of-order and checkpoint
    # loading will fail.
    boxes, scores, classes, valid_detections = self._generate_detections_fn(
        box_outputs, class_outputs, rpn_rois, inputs['image_info'][:, 1:2, :])
    model_outputs.update({
        'num_detections': valid_detections,
        'detection_boxes': boxes,
        'detection_classes': classes,
        'detection_scores': scores,
    })

    if not self._include_mask:
      return model_outputs

    if is_training:
      rpn_rois, classes, mask_targets = self._sample_masks_fn(
          rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices,
          inputs['gt_masks'])
      mask_targets = tf.stop_gradient(mask_targets)

      classes = tf.cast(classes, dtype=tf.int32)

      model_outputs.update({
          'mask_targets': mask_targets,
          'sampled_class_targets': classes,
      })
    else:
      rpn_rois = boxes
      classes = tf.cast(classes, dtype=tf.int32)

    mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
        fpn_features, rpn_rois, output_size=14)

    mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes, is_training)

    if is_training:
      model_outputs.update({
          'mask_outputs':
              tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
                                    mask_outputs),
      })
    else:
      model_outputs.update({
          'detection_masks': tf.nn.sigmoid(mask_outputs)
      })

    return model_outputs