def compute_loss(self, targets, head_outputs, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor losses = [] cls_logits = head_outputs['cls_logits'] for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip( targets, cls_logits, matched_idxs): # determine only the foreground foreground_idxs_per_image = matched_idxs_per_image >= 0 num_foreground = foreground_idxs_per_image.sum() # no matched_idxs means there were no annotations in this image # TODO: enable support for images without annotations that works on distributed if False: # matched_idxs_per_image.numel() == 0: gt_classes_target = torch.zeros_like(cls_logits_per_image) valid_idxs_per_image = torch.arange( cls_logits_per_image.shape[0]) else: # create the target classification gt_classes_target = torch.zeros_like(cls_logits_per_image) gt_classes_target[ foreground_idxs_per_image, targets_per_image['labels'] [matched_idxs_per_image[foreground_idxs_per_image]]] = 1.0 # find indices for which anchors should be ignored valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss losses.append( sigmoid_focal_loss( cls_logits_per_image[valid_idxs_per_image], gt_classes_target[valid_idxs_per_image], reduction='sum', ) / max(1, num_foreground)) return _sum(losses) / len(targets), losses
def compute_loss(self, targets, head_outputs, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor losses = [] cls_logits = head_outputs['cls_logits'] for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip( targets, cls_logits, matched_idxs): # determine only the foreground foreground_idxs_per_image = matched_idxs_per_image >= 0 num_foreground = foreground_idxs_per_image.sum() # create the target classification gt_classes_target = torch.zeros_like(cls_logits_per_image) gt_classes_target[ foreground_idxs_per_image, targets_per_image['labels'][ matched_idxs_per_image[foreground_idxs_per_image]]] = 1.0 # find indices for which anchors should be ignored valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss losses.append( sigmoid_focal_loss( cls_logits_per_image[valid_idxs_per_image], gt_classes_target[valid_idxs_per_image], reduction='sum', ) / max(1, num_foreground)) return _sum(losses) / len(targets)
def old_compute_loss(self, targets, head_outputs, matched_idxs): def _sum(x): res = x[0] for i in x[1:]: res = res + i return res losses = [] cls_logits = head_outputs['cls_logits'] for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): # determine only the foreground foreground_idxs_per_image = matched_idxs_per_image >= 0 num_foreground = foreground_idxs_per_image.sum() # create the target classification gt_classes_target = torch.zeros_like(cls_logits_per_image).float() gt_classes_target[ foreground_idxs_per_image, ] = targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]].float() # find indices for which anchors should be ignored valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss losses.append(sigmoid_focal_loss( cls_logits_per_image[valid_idxs_per_image], gt_classes_target[valid_idxs_per_image], reduction='sum', ) / max(1, num_foreground)) return _sum(losses) / len(targets)
def _compute_cls_loss(self, class_pred, bg_targets, fg_class_targets): norm = (1. - bg_targets).sum() loss_cls = cv_ops.sigmoid_focal_loss(class_pred, fg_class_targets, self.alpha, self.gamma, reduction='sum') / norm return loss_cls
def forward(self, pred, target, points): class_pred, distance_pred, centerness_pred = pred['class'], pred[ 'distance'], pred['centerness'] class_targets, distance_targets, centerness_targets = target[ 'class'], target['distance'], target['centerness'] positive_idx = torch.nonzero(class_targets.reshape(-1)).reshape(-1) pos_distance_pred = distance_pred.reshape( -1, 4)[positive_idx] # [num_positives, 4] pos_distance_targets = distance_targets.reshape( -1, 4)[positive_idx] # [num_positives, 4] pos_centerness_pred = centerness_pred.reshape(-1)[ positive_idx] # [num_positives] pos_centerness_targets = centerness_targets.reshape(-1)[ positive_idx] # [num_positives] pos_points = points.reshape(-1, 2)[positive_idx] pos_decoded_bbox_pred = bbox_ops.convert_distance_to_bbox( pos_points, pos_distance_pred) pos_decoded_bbox_targets = bbox_ops.convert_distance_to_bbox( pos_points, pos_distance_targets) class_targets = func.one_hot(class_targets, num_classes=len(tools.VOC_CLASSES) + 1).float() bg_targets = class_targets[..., 0] fg_class_targets = class_targets[..., 1:] loss_cls = cv_ops.sigmoid_focal_loss( class_pred, fg_class_targets, self.alpha, self.gamma, reduction='sum') / (1. - bg_targets).sum() iou_loss = -cv_ops.box_iou( pos_decoded_bbox_pred, pos_decoded_bbox_targets).diagonal().clamp(min=1e-6).log() # iou_loss = 1 - cv_ops.generalized_box_iou(pos_decoded_bbox_pred, pos_decoded_bbox_targets).diagonal() loss_bbox = (pos_centerness_targets * iou_loss).sum() / pos_centerness_targets.sum() loss_centerness = func.binary_cross_entropy_with_logits( pos_centerness_pred, pos_centerness_targets) return loss_cls, loss_bbox, loss_centerness
def OHE_compute_loss(self, targets, head_outputs, matched_idxs): def _sum(x): res = x[0] for i in x[1:]: res = res + i return res losses = [] LOSS_ON_GPU = 1 cls_logits = head_outputs['cls_logits'].to(config.devices[LOSS_ON_GPU]) # # cls_logits = [x.to(config.devices[1]) for x in cls_logits] targets = [{ 'labels': x['labels'].to(config.devices[LOSS_ON_GPU]) } for x in targets] for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): # determine only the foreground foreground_idxs_per_image = (matched_idxs_per_image >= 0).to(config.devices[LOSS_ON_GPU]) num_foreground = foreground_idxs_per_image.sum().to(config.devices[LOSS_ON_GPU]) # create the target classification gt_classes_target = torch.zeros_like(cls_logits_per_image).float().to(config.devices[LOSS_ON_GPU]) gt_classes_target[ foreground_idxs_per_image, ] = targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]].float() # find indices for which anchors should be ignored valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss losses.append(sigmoid_focal_loss( cls_logits_per_image[valid_idxs_per_image], gt_classes_target[valid_idxs_per_image], reduction='sum', ) / max(1, num_foreground)) loss = _sum(losses) / len(targets) loss = loss.to(config.devices[0]) return loss
def compute_loss( self, model_output: DigitDetectionModelOutput, model_target: DigitDetectionModelTarget, ) -> Optional[torch.Tensor]: loss_box_regression = 0 loss_classification = 0 smooth = SmoothL1Loss() loss_box_regression += smooth(model_output.box_regression_output, model_target.box_regression_target) loss_classification += sigmoid_focal_loss( model_output.classification_output, model_target.classification_target, reduction='mean') if len(model_target.matched_anchors) == 0: return None return (loss_box_regression + loss_classification ) * model_output.classification_output.shape[1] / len( model_target.matched_anchors)
def loss_masks(self, outputs, targets, indices, num_boxes): """ Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ # assert "pred_masks" in outputs src_idx = self._get_src_permutation_idx(indices) tgt_idx = self._get_tgt_permutation_idx(indices) src_masks = outputs["pred_masks"] # TODO use valid to mask invalid areas due to padding in loss target_masks, valid = nested_tensor_from_tensor_list( [t["masks"] for t in targets]).decompose() target_masks = target_masks.to(src_masks) src_masks = src_masks[src_idx] # upsample predictions to the target size src_masks = interpolate( src_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False, ) src_masks = src_masks[:, 0].flatten(1) target_masks = target_masks[tgt_idx].flatten(1) focal_loss = sigmoid_focal_loss(src_masks, target_masks) box_norm_focal_loss = focal_loss.mean(1).sum() / num_boxes norm_dice_loss = dice_loss(src_masks, target_masks) / num_boxes losses = { "loss_mask": box_norm_focal_loss, "loss_dice": norm_dice_loss, } return losses