Пример #1
0
    def build_losses(self, outputs, labels, aux_losses=None):
        """Build Mask R-CNN losses."""
        params = self.task_config

        rpn_score_loss_fn = maskrcnn_losses.RpnScoreLoss(
            tf.shape(outputs['box_outputs'])[1])
        rpn_box_loss_fn = maskrcnn_losses.RpnBoxLoss(
            params.losses.rpn_huber_loss_delta)
        rpn_score_loss = tf.reduce_mean(
            rpn_score_loss_fn(outputs['rpn_scores'],
                              labels['rpn_score_targets']))
        rpn_box_loss = tf.reduce_mean(
            rpn_box_loss_fn(outputs['rpn_boxes'], labels['rpn_box_targets']))

        frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss()
        frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss(
            params.losses.frcnn_huber_loss_delta)
        frcnn_cls_loss = tf.reduce_mean(
            frcnn_cls_loss_fn(outputs['class_outputs'],
                              outputs['class_targets']))
        frcnn_box_loss = tf.reduce_mean(
            frcnn_box_loss_fn(outputs['box_outputs'], outputs['class_targets'],
                              outputs['box_targets']))

        if params.model.include_mask:
            mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
            mask_class_targets = outputs['mask_class_targets']
            if self._task_config.allowed_mask_class_ids is not None:
                # Classes with ID=0 are ignored by mask_loss_fn in loss computation.
                mask_class_targets = zero_out_disallowed_class_ids(
                    mask_class_targets,
                    self._task_config.allowed_mask_class_ids)

            mask_loss = tf.reduce_mean(
                mask_loss_fn(outputs['mask_outputs'], outputs['mask_targets'],
                             mask_class_targets))
        else:
            mask_loss = 0.0

        model_loss = (params.losses.rpn_score_weight * rpn_score_loss +
                      params.losses.rpn_box_weight * rpn_box_loss +
                      params.losses.frcnn_class_weight * frcnn_cls_loss +
                      params.losses.frcnn_box_weight * frcnn_box_loss +
                      params.losses.mask_weight * mask_loss)

        total_loss = model_loss
        if aux_losses:
            reg_loss = tf.reduce_sum(aux_losses)
            total_loss = model_loss + reg_loss

        losses = {
            'total_loss': total_loss,
            'rpn_score_loss': rpn_score_loss,
            'rpn_box_loss': rpn_box_loss,
            'frcnn_cls_loss': frcnn_cls_loss,
            'frcnn_box_loss': frcnn_box_loss,
            'mask_loss': mask_loss,
            'model_loss': model_loss,
        }
        return losses
Пример #2
0
  def build_losses(self,
                   outputs: Mapping[str, Any],
                   labels: Mapping[str, Any],
                   aux_losses: Optional[Any] = None):
    """Build Mask R-CNN losses."""
    params = self.task_config
    cascade_ious = params.model.roi_sampler.cascade_iou_thresholds

    rpn_score_loss_fn = maskrcnn_losses.RpnScoreLoss(
        tf.shape(outputs['box_outputs'])[1])
    rpn_box_loss_fn = maskrcnn_losses.RpnBoxLoss(
        params.losses.rpn_huber_loss_delta)
    rpn_score_loss = tf.reduce_mean(
        rpn_score_loss_fn(
            outputs['rpn_scores'], labels['rpn_score_targets']))
    rpn_box_loss = tf.reduce_mean(
        rpn_box_loss_fn(
            outputs['rpn_boxes'], labels['rpn_box_targets']))

    frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss()
    frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss(
        params.losses.frcnn_huber_loss_delta,
        params.model.detection_head.class_agnostic_bbox_pred)

    # Final cls/box losses are computed as an average of all detection heads.
    frcnn_cls_loss = 0.0
    frcnn_box_loss = 0.0
    num_det_heads = 1 if cascade_ious is None else 1 + len(cascade_ious)
    for cas_num in range(num_det_heads):
      frcnn_cls_loss_i = tf.reduce_mean(
          frcnn_cls_loss_fn(
              outputs['class_outputs_{}'
                      .format(cas_num) if cas_num else 'class_outputs'],
              outputs['class_targets_{}'
                      .format(cas_num) if cas_num else 'class_targets']))
      frcnn_box_loss_i = tf.reduce_mean(
          frcnn_box_loss_fn(
              outputs['box_outputs_{}'.format(cas_num
                                             ) if cas_num else 'box_outputs'],
              outputs['class_targets_{}'
                      .format(cas_num) if cas_num else 'class_targets'],
              outputs['box_targets_{}'.format(cas_num
                                             ) if cas_num else 'box_targets']))
      frcnn_cls_loss += frcnn_cls_loss_i
      frcnn_box_loss += frcnn_box_loss_i
    frcnn_cls_loss /= num_det_heads
    frcnn_box_loss /= num_det_heads

    if params.model.include_mask:
      mask_loss_fn = maskrcnn_losses.MaskrcnnLoss()
      mask_class_targets = outputs['mask_class_targets']
      if self._task_config.allowed_mask_class_ids is not None:
        # Classes with ID=0 are ignored by mask_loss_fn in loss computation.
        mask_class_targets = zero_out_disallowed_class_ids(
            mask_class_targets, self._task_config.allowed_mask_class_ids)

      mask_loss = tf.reduce_mean(
          mask_loss_fn(
              outputs['mask_outputs'],
              outputs['mask_targets'],
              mask_class_targets))
    else:
      mask_loss = 0.0

    model_loss = (
        params.losses.rpn_score_weight * rpn_score_loss +
        params.losses.rpn_box_weight * rpn_box_loss +
        params.losses.frcnn_class_weight * frcnn_cls_loss +
        params.losses.frcnn_box_weight * frcnn_box_loss +
        params.losses.mask_weight * mask_loss)

    total_loss = model_loss
    if aux_losses:
      reg_loss = tf.reduce_sum(aux_losses)
      total_loss = model_loss + reg_loss

    total_loss = params.losses.loss_weight * total_loss
    losses = {
        'total_loss': total_loss,
        'rpn_score_loss': rpn_score_loss,
        'rpn_box_loss': rpn_box_loss,
        'frcnn_cls_loss': frcnn_cls_loss,
        'frcnn_box_loss': frcnn_box_loss,
        'mask_loss': mask_loss,
        'model_loss': model_loss,
    }
    return losses