def _build_frcnn_losses( self, outputs: Mapping[str, Any], labels: Mapping[str, Any]) -> Tuple[tf.Tensor, tf.Tensor]: """Build losses for Fast R-CNN.""" cascade_ious = self.task_config.model.roi_sampler.cascade_iou_thresholds frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss() frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss( self.task_config.losses.frcnn_huber_loss_delta, self.task_config.model.detection_head.class_agnostic_bbox_pred) # Final cls/box losses are computed as an average of all detection heads. frcnn_cls_loss = 0.0 frcnn_box_loss = 0.0 num_det_heads = 1 if cascade_ious is None else 1 + len(cascade_ious) for cas_num in range(num_det_heads): frcnn_cls_loss_i = tf.reduce_mean( frcnn_cls_loss_fn( outputs['class_outputs_{}'. format(cas_num) if cas_num else 'class_outputs'], outputs['class_targets_{}'. format(cas_num) if cas_num else 'class_targets'])) frcnn_box_loss_i = tf.reduce_mean( frcnn_box_loss_fn( outputs['box_outputs_{}'. format(cas_num) if cas_num else 'box_outputs'], outputs['class_targets_{}'. format(cas_num) if cas_num else 'class_targets'], outputs['box_targets_{}'. format(cas_num) if cas_num else 'box_targets'])) frcnn_cls_loss += frcnn_cls_loss_i frcnn_box_loss += frcnn_box_loss_i frcnn_cls_loss /= num_det_heads frcnn_box_loss /= num_det_heads return frcnn_cls_loss, frcnn_box_loss
def build_losses(self, outputs: Mapping[str, Any], labels: Mapping[str, Any], aux_losses: Optional[Any] = None): """Build Mask R-CNN losses.""" params = self.task_config cascade_ious = params.model.roi_sampler.cascade_iou_thresholds rpn_score_loss_fn = maskrcnn_losses.RpnScoreLoss( tf.shape(outputs['box_outputs'])[1]) rpn_box_loss_fn = maskrcnn_losses.RpnBoxLoss( params.losses.rpn_huber_loss_delta) rpn_score_loss = tf.reduce_mean( rpn_score_loss_fn(outputs['rpn_scores'], labels['rpn_score_targets'])) rpn_box_loss = tf.reduce_mean( rpn_box_loss_fn(outputs['rpn_boxes'], labels['rpn_box_targets'])) frcnn_cls_loss_fn = maskrcnn_losses.FastrcnnClassLoss() frcnn_box_loss_fn = maskrcnn_losses.FastrcnnBoxLoss( params.losses.frcnn_huber_loss_delta, params.model.detection_head.class_agnostic_bbox_pred) # Final cls/box losses are computed as an average of all detection heads. frcnn_cls_loss = 0.0 frcnn_box_loss = 0.0 num_det_heads = 1 if cascade_ious is None else 1 + len(cascade_ious) for cas_num in range(num_det_heads): frcnn_cls_loss_i = tf.reduce_mean( frcnn_cls_loss_fn( outputs['class_outputs_{}'. format(cas_num) if cas_num else 'class_outputs'], outputs['class_targets_{}'. format(cas_num) if cas_num else 'class_targets'])) frcnn_box_loss_i = tf.reduce_mean( frcnn_box_loss_fn( outputs['box_outputs_{}'. format(cas_num) if cas_num else 'box_outputs'], outputs['class_targets_{}'. format(cas_num) if cas_num else 'class_targets'], outputs['box_targets_{}'. format(cas_num) if cas_num else 'box_targets'])) frcnn_cls_loss += frcnn_cls_loss_i frcnn_box_loss += frcnn_box_loss_i frcnn_cls_loss /= num_det_heads frcnn_box_loss /= num_det_heads if params.model.include_mask: mask_loss_fn = maskrcnn_losses.MaskrcnnLoss() mask_class_targets = outputs['mask_class_targets'] if self._task_config.allowed_mask_class_ids is not None: # Classes with ID=0 are ignored by mask_loss_fn in loss computation. mask_class_targets = zero_out_disallowed_class_ids( mask_class_targets, self._task_config.allowed_mask_class_ids) mask_loss = tf.reduce_mean( mask_loss_fn(outputs['mask_outputs'], outputs['mask_targets'], mask_class_targets)) else: mask_loss = 0.0 model_loss = (params.losses.rpn_score_weight * rpn_score_loss + params.losses.rpn_box_weight * rpn_box_loss + params.losses.frcnn_class_weight * frcnn_cls_loss + params.losses.frcnn_box_weight * frcnn_box_loss + params.losses.mask_weight * mask_loss) total_loss = model_loss if aux_losses: reg_loss = tf.reduce_sum(aux_losses) total_loss = model_loss + reg_loss total_loss = params.losses.loss_weight * total_loss losses = { 'total_loss': total_loss, 'rpn_score_loss': rpn_score_loss, 'rpn_box_loss': rpn_box_loss, 'frcnn_cls_loss': frcnn_cls_loss, 'frcnn_box_loss': frcnn_box_loss, 'mask_loss': mask_loss, 'model_loss': model_loss, } return losses