Exemple #1
0
    def call(self,
             images,
             image_shape,
             anchor_boxes=None,
             gt_boxes=None,
             gt_classes=None,
             gt_masks=None,
             training=None):
        model_outputs = {}

        # Feature extraction.
        features = self.backbone(images)
        if self.decoder:
            features = self.decoder(features)

        # Region proposal network.
        rpn_scores, rpn_boxes = self.rpn_head(features)

        model_outputs.update({
            'rpn_boxes': rpn_boxes,
            'rpn_scores': rpn_scores
        })

        # Generate RoIs.
        rois, _ = self.roi_generator(rpn_boxes, rpn_scores, anchor_boxes,
                                     image_shape, training)

        if training:
            rois = tf.stop_gradient(rois)

            rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
                self.roi_sampler(rois, gt_boxes, gt_classes))
            # Assign target for the 2nd stage classification.
            box_targets = box_ops.encode_boxes(matched_gt_boxes,
                                               rois,
                                               weights=[10.0, 10.0, 5.0, 5.0])
            # If the target is background, the box target is set to all 0s.
            box_targets = tf.where(
                tf.tile(
                    tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
                    [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
            model_outputs.update({
                'class_targets': matched_gt_classes,
                'box_targets': box_targets,
            })

        # RoI align.
        roi_features = self.roi_aligner(features, rois)

        # Detection head.
        raw_scores, raw_boxes = self.detection_head(roi_features)

        if training:
            model_outputs.update({
                'class_outputs': raw_scores,
                'box_outputs': raw_boxes,
            })
        else:
            # Post-processing.
            detections = self.detection_generator(raw_boxes, raw_scores, rois,
                                                  image_shape)
            model_outputs.update({
                'detection_boxes':
                detections['detection_boxes'],
                'detection_scores':
                detections['detection_scores'],
                'detection_classes':
                detections['detection_classes'],
                'num_detections':
                detections['num_detections'],
            })

        if not self._include_mask:
            return model_outputs

        if training:
            if self._config_dict['use_gt_boxes_for_masks']:
                mask_size = (
                    self.mask_roi_aligner._config_dict['crop_size'] *  # pylint:disable=protected-access
                    self.mask_head._config_dict['upsample_factor']  # pylint:disable=protected-access
                )
                gt_masks = resize_as(source=gt_masks, size=mask_size)

                logging.info('Using GT class and mask targets.')
                model_outputs.update({
                    'mask_class_targets': gt_classes,
                    'mask_targets': gt_masks,
                })
            else:
                rois, roi_classes, roi_masks = self.mask_sampler(
                    rois, matched_gt_boxes, matched_gt_classes,
                    matched_gt_indices, gt_masks)
                roi_masks = tf.stop_gradient(roi_masks)
                model_outputs.update({
                    'mask_class_targets': roi_classes,
                    'mask_targets': roi_masks,
                })

        else:
            rois = model_outputs['detection_boxes']
            roi_classes = model_outputs['detection_classes']

        # Mask RoI align.
        if training and self._config_dict['use_gt_boxes_for_masks']:
            logging.info('Using GT mask roi features.')
            mask_roi_features = self.mask_roi_aligner(features, gt_boxes)
            raw_masks = self.mask_head([mask_roi_features, gt_classes])

        else:
            mask_roi_features = self.mask_roi_aligner(features, rois)
            raw_masks = self.mask_head([mask_roi_features, roi_classes])

        # Mask head.
        if training:
            model_outputs.update({
                'mask_outputs': raw_masks,
            })
        else:
            model_outputs.update({
                'detection_masks': tf.math.sigmoid(raw_masks),
            })
        return model_outputs
Exemple #2
0
 def test_fn(boxes, anchors):
   encoded_boxes = box_ops.encode_boxes(boxes, anchors, weights)
   decoded_boxes = box_ops.decode_boxes(encoded_boxes, anchors, weights)
   return decoded_boxes
    def _run_frcnn_head(self, features, rois, gt_boxes, gt_classes, training,
                        model_outputs, cascade_num, regression_weights):
        """Runs the frcnn head that does both class and box prediction.

    Args:
      features: `list` of features from the feature extractor.
      rois: `list` of current rois that will be used to predict bbox refinement
        and classes from.
      gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4].
        This tensor might have paddings with a negative value.
      gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
        classes. It is padded with -1s to indicate the invalid classes.
      training: `bool`, if model is training or being evaluated.
      model_outputs: `dict`, used for storing outputs used for eval and losses.
      cascade_num: `int`, the current frcnn layer in the cascade.
      regression_weights: `list`, weights used for l1 loss in bounding box
        regression.

    Returns:
      class_outputs: Class predictions for rois.
      box_outputs: Box predictions for rois. These are formatted for the
        regression loss and need to be converted before being used as rois
        in the next stage.
      model_outputs: Updated dict with predictions used for losses and eval.
      matched_gt_boxes: If `is_training` is true, then these give the gt box
        location of its positive match.
      matched_gt_classes: If `is_training` is true, then these give the gt class
         of the predicted box.
      matched_gt_boxes: If `is_training` is true, then these give the box
        location of its positive match.
      matched_gt_indices: If `is_training` is true, then gives the index of
        the positive box match. Used for mask prediction.
      rois: The sampled rois used for this layer.
    """
        # Only used during training.
        matched_gt_boxes, matched_gt_classes, matched_gt_indices = (None, None,
                                                                    None)
        if training and gt_boxes is not None:
            rois = tf.stop_gradient(rois)

            current_roi_sampler = self.roi_sampler[cascade_num]
            rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
                current_roi_sampler(rois, gt_boxes, gt_classes))
            # Create bounding box training targets.
            box_targets = box_ops.encode_boxes(matched_gt_boxes,
                                               rois,
                                               weights=regression_weights)
            # If the target is background, the box target is set to all 0s.
            box_targets = tf.where(
                tf.tile(
                    tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
                    [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
            model_outputs.update({
                'class_targets_{}'.format(cascade_num) if cascade_num else 'class_targets':
                matched_gt_classes,
                'box_targets_{}'.format(cascade_num) if cascade_num else 'box_targets':
                box_targets,
            })

        # Get roi features.
        roi_features = self.roi_aligner(features, rois)

        # Run frcnn head to get class and bbox predictions.
        current_detection_head = self.detection_head[cascade_num]
        class_outputs, box_outputs = current_detection_head(roi_features)

        model_outputs.update({
            'class_outputs_{}'.format(cascade_num) if cascade_num else 'class_outputs':
            class_outputs,
            'box_outputs_{}'.format(cascade_num) if cascade_num else 'box_outputs':
            box_outputs,
        })
        return (class_outputs, box_outputs, model_outputs, matched_gt_boxes,
                matched_gt_classes, matched_gt_indices, rois)
    def call(self,
             images: tf.Tensor,
             image_shape: tf.Tensor,
             anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
             gt_boxes: tf.Tensor = None,
             gt_classes: tf.Tensor = None,
             gt_masks: tf.Tensor = None,
             training: bool = None) -> Mapping[str, tf.Tensor]:
        model_outputs = {}

        # Feature extraction.
        features = self.backbone(images)
        if self.decoder:
            features = self.decoder(features)

        # Region proposal network.
        rpn_scores, rpn_boxes = self.rpn_head(features)

        model_outputs.update({
            'rpn_boxes': rpn_boxes,
            'rpn_scores': rpn_scores
        })

        # Generate RoIs.
        rois, _ = self.roi_generator(rpn_boxes, rpn_scores, anchor_boxes,
                                     image_shape, training)

        if training:
            rois = tf.stop_gradient(rois)

            rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
                self.roi_sampler(rois, gt_boxes, gt_classes))
            # Assign target for the 2nd stage classification.
            box_targets = box_ops.encode_boxes(matched_gt_boxes,
                                               rois,
                                               weights=[10.0, 10.0, 5.0, 5.0])
            # If the target is background, the box target is set to all 0s.
            box_targets = tf.where(
                tf.tile(
                    tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
                    [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
            model_outputs.update({
                'class_targets': matched_gt_classes,
                'box_targets': box_targets,
            })

        # RoI align.
        roi_features = self.roi_aligner(features, rois)

        # Detection head.
        raw_scores, raw_boxes = self.detection_head(roi_features)

        if training:
            model_outputs.update({
                'class_outputs': raw_scores,
                'box_outputs': raw_boxes,
            })
        else:
            # Post-processing.
            detections = self.detection_generator(raw_boxes, raw_scores, rois,
                                                  image_shape)
            model_outputs.update({
                'detection_boxes':
                detections['detection_boxes'],
                'detection_scores':
                detections['detection_scores'],
                'detection_classes':
                detections['detection_classes'],
                'num_detections':
                detections['num_detections'],
            })

        if not self._include_mask:
            return model_outputs

        if training:
            rois, roi_classes, roi_masks = self.mask_sampler(
                rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices,
                gt_masks)
            roi_masks = tf.stop_gradient(roi_masks)

            model_outputs.update({
                'mask_class_targets': roi_classes,
                'mask_targets': roi_masks,
            })
        else:
            rois = model_outputs['detection_boxes']
            roi_classes = model_outputs['detection_classes']

        # Mask RoI align.
        mask_roi_features = self.mask_roi_aligner(features, rois)

        # Mask head.
        raw_masks = self.mask_head([mask_roi_features, roi_classes])

        if training:
            model_outputs.update({
                'mask_outputs': raw_masks,
            })
        else:
            model_outputs.update({
                'detection_masks': tf.math.sigmoid(raw_masks),
            })
        return model_outputs