示例#1
0
    def compute_loss(self, targets, head_outputs, matched_idxs):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
        losses = []

        cls_logits = head_outputs['cls_logits']

        for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(
                targets, cls_logits, matched_idxs):
            # determine only the foreground
            foreground_idxs_per_image = matched_idxs_per_image >= 0
            num_foreground = foreground_idxs_per_image.sum()
            # no matched_idxs means there were no annotations in this image
            # TODO: enable support for images without annotations that works on distributed
            if False:  # matched_idxs_per_image.numel() == 0:
                gt_classes_target = torch.zeros_like(cls_logits_per_image)
                valid_idxs_per_image = torch.arange(
                    cls_logits_per_image.shape[0])
            else:
                # create the target classification
                gt_classes_target = torch.zeros_like(cls_logits_per_image)
                gt_classes_target[
                    foreground_idxs_per_image, targets_per_image['labels']
                    [matched_idxs_per_image[foreground_idxs_per_image]]] = 1.0

                # find indices for which anchors should be ignored
                valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

            # compute the classification loss
            losses.append(
                sigmoid_focal_loss(
                    cls_logits_per_image[valid_idxs_per_image],
                    gt_classes_target[valid_idxs_per_image],
                    reduction='sum',
                ) / max(1, num_foreground))
        return _sum(losses) / len(targets), losses
示例#2
0
    def compute_loss(self, targets, head_outputs, matched_idxs):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
        losses = []

        cls_logits = head_outputs['cls_logits']

        for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(
                targets, cls_logits, matched_idxs):
            # determine only the foreground
            foreground_idxs_per_image = matched_idxs_per_image >= 0
            num_foreground = foreground_idxs_per_image.sum()

            # create the target classification
            gt_classes_target = torch.zeros_like(cls_logits_per_image)
            gt_classes_target[
                foreground_idxs_per_image, targets_per_image['labels'][
                    matched_idxs_per_image[foreground_idxs_per_image]]] = 1.0

            # find indices for which anchors should be ignored
            valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

            # compute the classification loss
            losses.append(
                sigmoid_focal_loss(
                    cls_logits_per_image[valid_idxs_per_image],
                    gt_classes_target[valid_idxs_per_image],
                    reduction='sum',
                ) / max(1, num_foreground))

        return _sum(losses) / len(targets)
    def old_compute_loss(self, targets, head_outputs, matched_idxs):
        def _sum(x):
            res = x[0]
            for i in x[1:]:
                res = res + i
            return res

        losses = []

        cls_logits = head_outputs['cls_logits']

        for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
            # determine only the foreground
            foreground_idxs_per_image = matched_idxs_per_image >= 0
            num_foreground = foreground_idxs_per_image.sum()

            # create the target classification
            gt_classes_target = torch.zeros_like(cls_logits_per_image).float()
            gt_classes_target[
                foreground_idxs_per_image,

            ] = targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]].float()

            # find indices for which anchors should be ignored
            valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

            # compute the classification loss
            losses.append(sigmoid_focal_loss(
                cls_logits_per_image[valid_idxs_per_image],
                gt_classes_target[valid_idxs_per_image],
                reduction='sum',
            ) / max(1, num_foreground))

        return _sum(losses) / len(targets)
示例#4
0
 def _compute_cls_loss(self, class_pred, bg_targets, fg_class_targets):
     norm = (1. - bg_targets).sum()
     loss_cls = cv_ops.sigmoid_focal_loss(class_pred,
                                          fg_class_targets,
                                          self.alpha,
                                          self.gamma,
                                          reduction='sum') / norm
     return loss_cls
示例#5
0
    def forward(self, pred, target, points):
        class_pred, distance_pred, centerness_pred = pred['class'], pred[
            'distance'], pred['centerness']
        class_targets, distance_targets, centerness_targets = target[
            'class'], target['distance'], target['centerness']

        positive_idx = torch.nonzero(class_targets.reshape(-1)).reshape(-1)
        pos_distance_pred = distance_pred.reshape(
            -1, 4)[positive_idx]  # [num_positives, 4]
        pos_distance_targets = distance_targets.reshape(
            -1, 4)[positive_idx]  # [num_positives, 4]
        pos_centerness_pred = centerness_pred.reshape(-1)[
            positive_idx]  # [num_positives]
        pos_centerness_targets = centerness_targets.reshape(-1)[
            positive_idx]  # [num_positives]

        pos_points = points.reshape(-1, 2)[positive_idx]
        pos_decoded_bbox_pred = bbox_ops.convert_distance_to_bbox(
            pos_points, pos_distance_pred)
        pos_decoded_bbox_targets = bbox_ops.convert_distance_to_bbox(
            pos_points, pos_distance_targets)

        class_targets = func.one_hot(class_targets,
                                     num_classes=len(tools.VOC_CLASSES) +
                                     1).float()
        bg_targets = class_targets[..., 0]
        fg_class_targets = class_targets[..., 1:]
        loss_cls = cv_ops.sigmoid_focal_loss(
            class_pred,
            fg_class_targets,
            self.alpha,
            self.gamma,
            reduction='sum') / (1. - bg_targets).sum()

        iou_loss = -cv_ops.box_iou(
            pos_decoded_bbox_pred,
            pos_decoded_bbox_targets).diagonal().clamp(min=1e-6).log()
        # iou_loss = 1 - cv_ops.generalized_box_iou(pos_decoded_bbox_pred, pos_decoded_bbox_targets).diagonal()
        loss_bbox = (pos_centerness_targets *
                     iou_loss).sum() / pos_centerness_targets.sum()

        loss_centerness = func.binary_cross_entropy_with_logits(
            pos_centerness_pred, pos_centerness_targets)

        return loss_cls, loss_bbox, loss_centerness
    def OHE_compute_loss(self, targets, head_outputs, matched_idxs):
        def _sum(x):
            res = x[0]
            for i in x[1:]:
                res = res + i
            return res

        losses = []

        LOSS_ON_GPU = 1

        cls_logits = head_outputs['cls_logits'].to(config.devices[LOSS_ON_GPU])
        #
        # cls_logits = [x.to(config.devices[1]) for x in cls_logits]
        targets = [{
            'labels': x['labels'].to(config.devices[LOSS_ON_GPU])
        } for x in targets]

        for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
            # determine only the foreground
            foreground_idxs_per_image = (matched_idxs_per_image >= 0).to(config.devices[LOSS_ON_GPU])
            num_foreground = foreground_idxs_per_image.sum().to(config.devices[LOSS_ON_GPU])

            # create the target classification
            gt_classes_target = torch.zeros_like(cls_logits_per_image).float().to(config.devices[LOSS_ON_GPU])
            gt_classes_target[
                foreground_idxs_per_image,

            ] = targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]].float()

            # find indices for which anchors should be ignored
            valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS

            # compute the classification loss
            losses.append(sigmoid_focal_loss(
                cls_logits_per_image[valid_idxs_per_image],
                gt_classes_target[valid_idxs_per_image],
                reduction='sum',
            ) / max(1, num_foreground))

        loss = _sum(losses) / len(targets)

        loss = loss.to(config.devices[0])
        return loss
示例#7
0
    def compute_loss(
        self,
        model_output: DigitDetectionModelOutput,
        model_target: DigitDetectionModelTarget,
    ) -> Optional[torch.Tensor]:
        loss_box_regression = 0
        loss_classification = 0
        smooth = SmoothL1Loss()
        loss_box_regression += smooth(model_output.box_regression_output,
                                      model_target.box_regression_target)
        loss_classification += sigmoid_focal_loss(
            model_output.classification_output,
            model_target.classification_target,
            reduction='mean')

        if len(model_target.matched_anchors) == 0:
            return None

        return (loss_box_regression + loss_classification
                ) * model_output.classification_output.shape[1] / len(
                    model_target.matched_anchors)
示例#8
0
    def loss_masks(self, outputs, targets, indices, num_boxes):
        """
        Compute the losses related to the masks: the focal loss and the dice loss.
        targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
        """
        # assert "pred_masks" in outputs

        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)

        src_masks = outputs["pred_masks"]

        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = nested_tensor_from_tensor_list(
            [t["masks"] for t in targets]).decompose()
        target_masks = target_masks.to(src_masks)

        src_masks = src_masks[src_idx]
        # upsample predictions to the target size
        src_masks = interpolate(
            src_masks[:, None],
            size=target_masks.shape[-2:],
            mode="bilinear",
            align_corners=False,
        )
        src_masks = src_masks[:, 0].flatten(1)

        target_masks = target_masks[tgt_idx].flatten(1)
        focal_loss = sigmoid_focal_loss(src_masks, target_masks)
        box_norm_focal_loss = focal_loss.mean(1).sum() / num_boxes
        norm_dice_loss = dice_loss(src_masks, target_masks) / num_boxes
        losses = {
            "loss_mask": box_norm_focal_loss,
            "loss_dice": norm_dice_loss,
        }
        return losses