Beispiel #1
0
    def loss_labels(self, outputs, targets, indices, num_boxes, log=False):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat(
            [t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2],
                                    self.num_classes,
                                    dtype=torch.int64,
                                    device=src_logits.device)
        target_classes[idx] = target_classes_o

        if self.use_focal:
            src_logits = src_logits.flatten(0, 1)
            # prepare one_hot target.
            target_classes = target_classes.flatten(0, 1)
            pos_inds = torch.nonzero(target_classes != self.num_classes,
                                     as_tuple=True)[0]
            labels = torch.zeros_like(src_logits)
            labels[pos_inds, target_classes[pos_inds]] = 1
            # comp focal loss.
            class_loss = sigmoid_focal_loss_jit(
                src_logits,
                labels,
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            ) / num_boxes
            losses = {'loss_ce': class_loss}
        else:
            loss_ce = F.cross_entropy(src_logits.transpose(1, 2),
                                      target_classes, self.empty_weight)
            losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx],
                                                   target_classes_o)[0]
        return losses
Beispiel #2
0
    def losses(
            self,
            anchors,
            gt_classes,
            gt_boxes,
            pred_class_logits,
            pred_anchor_deltas,
            pred_class_logits_var=None,
            pred_bbox_cov=None):
        """
        Args:
            For `gt_classes` and `gt_anchors_deltas` parameters, see
                :meth:`RetinaNet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of anchors across levels, i.e. sum(Hi x Wi x A)
            For `pred_class_logits`, `pred_anchor_deltas`, `pred_class_logits_var` and `pred_bbox_cov`, see
                :meth:`RetinaNetHead.forward`.
        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        num_images = len(gt_classes)
        gt_labels = torch.stack(gt_classes)  # (N, R)
        anchors = type(anchors[0]).cat(anchors).tensor  # (R, 4)
        gt_anchor_deltas = [
            self.box2box_transform.get_deltas(
                anchors, k) for k in gt_boxes]
        gt_anchor_deltas = torch.stack(gt_anchor_deltas)  # (N, R, 4)

        valid_mask = gt_labels >= 0
        pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
        num_pos_anchors = pos_mask.sum().item()
        get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images)
        self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + \
            (1 - self.loss_normalizer_momentum) * max(num_pos_anchors, 1)

        # classification and regression loss

        # Shapes:
        # (N x R, K) for class_logits and class_logits_var.
        # (N x R, 4), (N x R x 10) for pred_anchor_deltas and pred_class_bbox_cov respectively.

        # Transform per-feature layer lists to a single tensor
        pred_class_logits = cat(pred_class_logits, dim=1)
        pred_anchor_deltas = cat(pred_anchor_deltas, dim=1)

        if pred_class_logits_var is not None:
            pred_class_logits_var = cat(
                pred_class_logits_var, dim=1)

        if pred_bbox_cov is not None:
            pred_bbox_cov = cat(
                pred_bbox_cov, dim=1)

        gt_classes_target = torch.nn.functional.one_hot(
            gt_labels[valid_mask],
            num_classes=self.num_classes +
            1)[
            :,
            :-
            1].to(
            pred_class_logits[0].dtype)  # no loss for the last (background) class

        # Classification losses
        if self.compute_cls_var:
            # Compute classification variance according to:
            # "What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?", NIPS 2017
            if self.cls_var_loss == 'loss_attenuation':
                num_samples = self.cls_var_num_samples
                # Compute standard deviation
                pred_class_logits_var = torch.sqrt(torch.exp(
                    pred_class_logits_var[valid_mask]))

                pred_class_logits = pred_class_logits[valid_mask]

                # Produce normal samples using logits as the mean and the standard deviation computed above
                # Scales with GPU memory. 12 GB ---> 3 Samples per anchor for
                # COCO dataset.
                univariate_normal_dists = distributions.normal.Normal(
                    pred_class_logits, scale=pred_class_logits_var)

                pred_class_stochastic_logits = univariate_normal_dists.rsample(
                    (num_samples,))
                pred_class_stochastic_logits = pred_class_stochastic_logits.view(
                    (pred_class_stochastic_logits.shape[1] * num_samples, pred_class_stochastic_logits.shape[2], -1))
                pred_class_stochastic_logits = pred_class_stochastic_logits.squeeze(
                    2)

                # Produce copies of the target classes to match the number of
                # stochastic samples.
                gt_classes_target = torch.unsqueeze(gt_classes_target, 0)
                gt_classes_target = torch.repeat_interleave(
                    gt_classes_target, num_samples, dim=0).view(
                    (gt_classes_target.shape[1] * num_samples, gt_classes_target.shape[2], -1))
                gt_classes_target = gt_classes_target.squeeze(2)

                # Produce copies of the target classes to form the stochastic
                # focal loss.
                loss_cls = sigmoid_focal_loss_jit(
                    pred_class_stochastic_logits,
                    gt_classes_target,
                    alpha=self.focal_loss_alpha,
                    gamma=self.focal_loss_gamma,
                    reduction="sum",
                ) / (num_samples * max(1, self.loss_normalizer))
            else:
                raise ValueError(
                    'Invalid classification loss name {}.'.format(
                        self.bbox_cov_loss))
        else:
            # Standard loss computation in case one wants to use this code
            # without any probabilistic inference.
            loss_cls = sigmoid_focal_loss_jit(
                pred_class_logits[valid_mask],
                gt_classes_target,
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            ) / max(1, self.loss_normalizer)

        # Compute Regression Loss
        pred_anchor_deltas = pred_anchor_deltas[pos_mask]
        gt_anchors_deltas = gt_anchor_deltas[pos_mask]
        if self.compute_bbox_cov:
            # We have to clamp the output variance else probabilistic metrics
            # go to infinity.
            pred_bbox_cov = clamp_log_variance(pred_bbox_cov[pos_mask])
            if self.bbox_cov_loss == 'negative_log_likelihood':
                if self.bbox_cov_type == 'diagonal':
                    # Compute regression variance according to:
                    # "What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?", NIPS 2017
                    # This implementation with smooth_l1_loss outperforms using
                    # torch.distribution.multivariate_normal. Losses might have different numerical values
                    # since we do not include constants in this implementation.
                    loss_box_reg = 0.5 * torch.exp(-pred_bbox_cov) * smooth_l1_loss(
                        pred_anchor_deltas,
                        gt_anchors_deltas,
                        beta=self.smooth_l1_beta)
                    loss_covariance_regularize = 0.5 * pred_bbox_cov
                    loss_box_reg += loss_covariance_regularize

                    # Sum over all elements
                    loss_box_reg = torch.sum(
                        loss_box_reg) / max(1, self.loss_normalizer)
                else:
                    # Multivariate negative log likelihood. Implemented with
                    # pytorch multivariate_normal.log_prob function. Custom implementations fail to finish training
                    # due to NAN loss.

                    # This is the Cholesky decomposition of the covariance matrix. We reconstruct it from 10 estimated
                    # parameters as a lower triangular matrix.
                    forecaster_cholesky = covariance_output_to_cholesky(
                        pred_bbox_cov)

                    # Compute multivariate normal distribution using torch
                    # distribution functions.
                    multivariate_normal_dists = distributions.multivariate_normal.MultivariateNormal(
                        pred_anchor_deltas, scale_tril=forecaster_cholesky)

                    loss_box_reg = - \
                        multivariate_normal_dists.log_prob(gt_anchors_deltas)
                    loss_box_reg = torch.sum(
                        loss_box_reg) / max(1, self.loss_normalizer)

            elif self.bbox_cov_loss == 'second_moment_matching':
                # Compute regression covariance using second moment matching.
                loss_box_reg = smooth_l1_loss(
                    pred_anchor_deltas,
                    gt_anchors_deltas,
                    beta=self.smooth_l1_beta)

                # Compute errors
                errors = (pred_anchor_deltas - gt_anchors_deltas)

                if self.bbox_cov_type == 'diagonal':
                    # Compute second moment matching term.
                    second_moment_matching_term = smooth_l1_loss(
                        torch.exp(pred_bbox_cov), errors ** 2, beta=self.smooth_l1_beta)
                    loss_box_reg += second_moment_matching_term
                    loss_box_reg = torch.sum(
                        loss_box_reg) / max(1, self.loss_normalizer)
                else:
                    # Compute second moment matching term.
                    errors = torch.unsqueeze(errors, 2)
                    gt_error_covar = torch.matmul(
                        errors, torch.transpose(errors, 2, 1))

                    # This is the cholesky decomposition of the covariance matrix. We reconstruct it from 10 estimated
                    # parameters as a lower triangular matrix.
                    forecaster_cholesky = covariance_output_to_cholesky(
                        pred_bbox_cov)

                    predicted_covar = torch.matmul(
                        forecaster_cholesky, torch.transpose(
                            forecaster_cholesky, 2, 1))

                    second_moment_matching_term = smooth_l1_loss(
                        predicted_covar, gt_error_covar, beta=self.smooth_l1_beta, reduction='sum')

                    loss_box_reg = (torch.sum(
                        loss_box_reg) + second_moment_matching_term) / max(1, self.loss_normalizer)

            elif self.bbox_cov_loss == 'energy_loss':
                # Compute regression variance according to energy score loss.
                forecaster_means = pred_anchor_deltas

                # Compute forecaster cholesky. Takes care of diagonal case
                # automatically.
                forecaster_cholesky = covariance_output_to_cholesky(
                    pred_bbox_cov)

                # Define normal distribution samples. To compute energy score,
                # we need i+1 samples.

                # Define per-anchor Distributions
                multivariate_normal_dists = distributions.multivariate_normal.MultivariateNormal(
                    forecaster_means, scale_tril=forecaster_cholesky)

                # Define Monte-Carlo Samples
                distributions_samples = multivariate_normal_dists.rsample(
                    (self.bbox_cov_num_samples + 1,))

                distributions_samples_1 = distributions_samples[0:self.bbox_cov_num_samples, :, :]
                distributions_samples_2 = distributions_samples[1:
                                                                self.bbox_cov_num_samples + 1, :, :]

                # Compute energy score
                gt_anchors_deltas_samples = torch.repeat_interleave(
                    gt_anchors_deltas.unsqueeze(0), self.bbox_cov_num_samples, dim=0)

                energy_score_first_term = 2.0 * smooth_l1_loss(
                    distributions_samples_1,
                    gt_anchors_deltas_samples,
                    beta=self.smooth_l1_beta,
                    reduction="sum") / self.bbox_cov_num_samples  # First term

                energy_score_second_term = - smooth_l1_loss(
                    distributions_samples_1,
                    distributions_samples_2,
                    beta=self.smooth_l1_beta,
                    reduction="sum") / self.bbox_cov_num_samples   # Second term

                # Final Loss
                loss_box_reg = (
                    energy_score_first_term + energy_score_second_term) / max(1, self.loss_normalizer)

            else:
                raise ValueError(
                    'Invalid regression loss name {}.'.format(
                        self.bbox_cov_loss))

            # Perform loss annealing. Essential for reliably training variance estimates using NLL in RetinaNet.
            # For energy score and second moment matching, this is optional.
            standard_regression_loss = smooth_l1_loss(
                pred_anchor_deltas,
                gt_anchors_deltas,
                beta=self.smooth_l1_beta,
                reduction="sum",
            ) / max(1, self.loss_normalizer)

            probabilistic_loss_weight = get_probabilistic_loss_weight(
                self.current_step, self.annealing_step)
            loss_box_reg = (1.0 - probabilistic_loss_weight) * \
                standard_regression_loss + probabilistic_loss_weight * loss_box_reg
        else:
            # Standard regression loss in case no variance is needed to be
            # estimated.
            loss_box_reg = smooth_l1_loss(
                pred_anchor_deltas,
                gt_anchors_deltas,
                beta=self.smooth_l1_beta,
                reduction="sum",
            ) / max(1, self.loss_normalizer)

        return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
    def losses(self, init_gt_classes, init_reg_targets, refine_gt_classes, refine_reg_targets, \
               pred_class_logits, pred_box_reg_init, pred_box_reg, pred_center_score, strides, pred_ratio):

        strides = strides.repeat(pred_class_logits[0].shape[0])  # [N*X]
        pred_class_logits, pred_box_reg_init, pred_box_reg, pred_center_score, pred_ratio = \
            permute_and_concat(pred_class_logits, pred_box_reg_init, pred_box_reg, pred_center_score, pred_ratio, self.num_classes)
        # Shapes: (N x R) and (N x R, 4), (N x R) respectively.

        init_gt_classes = init_gt_classes.flatten()
        init_reg_targets = init_reg_targets.view(-1, 4)

        init_foreground_idxs = (init_gt_classes >= 0) & (init_gt_classes != self.num_classes)
        init_pos_inds = torch.nonzero(init_foreground_idxs).squeeze(1)

        num_gpus = get_num_gpus()
        # sync num_pos from all gpus
        init_total_num_pos = reduce_sum(init_pos_inds.new_tensor([init_pos_inds.numel()])).item()
        init_num_pos_avg_per_gpu = max(init_total_num_pos / float(num_gpus), 1.0)

        refine_gt_classes = refine_gt_classes.flatten()
        refine_reg_targets = refine_reg_targets.view(-1, 4)

        refine_foreground_idxs = (refine_gt_classes >= 0) & (refine_gt_classes != self.num_classes)
        refine_pos_inds = torch.nonzero(refine_foreground_idxs).squeeze(1)

        # sync num_pos from all gpus
        refine_total_num_pos = reduce_sum(refine_pos_inds.new_tensor([refine_pos_inds.numel()])).item()
        refine_num_pos_avg_per_gpu = max(refine_total_num_pos / float(num_gpus), 1.0)

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[refine_foreground_idxs, refine_gt_classes[refine_foreground_idxs]] = 1

        # logits loss
        cls_loss = sigmoid_focal_loss_jit(
            pred_class_logits, gt_classes_target,
            alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum",
        ) / refine_num_pos_avg_per_gpu
        
        init_foreground_targets = init_reg_targets[init_foreground_idxs]
        gt_ratio_1 = (init_foreground_targets[:,0] + init_foreground_targets[:,2]) \
            / (init_foreground_targets[:,1] + init_foreground_targets[:,3])
        gt_ratio_2 = 1 / gt_ratio_1
        gt_ratios = torch.stack((gt_ratio_1,gt_ratio_2), dim = 1)
        gt_ratio = gt_ratios.min(dim=1)[0]
        gt_center_score = compute_centerness_targets(init_reg_targets[init_foreground_idxs], gt_ratio)
        
        # average sum_centerness_targets from all gpus,
        # which is used to normalize centerness-weighed reg loss
        sum_centerness_targets_avg_per_gpu = \
            reduce_sum(gt_center_score.sum()).item() / float(num_gpus)
        reg_loss_init = iou_loss(
            pred_box_reg_init[init_foreground_idxs], init_reg_targets[init_foreground_idxs], gt_center_score,
            loss_type=self.iou_loss_type
        ) / sum_centerness_targets_avg_per_gpu

        coords_norm_refine = strides[refine_foreground_idxs].unsqueeze(-1) * 4
        reg_loss = smooth_l1_loss(
            pred_box_reg[refine_foreground_idxs] / coords_norm_refine,
            refine_reg_targets[refine_foreground_idxs] / coords_norm_refine,
            0.11, reduction="sum") / max(1, refine_num_pos_avg_per_gpu)
        #        reg_loss = iou_loss(
        #            pred_box_reg[refine_foreground_idxs], refine_reg_targets[refine_foreground_idxs], 1,
        #            loss_type=self.iou_loss_type
        #        ) / sum_centerness_targets_avg_per_gpu
        centerness_loss = F.binary_cross_entropy_with_logits(
            torch.pow(torch.abs(pred_center_score[init_foreground_idxs]), pred_ratio[init_foreground_idxs]), gt_center_score, reduction='sum'
        ) / init_num_pos_avg_per_gpu

        return dict(cls_loss=cls_loss, reg_loss_init=reg_loss_init, reg_loss=reg_loss, centerness_loss=centerness_loss)
Beispiel #4
0
    def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes):
        """
        Args:
            anchors (list[Boxes]): a list of #feature level Boxes
            gt_labels, gt_boxes: see output of :meth:`RetinaNet.label_anchors`.
                Their shapes are (N, R) and (N, R, 4), respectively, where R is
                the total number of anchors across levels, i.e. sum(Hi x Wi x Ai)
            pred_logits, pred_anchor_deltas: both are list[Tensor]. Each element in the
                list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4).
                Where K is the number of classes used in `pred_logits`.

        Returns:
            dict[str, Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        num_images = len(gt_labels)
        gt_labels = torch.stack(gt_labels)  # (N, R)
        anchors = type(anchors[0]).cat(anchors).tensor  # (R, 4)
        gt_anchor_deltas = [self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes]
        gt_anchor_deltas = torch.stack(gt_anchor_deltas)  # (N, R, 4)

        valid_mask = gt_labels >= 0
        pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
        num_pos_anchors = pos_mask.sum().item()
        get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images)
        self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
            1 - self.loss_normalizer_momentum
        ) * max(num_pos_anchors, 1)

        # classification and regression loss
        gt_labels_target = F.one_hot(gt_labels[valid_mask], num_classes=self.num_classes + 1)[
            :, :-1
        ]  # no loss for the last (background) class
        loss_cls = sigmoid_focal_loss_jit(
            cat(pred_logits, dim=1)[valid_mask],
            gt_labels_target.to(pred_logits[0].dtype),
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        )

        if self.box_reg_loss_type == "smooth_l1":
            loss_box_reg = smooth_l1_loss(
                cat(pred_anchor_deltas, dim=1)[pos_mask],
                gt_anchor_deltas[pos_mask],
                beta=self.smooth_l1_beta,
                reduction="sum",
            )
        elif self.box_reg_loss_type == "giou":
            pred_boxes = [
                self.box2box_transform.apply_deltas(k, anchors)
                for k in cat(pred_anchor_deltas, dim=1)
            ]
            loss_box_reg = giou_loss(
                torch.stack(pred_boxes)[pos_mask], torch.stack(gt_boxes)[pos_mask], reduction="sum"
            )
        else:
            raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")

        return {
            "loss_cls": loss_cls / self.loss_normalizer,
            "loss_box_reg": loss_box_reg / self.loss_normalizer,
        }
Beispiel #5
0
    def forward(self, features, gt_instances=None):
        for i, f in enumerate(self.in_features):
            if i == 0:
                x = self.refine[i](features[f])
            else:
                x_p = self.refine[i](features[f])

                target_h, target_w = x.size()[2:]
                h, w = x_p.size()[2:]
                assert target_h % h == 0
                assert target_w % w == 0
                factor_h, factor_w = target_h // h, target_w // w
                assert factor_h == factor_w
                x_p = F.interpolate(x_p,
                                    scale_factor=factor_h,
                                    mode='bilinear',
                                    align_corners=True)
                # x_p = aligned_bilinear(x_p, factor_h)
                x = x + x_p

        mask_feats = self.tower(x)

        if self.num_outputs == 0:
            mask_feats = mask_feats[:, :self.num_outputs]

        losses = {}
        # auxiliary thing semantic loss
        if self.training and self.sem_loss_on:
            logits_pred = self.logits(
                self.seg_head(features[self.in_features[0]]))

            # compute semantic targets
            semantic_targets = []
            for per_im_gt in gt_instances:
                h, w = per_im_gt.gt_bitmasks_full.size()[-2:]
                areas = per_im_gt.gt_bitmasks_full.sum(dim=-1).sum(dim=-1)
                areas = areas[:, None, None].repeat(1, h, w)
                areas[per_im_gt.gt_bitmasks_full == 0] = INF
                areas = areas.permute(1, 2, 0).reshape(h * w, -1)
                min_areas, inds = areas.min(dim=1)
                per_im_sematic_targets = per_im_gt.gt_classes[inds] + 1
                per_im_sematic_targets[min_areas == INF] = 0
                per_im_sematic_targets = per_im_sematic_targets.reshape(h, w)
                semantic_targets.append(per_im_sematic_targets)

            semantic_targets = torch.stack(semantic_targets, dim=0)

            # resize target to reduce memory
            semantic_targets = semantic_targets[:, None, self.out_stride //
                                                2::self.out_stride,
                                                self.out_stride //
                                                2::self.out_stride]

            # prepare one-hot targets
            num_classes = logits_pred.size(1)
            class_range = torch.arange(num_classes,
                                       dtype=logits_pred.dtype,
                                       device=logits_pred.device)[:, None,
                                                                  None]
            class_range = class_range + 1
            one_hot = (semantic_targets == class_range).float()
            num_pos = (one_hot > 0).sum().float().clamp(min=1.0)

            loss_sem = sigmoid_focal_loss_jit(
                logits_pred,
                one_hot,
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            ) / num_pos
            losses['loss_sem'] = loss_sem

        return mask_feats, losses
Beispiel #6
0
    def fcos_losses(self, instances):
        num_classes = instances.logits_pred.size(1)
        assert num_classes == self.num_classes

        labels = instances.labels.flatten()
        gt_object = instances.gt_inds

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        neg_inds = torch.nonzero(labels == num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(instances.logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            instances.logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="none",
        )  #/ num_pos_avg

        positive_diff = (
            1 - instances.logits_pred[class_target == 1].sigmoid()).abs()
        negative_diff = (
            0 - instances.logits_pred[class_target == 0].sigmoid()).abs()

        positive_mean = positive_diff.mean().detach()
        positive_std = positive_diff.std().detach()

        negative_mean = negative_diff.mean().detach()
        negative_std = negative_diff.std().detach()

        upper_true_loss = class_loss.flatten()[(class_target == 1).flatten()][
            (positive_diff >
             (positive_mean + positive_std))].sum() / num_pos_avg
        under_true_loss = class_loss.flatten()[(class_target == 1).flatten()][
            (positive_diff <=
             (positive_mean + positive_std))].sum() / num_pos_avg
        upper_false_loss = class_loss.flatten()[(class_target == 0).flatten()][
            (negative_diff >
             (negative_mean + negative_std))].sum() / num_pos_avg
        under_false_loss = class_loss.flatten()[(class_target == 0).flatten()][
            (negative_diff <=
             (negative_mean + negative_std))].sum() / num_pos_avg

        storage = get_event_storage()
        if storage.iter % 20 == 0:
            logger.info(
                "upper_true {}, under_true {} upper_false {} under_false {}".
                format((positive_diff > positive_mean + positive_std).sum(),
                       (positive_diff <= positive_mean + positive_std).sum(),
                       (negative_diff > negative_mean + negative_std).sum(),
                       (negative_diff <= negative_mean + negative_std).sum()))

        instances = instances[pos_inds]
        instances.pos_inds = pos_inds

        #assert (instances.gt_inds.unique() != gt_object.unique()).sum() == 0

        ctrness_targets = compute_ctrness_targets(instances.reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        loss_denorm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)
        instances.gt_ctrs = ctrness_targets

        if pos_inds.numel() > 0:
            reg_loss = self.loc_loss_func(instances.reg_pred,
                                          instances.reg_targets,
                                          ctrness_targets) / loss_denorm

            ctrness_loss = torch.nn.MSELoss(reduction="sum")(
                instances.ctrness_pred.sigmoid(),
                ctrness_targets) / num_pos_avg
        else:
            reg_loss = instances.reg_pred.sum() * 0
            ctrness_loss = instances.ctrness_pred.sum() * 0

        losses = {
            "loss_upper_true_cls": upper_true_loss,
            "loss_under_true_cls": under_true_loss,
            "loss_upper_false_cls": upper_false_loss,
            "loss_under_false_cls": under_false_loss,
            "loss_fcos_loc": reg_loss,
            "loss_fcos_ctr": ctrness_loss,
            #"loss_negative_identity_mean": negative_identity_mean_loss,
            #"loss_negative_identity_std": negative_identity_std_loss,
            #"loss_positive_identity": positive_identity_loss,
        }
        extras = {"instances": instances, "loss_denorm": loss_denorm}
        return extras, losses
Beispiel #7
0
    def MEInst_losses(self, labels, reg_targets, logits_pred, reg_pred,
                      ctrness_pred, mask_pred, mask_targets):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]
        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(reg_pred, reg_targets,
                                 ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        if self.loss_on_mask:
            # n_components predictions --> m*m mask predictions without sigmoid
            # as sigmoid function is combined in loss.
            mask_pred = self.mask_encoding.decoder(mask_pred, is_train=True)
            mask_loss = self.mask_loss_func(mask_pred, mask_targets)
            mask_loss = mask_loss.sum(1) * ctrness_targets
            mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
        else:
            # m*m mask labels --> n_components encoding labels
            mask_targets = self.mask_encoding.encoder(mask_targets)
            if self.mask_loss_type == 'mse':
                mask_loss = self.mask_loss_func(mask_pred, mask_targets)
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask,
                                                  1.0)
            else:
                raise NotImplementedError

        losses = {
            "loss_MEInst_cls": class_loss,
            "loss_MEInst_loc": reg_loss,
            "loss_MEInst_ctr": ctrness_loss,
            "loss_MEInst_mask": mask_loss,
        }
        return losses, {}
Beispiel #8
0
    def fcos_losses(self, instances):
        num_classes = instances.logits_pred.size(1)
        assert num_classes == self.num_classes

        labels = instances.labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)

        num_pos_local = torch.ones_like(pos_inds).sum()
        num_pos_avg = max(reduce_mean(num_pos_local).item(), 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(instances.logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(instances.logits_pred,
                                            class_target,
                                            alpha=self.focal_loss_alpha,
                                            gamma=self.focal_loss_gamma,
                                            reduction="sum")

        if self.loss_normalizer_cls == "moving_fg":
            self.moving_num_fg = self.moving_num_fg_momentum * self.moving_num_fg + (
                1 - self.moving_num_fg_momentum) * num_pos_avg
            class_loss = class_loss / self.moving_num_fg
        elif self.loss_normalizer_cls == "fg":
            class_loss = class_loss / num_pos_avg
        else:
            num_samples_local = torch.ones_like(labels).sum()
            num_samples_avg = max(reduce_mean(num_samples_local).item(), 1.0)
            class_loss = class_loss / num_samples_avg

        class_loss = class_loss * self.loss_weight_cls

        instances = instances[pos_inds]
        instances.pos_inds = pos_inds

        ctrness_targets = compute_ctrness_targets(instances.reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        loss_denorm = max(reduce_mean(ctrness_targets_sum).item(), 1e-6)
        instances.gt_ctrs = ctrness_targets

        if pos_inds.numel() > 0:
            reg_loss = self.loc_loss_func(instances.reg_pred,
                                          instances.reg_targets,
                                          ctrness_targets) / loss_denorm

            ctrness_loss = F.binary_cross_entropy_with_logits(
                instances.ctrness_pred, ctrness_targets,
                reduction="sum") / num_pos_avg
        else:
            reg_loss = instances.reg_pred.sum() * 0
            ctrness_loss = instances.ctrness_pred.sum() * 0

        losses = {
            "loss_fcos_cls": class_loss,
            "loss_fcos_loc": reg_loss,
            "loss_fcos_ctr": ctrness_loss
        }
        extras = {"instances": instances, "loss_denorm": loss_denorm}
        return extras, losses
Beispiel #9
0
def fcose_losses(
        labels,
        reg_targets,
        ext_targets,
        logits_pred,
        reg_pred,
        ext_pred,
        ctrness_pred,
        focal_loss_alpha,
        focal_loss_gamma,
        iou_loss,
        ext_loss
):
    num_classes = logits_pred.size(1)
    labels = labels.flatten()

    pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
    num_pos_local = pos_inds.numel()
    num_gpus = get_world_size()
    total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
    num_pos_avg = max(total_num_pos / num_gpus, 1.0)

    # prepare one_hot
    class_target = torch.zeros_like(logits_pred)
    class_target[pos_inds, labels[pos_inds]] = 1    # background-0; C binary cls

    class_loss = sigmoid_focal_loss_jit(
        logits_pred,
        class_target,
        alpha=focal_loss_alpha,
        gamma=focal_loss_gamma,
        reduction="sum",
    ) / num_pos_avg

    reg_pred = reg_pred[pos_inds]
    reg_targets = reg_targets[pos_inds]
    ext_pred = ext_pred[pos_inds]
    ext_targets = ext_targets[pos_inds]
    ctrness_pred = ctrness_pred[pos_inds]

    ctrness_targets = compute_ctrness_targets(reg_targets)
    ctrness_targets_sum = ctrness_targets.sum()
    ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

    ext_pt_loss = ext_loss(
        ext_pred,
        ext_targets,
        ctrness_targets
    ) / ctrness_norm

    reg_loss = iou_loss(
        reg_pred,
        reg_targets,
        ctrness_targets
    ) / ctrness_norm

    ctrness_loss = F.binary_cross_entropy_with_logits(
        ctrness_pred,
        ctrness_targets,
        reduction="sum"
    ) / num_pos_avg

    losses = {
        "loss_fcos_cls": class_loss,
        "loss_fcos_loc": reg_loss,
        "loss_fcos_ctr": ctrness_loss,
        "loss_ext_pts": ext_pt_loss
    }
    return losses, {}
Beispiel #10
0
    def forward(self, indices, gt_instances, anchors, pred_class_logits,
                pred_anchor_deltas):
        pred_class_logits = cat(pred_class_logits,
                                dim=1).view(-1, self.num_classes)
        pred_anchor_deltas = cat(pred_anchor_deltas, dim=1).view(-1, 4)

        anchors = [Boxes.cat(anchors_i) for anchors_i in anchors]
        N = len(anchors)
        # list[Tensor(R, 4)], one for each image
        all_anchors = Boxes.cat(anchors).tensor
        # Boxes(Tensor(N*R, 4))
        predicted_boxes = self.box2box_transform.apply_deltas(
            pred_anchor_deltas, all_anchors)
        predicted_boxes = predicted_boxes.reshape(N, -1, 4)

        ious = []
        pos_ious = []
        for i in range(N):
            src_idx, tgt_idx = indices[i]
            iou = box_iou(predicted_boxes[i, ...],
                          gt_instances[i].gt_boxes.tensor)
            if iou.numel() == 0:
                max_iou = iou.new_full((iou.size(0), ), 0)
            else:
                max_iou = iou.max(dim=1)[0]
            a_iou = box_iou(anchors[i].tensor, gt_instances[i].gt_boxes.tensor)
            if a_iou.numel() == 0:
                pos_iou = a_iou.new_full((0, ), 0)
            else:
                pos_iou = a_iou[src_idx, tgt_idx]
            ious.append(max_iou)
            pos_ious.append(pos_iou)
        ious = torch.cat(ious)
        ignore_idx = ious > self.neg_ignore_thresh
        pos_ious = torch.cat(pos_ious)
        pos_ignore_idx = pos_ious < self.pos_ignore_thresh

        src_idx = torch.cat([
            src + idx * anchors[0].tensor.shape[0]
            for idx, (src, _) in enumerate(indices)
        ])
        gt_classes = torch.full(pred_class_logits.shape[:1],
                                self.num_classes,
                                dtype=torch.int64,
                                device=pred_class_logits.device)
        gt_classes[ignore_idx] = -1
        target_classes_o = torch.cat(
            [t.gt_classes[J] for t, (_, J) in zip(gt_instances, indices)])
        target_classes_o[pos_ignore_idx] = -1
        gt_classes[src_idx] = target_classes_o

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        if comm.get_world_size() > 1:
            dist.all_reduce(num_foreground)
        num_foreground = num_foreground * 1.0 / comm.get_world_size()

        # cls loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        )
        # reg loss
        target_boxes = torch.cat(
            [t.gt_boxes.tensor[i] for t, (_, i) in zip(gt_instances, indices)],
            dim=0)
        target_boxes = target_boxes[~pos_ignore_idx]
        matched_predicted_boxes = predicted_boxes.reshape(
            -1, 4)[src_idx[~pos_ignore_idx]]
        loss_box_reg = giou_loss(matched_predicted_boxes,
                                 target_boxes,
                                 reduction="sum")

        return {
            "loss_cls": loss_cls / max(1, num_foreground),
            "loss_box_reg": loss_box_reg / max(1, num_foreground),
        }
Beispiel #11
0
    def losses(self, indices, gt_instances, anchors, pred_class_logits,
               pred_anchor_deltas):
        pred_class_logits = cat(pred_class_logits,
                                dim=1).view(-1, self.num_classes)
        pred_anchor_deltas = cat(pred_anchor_deltas, dim=1).view(-1, 4)

        anchors = [Boxes.cat(anchors_i) for anchors_i in anchors]
        N = len(anchors)
        # list[Tensor(R, 4)], one for each image
        all_anchors = Boxes.cat(anchors).tensor
        # Boxes(Tensor(N*R, 4))
        predicted_boxes = self.box2box_transform.apply_deltas(
            pred_anchor_deltas, all_anchors)
        predicted_boxes = predicted_boxes.reshape(N, -1, 4)

        # We obtain positive anchors by choosing gt boxes' k nearest anchors
        # and leave the rest to be negative anchors. However, there may
        # exist negative anchors that have similar distances with the chosen
        # positives. These negatives may cause ambiguity for model training
        # if we just set them as negatives. Given that we want the model's
        # predict boxes on negative anchors to have low IoU with gt boxes,
        # we set a threshold on the IoU between predicted boxes and gt boxes
        # instead of the IoU between anchor boxes and gt boxes.
        ious = []
        pos_ious = []
        for i in range(N):
            src_idx, tgt_idx = indices[i]
            iou = box_iou(predicted_boxes[i, ...],
                          gt_instances[i].gt_boxes.tensor)
            if iou.numel() == 0:
                max_iou = iou.new_full((iou.size(0), ), 0)
            else:
                max_iou = iou.max(dim=1)[0]
            a_iou = box_iou(anchors[i].tensor, gt_instances[i].gt_boxes.tensor)
            if a_iou.numel() == 0:
                pos_iou = a_iou.new_full((0, ), 0)
            else:
                pos_iou = a_iou[src_idx, tgt_idx]
            ious.append(max_iou)
            pos_ious.append(pos_iou)
        ious = torch.cat(ious)
        ignore_idx = ious > self.neg_ignore_thresh
        pos_ious = torch.cat(pos_ious)
        pos_ignore_idx = pos_ious < self.pos_ignore_thresh

        src_idx = torch.cat([
            src + idx * anchors[0].tensor.shape[0]
            for idx, (src, _) in enumerate(indices)
        ])
        gt_classes = torch.full(pred_class_logits.shape[:1],
                                self.num_classes,
                                dtype=torch.int64,
                                device=pred_class_logits.device)
        gt_classes[ignore_idx] = -1
        target_classes_o = torch.cat(
            [t.gt_classes[J] for t, (_, J) in zip(gt_instances, indices)])
        target_classes_o[pos_ignore_idx] = -1
        gt_classes[src_idx] = target_classes_o

        valid_idxs = gt_classes >= 0
        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        num_foreground = foreground_idxs.sum()

        gt_classes_target = torch.zeros_like(pred_class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        if comm.get_world_size() > 1:
            dist.all_reduce(num_foreground)
        num_foreground = num_foreground * 1.0 / comm.get_world_size()

        # cls loss
        loss_cls = sigmoid_focal_loss_jit(
            pred_class_logits[valid_idxs],
            gt_classes_target[valid_idxs],
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        )
        # reg loss
        target_boxes = torch.cat(
            [t.gt_boxes.tensor[i] for t, (_, i) in zip(gt_instances, indices)],
            dim=0)
        target_boxes = target_boxes[~pos_ignore_idx]
        matched_predicted_boxes = predicted_boxes.reshape(
            -1, 4)[src_idx[~pos_ignore_idx]]
        loss_box_reg = giou_loss(matched_predicted_boxes,
                                 target_boxes,
                                 reduction="sum")

        return {
            "loss_cls": loss_cls / max(1, num_foreground),
            "loss_box_reg": loss_box_reg / max(1, num_foreground),
        }
Beispiel #12
0
    def losses(self, center_pts, cls_outs, pts_outs_init, pts_outs_refine,
               targets):
        """
        Args:
            center_pts: (list[list[Tensor]]): a list of N=#image elements. Each
                is a list of #feature level tensors. The tensors contains
                shifts of this image on the specific feature level.
            cls_outs: List[Tensor], each item in list with
                shape:[N, num_classes, H, W]
            pts_outs_init: List[Tensor], each item in list with
                shape:[N, num_points*2, H, W]
            pts_outs_refine: List[Tensor], each item in list with
            shape:[N, num_points*2, H, W]
            targets: (list[Instances]): a list of N `Instances`s. The i-th
                `Instances` contains the ground-truth per-instance annotations
                for the i-th input image.
                Specify `targets` during training only.
        Returns:
            dict[str:Tensor]:
                mapping from a named loss to scalar tensor
        """
        featmap_sizes = [featmap.size()[-2:] for featmap in cls_outs]
        assert len(featmap_sizes) == len(center_pts[0])

        pts_dim = 2 * self.num_points

        cls_outs = [
            cls_out.permute(0, 2, 3, 1).reshape(cls_out.size(0), -1,
                                                self.num_classes)
            for cls_out in cls_outs
        ]
        pts_outs_init = [
            pts_out_init.permute(0, 2, 3, 1).reshape(pts_out_init.size(0), -1,
                                                     pts_dim)
            for pts_out_init in pts_outs_init
        ]
        pts_outs_refine = [
            pts_out_refine.permute(0, 2, 3, 1).reshape(pts_out_refine.size(0),
                                                       -1, pts_dim)
            for pts_out_refine in pts_outs_refine
        ]

        cls_outs = torch.cat(cls_outs, dim=1)
        pts_outs_init = torch.cat(pts_outs_init, dim=1)
        pts_outs_refine = torch.cat(pts_outs_refine, dim=1)

        pts_strides = []
        for i, s in enumerate(center_pts[0]):
            pts_strides.append(
                cls_outs.new_full((s.size(0), ), self.fpn_strides[i]))
        pts_strides = torch.cat(pts_strides, dim=0)

        center_pts = [
            torch.cat(c_pts, dim=0).to(cls_outs.device) for c_pts in center_pts
        ]

        pred_cls = []
        pred_init = []
        pred_refine = []

        target_cls = []
        target_init = []
        target_refine = []

        num_pos_init = 0
        num_pos_refine = 0

        for img, (per_center_pts, cls_prob, pts_init, pts_refine,
                  per_targets) in enumerate(
                      zip(center_pts, cls_outs, pts_outs_init, pts_outs_refine,
                          targets)):
            assert per_center_pts.shape[:-1] == cls_prob.shape[:-1]

            gt_bboxes = per_targets.gt_boxes.to(cls_prob.device)
            gt_labels = per_targets.gt_classes.to(cls_prob.device)

            pts_init_bbox_targets, pts_init_labels_targets = \
                self.point_targets(per_center_pts, pts_strides, gt_bboxes.tensor, gt_labels)

            # per_center_pts, shape:[N, 18]
            per_center_pts_repeat = per_center_pts.repeat(1, self.num_points)

            normalize_term = self.point_base_scale * pts_strides
            normalize_term = normalize_term.reshape(-1, 1)

            # bbox_center = torch.cat([per_center_pts, per_center_pts], dim=1)
            per_pts_strides = pts_strides.reshape(-1, 1)
            pts_init_coordinate = pts_init * per_pts_strides + \
                                  per_center_pts_repeat
            init_bbox_pred = self.pts_to_bbox(pts_init_coordinate)

            foreground_idxs = (pts_init_labels_targets >= 0) & \
                              (pts_init_labels_targets != self.num_classes)

            pred_init.append(init_bbox_pred[foreground_idxs] /
                             normalize_term[foreground_idxs])
            target_init.append(pts_init_bbox_targets[foreground_idxs] /
                               normalize_term[foreground_idxs])
            num_pos_init += foreground_idxs.sum()

            # A another way to convert predicted offset to bbox
            # bbox_pred_init = self.pts_to_bbox(pts_init.detach()) * \
            #     per_pts_strides
            # init_bbox_pred = bbox_center + bbox_pred_init

            pts_refine_bbox_targets, pts_refine_labels_targets = \
                self.bbox_targets(init_bbox_pred, gt_bboxes, gt_labels)

            pts_refine_coordinate = pts_refine * per_pts_strides + per_center_pts_repeat
            refine_bbox_pred = self.pts_to_bbox(pts_refine_coordinate)

            # bbox_pred_refine = self.pts_to_bbox(pts_refine) * per_pts_strides
            # refine_bbox_pred = bbox_center + bbox_pred_refine

            foreground_idxs = (pts_refine_labels_targets >= 0) & \
                              (pts_refine_labels_targets != self.num_classes)

            pred_refine.append(refine_bbox_pred[foreground_idxs] /
                               normalize_term[foreground_idxs])
            target_refine.append(pts_refine_bbox_targets[foreground_idxs] /
                                 normalize_term[foreground_idxs])
            num_pos_refine += foreground_idxs.sum()

            gt_classes_target = torch.zeros_like(cls_prob)
            gt_classes_target[foreground_idxs,
                              pts_refine_labels_targets[foreground_idxs]] = 1
            pred_cls.append(cls_prob)
            target_cls.append(gt_classes_target)

        pred_cls = torch.cat(pred_cls, dim=0)
        pred_init = torch.cat(pred_init, dim=0)
        pred_refine = torch.cat(pred_refine, dim=0)

        target_cls = torch.cat(target_cls, dim=0)
        target_init = torch.cat(target_init, dim=0)
        target_refine = torch.cat(target_refine, dim=0)

        loss_cls = sigmoid_focal_loss_jit(
            pred_cls,
            target_cls,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum") / max(
                1, num_pos_refine.item()) * self.loss_cls_weight

        loss_pts_init = smooth_l1_loss(
            pred_init, target_init, beta=0.11, reduction='sum') / max(
                1, num_pos_init.item()) * self.loss_loc_init_weight

        loss_pts_refine = smooth_l1_loss(
            pred_refine, target_refine, beta=0.11, reduction='sum') / max(
                1, num_pos_refine.item()) * self.loss_loc_refine_weight

        return {
            "loss_cls": loss_cls,
            "loss_pts_init": loss_pts_init,
            "loss_pts_refine": loss_pts_refine
        }
Beispiel #13
0
    def fcos_losses(
        self,
        labels,
        reg_targets,
        logits_pred,
        reg_pred,
        ctrness_pred,
        controllers_pred,
        focal_loss_alpha,
        focal_loss_gamma,
        iou_loss,
        matched_idxes,
        im_idxes,
        locations,
    ):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = (sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=focal_loss_alpha,
            gamma=focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg)

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        controllers_pred = controllers_pred[pos_inds]
        matched_idxes = matched_idxes[pos_inds]
        im_idxes = im_idxes[pos_inds]
        locations = locations[pos_inds]

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = iou_loss(reg_pred, reg_targets,
                            ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        # for CondInst
        batch_ins = pos_inds.shape[0]
        N, C, h, w = self.masks.shape
        center_x = torch.clamp(locations[:, 0], min=0, max=w - 1).long()
        center_y = torch.clamp(locations[:, 1], min=0, max=h - 1).long()
        x_range = torch.linspace(-1, 1, w, device=self.masks.device)
        y_range = torch.linspace(-1, 1, h, device=self.masks.device)
        y, x = torch.meshgrid(y_range, x_range)
        x = x.unsqueeze(0).unsqueeze(0)
        y = y.unsqueeze(0).unsqueeze(0)
        grid = torch.cat([x, y], 1)
        offset_x = x_range[center_x].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        offset_y = y_range[center_y].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        offset_xy = torch.cat([offset_x, offset_y], 1)
        coords_feat = grid - offset_xy
        masks_feat = self.masks
        r_h = int(h * self.strides[0])
        r_w = int(w * self.strides[0])
        targets_masks = [
            target_im.gt_masks.tensor for target_im in self.gt_instances
        ]
        masks_t = self.prepare_masks(h, w, r_h, r_w, targets_masks)
        mask_loss = masks_feat[0].new_tensor(0.0)
        batch_ins = im_idxes.shape[0]
        # for each image
        for i in range(N):
            inds = (im_idxes == i).nonzero().flatten()
            ins_num = inds.shape[0]
            if ins_num > 0:
                controllers = controllers_pred[inds]
                coord_feat = coords_feat[inds]
                mask_feat = masks_feat[None, i]
                mask_feat = torch.cat([mask_feat] * ins_num, dim=0)
                comb_feat = torch.cat((mask_feat, coord_feat),
                                      dim=1).view(1, -1, h, w)
                weight1, bias1, weight2, bias2, weight3, bias3 = torch.split(
                    controllers, [80, 8, 64, 8, 8, 1], dim=1)
                bias1, bias2, bias3 = bias1.flatten(), bias2.flatten(
                ), bias3.flatten()
                weight1 = weight1.reshape(-1, 8, 10).reshape(
                    -1, 10).unsqueeze(-1).unsqueeze(-1)
                weight2 = weight2.reshape(-1, 8, 8).reshape(
                    -1, 8).unsqueeze(-1).unsqueeze(-1)
                weight3 = weight3.unsqueeze(-1).unsqueeze(-1)
                conv1 = F.conv2d(comb_feat, weight1, bias1,
                                 groups=ins_num).relu()
                conv2 = F.conv2d(conv1, weight2, bias2, groups=ins_num).relu()
                masks_per_image = F.conv2d(conv2,
                                           weight3,
                                           bias3,
                                           groups=ins_num)
                masks_per_image = aligned_bilinear(
                    masks_per_image, self.strides[0])[0].sigmoid()
                for j in range(ins_num):
                    ind = inds[j]
                    mask_gt = masks_t[i][matched_idxes[ind]].float()
                    mask_pred = masks_per_image[j]
                    mask_loss += self.dice_loss(mask_pred, mask_gt)

        if batch_ins > 0:
            mask_loss = mask_loss / batch_ins

        losses = {
            "loss_fcos_cls": class_loss,
            "loss_fcos_loc": reg_loss,
            "loss_fcos_ctr": ctrness_loss,
            "loss_mask": mask_loss,
        }
        return losses, {}
Beispiel #14
0
def FCOSLosses(cls_scores, bbox_preds, centernesses, labels, bbox_targets,
               reg_loss, cfg):
    """
    Arguments:
        cls_scores, bbox_preds, centernesses: Same as the output of :meth:`FCOSHead.forward`
        labels, bbox_targets: Same as the output of :func:`FCOSTargets`

    Returns:
        losses (dict[str: Tensor]): A dict mapping from loss name to loss value.
    """
    # fmt: off
    num_classes = cfg.MODEL.FCOS.NUM_CLASSES
    focal_loss_alpha = cfg.MODEL.FCOS.LOSS_ALPHA
    focal_loss_gamma = cfg.MODEL.FCOS.LOSS_GAMMA
    # fmt: on

    # Collect all logits and regression predictions over feature maps
    # and images to arrive at the same shape as the labels and targets
    # The final ordering is L, N, H, W from slowest to fastest axis.
    flatten_cls_scores = cat(
        [
            # Reshape: (N, C, Hi, Wi) -> (N, Hi, Wi, C) -> (N*Hi*Wi, C)
            cls_score.permute(0, 2, 3, 1).reshape(-1, num_classes)
            for cls_score in cls_scores
        ],
        dim=0)

    flatten_bbox_preds = cat(
        [
            # Reshape: (N, 4, Hi, Wi) -> (N, Hi, Wi, 4) -> (N*Hi*Wi, 4)
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ],
        dim=0)
    flatten_centernesses = cat(
        [
            # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,)
            centerness.reshape(-1) for centerness in centernesses
        ],
        dim=0)

    # flatten classification and regression targets.
    flatten_labels = cat(labels)
    flatten_bbox_targets = cat(bbox_targets)

    # retain indices of positive predictions.
    pos_inds = torch.nonzero(flatten_labels != num_classes).squeeze(1)
    num_pos = max(len(pos_inds), 1.0)

    # prepare one_hot label.
    class_target = torch.zeros_like(flatten_cls_scores)
    class_target[pos_inds, flatten_labels[pos_inds]] = 1

    # classification loss: Focal loss
    loss_cls = sigmoid_focal_loss_jit(
        flatten_cls_scores,
        class_target,
        alpha=focal_loss_alpha,
        gamma=focal_loss_gamma,
        reduction="sum",
    ) / num_pos

    # compute regression loss and centerness loss only for positive samples.
    pos_bbox_preds = flatten_bbox_preds[pos_inds]
    pos_centernesses = flatten_centernesses[pos_inds]
    pos_bbox_targets = flatten_bbox_targets[pos_inds]

    # compute centerness targets.
    pos_centerness_targets = compute_centerness_targets(pos_bbox_targets)
    centerness_norm = max(pos_centerness_targets.sum(), 1e-6)

    # regression loss: IoU loss
    loss_bbox = reg_loss(pos_bbox_preds,
                         pos_bbox_targets,
                         weight=pos_centerness_targets) / centerness_norm

    # centerness loss: Binary CrossEntropy loss
    loss_centerness = F.binary_cross_entropy_with_logits(
        pos_centernesses, pos_centerness_targets, reduction="sum") / num_pos

    # final loss dict.
    losses = dict(loss_fcos_cls=loss_cls,
                  loss_fcos_loc=loss_bbox,
                  loss_fcos_ctr=loss_centerness)
    return losses
Beispiel #15
0
    def fcos_losses(
        self,
        labels,
        reg_targets,
        logits_pred,
        reg_pred,
        ctrness_pred,
        coeffs_pred,
        protos,
        focal_loss_alpha,
        focal_loss_gamma,
        iou_loss,
        matched_idxes,
        im_idxes
    ):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=focal_loss_alpha,
            gamma=focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)
        
        reg_loss = iou_loss(
            reg_pred,
            reg_targets,
            ctrness_targets
        ) / ctrness_norm
        
        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred,
            ctrness_targets,
            reduction="sum"
        ) / num_pos_avg
        
        # for yolact
        coeffs_pred = coeffs_pred[pos_inds]
        matched_idxes = matched_idxes[pos_inds]
        im_idxes = im_idxes[pos_inds]

        N, _, m_h, m_w = protos.shape
        r_h = int(m_h * self.strides[0])
        r_w = int(m_w * self.strides[0])
        targets_masks = [target_im.gt_masks.tensor for target_im in self.gt_instances]
        masks_t = self.prepare_masks(m_h, m_w, r_h, r_w, targets_masks)
        num_ins = coeffs_pred.shape[0]
        mask_loss = coeffs_pred[0].new_tensor(0.0)
        for i in range(num_ins):
            im_id = im_idxes[i]
            mask_pred = torch.sigmoid((protos[im_id]*coeffs_pred[i].view(self.num_protos,1,1)).sum(dim=0))
            mask_gt = masks_t[im_id][matched_idxes[i]].float()
            mask_loss += self.dice_loss(mask_pred, mask_gt)
        
        if num_ins > 0:
            mask_loss = mask_loss/num_ins

        losses = {
            "loss_fcos_cls": class_loss,
            "loss_fcos_loc": reg_loss,
            "loss_fcos_ctr": ctrness_loss,
            "loss_mask": mask_loss
        }
        return losses, {}
Beispiel #16
0
    def fcos_losses(self, labels, reg_targets, logits_pred, reg_pred,
                    ctrness_pred, controllers_pred, focal_loss_alpha,
                    focal_loss_gamma, iou_loss, matched_idxes, im_idxes):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=focal_loss_alpha,
            gamma=focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        controllers_pred = controllers_pred[pos_inds]
        matched_idxes = matched_idxes[pos_inds]
        im_idxes = im_idxes[pos_inds]

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = iou_loss(reg_pred, reg_targets,
                            ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        # for CondInst
        N, C, h, w = self.masks.shape
        grid_x = torch.arange(w).view(1, -1).float().repeat(
            h, 1).cuda() / (w - 1) * 2 - 1
        grid_y = torch.arange(h).view(-1, 1).float().repeat(
            1, w).cuda() / (h - 1) * 2 - 1
        x_map = grid_x.view(1, 1, h, w).repeat(N, 1, 1, 1)
        y_map = grid_y.view(1, 1, h, w).repeat(N, 1, 1, 1)
        masks_feat = torch.cat((self.masks, x_map, y_map), dim=1)
        r_h = int(h * self.strides[0])
        r_w = int(w * self.strides[0])

        # seg head
        mask_loss = 0
        '''
        targets_masks = [target_im.gt_masks.tensor for target_im in self.gt_instances]
        masks_t = self.prepare_masks(h, w, r_h, r_w, targets_masks)
        mask_loss = masks_feat[0].new_tensor(0.0)
        batch_ins = im_idxes.shape[0] 
        # for each image
        for i in range(N):
            inds = (im_idxes==i).nonzero().flatten()
            ins_num = inds.shape[0]
            if ins_num > 0:
                controllers = controllers_pred[inds]
                mask_feat = masks_feat[None, i]
                weights1 = controllers[:, :80].reshape(-1,8,10).reshape(-1,10).unsqueeze(-1).unsqueeze(-1)
                bias1 = controllers[:, 80:88].flatten()            
                weights2 = controllers[:, 88:152].reshape(-1,8,8).reshape(-1,8).unsqueeze(-1).unsqueeze(-1)
                bias2 = controllers[:, 152:160].flatten()
                weights3 = controllers[:, 160:168].unsqueeze(-1).unsqueeze(-1)
                bias3 = controllers[:,168:169].flatten()
                conv1 = F.conv2d(mask_feat,weights1,bias1).relu()
                conv2 = F.conv2d(conv1, weights2, bias2, groups = ins_num).relu()
                #masks_per_image = F.conv2d(conv2, weights3, bias3, groups = ins_num)[0].sigmoid()
                masks_per_image = F.conv2d(conv2, weights3, bias3, groups = ins_num) 
                masks_per_image = aligned_bilinear(masks_per_image, self.strides[0])[0].sigmoid()         
                for j in range(ins_num):
                    ind = inds[j]
                    mask_gt = masks_t[i][matched_idxes[ind]].float()
                    mask_pred = masks_per_image[j]
                    mask_loss += self.dice_loss(mask_pred, mask_gt)
            
        if batch_ins > 0:
            mask_loss = mask_loss / batch_ins
        '''

        losses = {
            "loss_fcos_cls": class_loss,
            "loss_fcos_loc": reg_loss,
            "loss_fcos_ctr": ctrness_loss,
            "loss_mask": mask_loss
        }
        return losses, {}
Beispiel #17
0
    def losses(self, anchors, pred_logits, pred_boxes_init, pred_anchor_deltas,
               gt_instances, point_centers, strides):
        """
        Args:
            anchors (list[Boxes]): a list of #feature level Boxes
            gt_labels, gt_boxes: see output of :meth:`RetinaNet.label_anchors`.
                Their shapes are (N, R) and (N, R, 4), respectively, where R is
                the total number of anchors across levels, i.e. sum(Hi x Wi x Ai)
            pred_logits, pred_anchor_deltas: both are list[Tensor]. Each element in the
                list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4).
                Where K is the number of classes used in `pred_logits`.

        Returns:
            dict[str, Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances)
        gt_labels_init, gt_boxes_init = self.get_ground_truth(
            point_centers, strides, gt_instances)

        # Transpose the Hi*Wi*A dimension to the middle:
        pred_logits = [
            permute_to_N_HWA_K(x, self.num_classes) for x in pred_logits
        ]
        pred_anchor_deltas = [
            permute_to_N_HWA_K(x, 4) for x in pred_anchor_deltas
        ]

        num_images = len(gt_labels)
        gt_labels = torch.stack(gt_labels)  # (N, R)
        anchors = type(anchors[0]).cat(anchors).tensor  # (R, 4)
        gt_anchor_deltas = [
            self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes
        ]
        gt_anchor_deltas = torch.stack(gt_anchor_deltas)  # (N, R, 4)

        valid_mask = gt_labels >= 0
        pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
        num_pos_anchors = pos_mask.sum().item()
        get_event_storage().put_scalar("num_pos_anchors",
                                       num_pos_anchors / num_images)
        self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + (
            1 - self.loss_normalizer_momentum) * max(num_pos_anchors, 1)

        # classification and regression loss
        gt_labels_target = F.one_hot(gt_labels[valid_mask],
                                     num_classes=self.num_classes + 1)[:, :-1]
        # no loss for the last (background) class
        loss_cls = sigmoid_focal_loss_jit(
            cat(pred_logits, dim=1)[valid_mask],
            gt_labels_target.to(pred_logits[0].dtype),
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) * self.loss_cls_weight

        init_foreground_idxs = gt_labels_init > 0
        strides = strides[None].repeat(pred_logits[0].shape[0], 1)
        coords_norm_init = strides[init_foreground_idxs].unsqueeze(-1) * 4
        loss_loc_init = smooth_l1_loss(
            pred_boxes_init[init_foreground_idxs] / coords_norm_init,
            gt_boxes_init[init_foreground_idxs] / coords_norm_init,
            beta=0.11,
            reduction="sum",
        ) / max(init_foreground_idxs.sum(), 1)
        if self.box_reg_loss_type == "smooth_l1":
            loss_loc_refine = smooth_l1_loss(
                cat(pred_anchor_deltas, dim=1)[pos_mask],
                gt_anchor_deltas[pos_mask],
                beta=0.11,
                reduction="sum",
            )
        elif self.box_reg_loss_type == "giou":
            pred_boxes = [
                self.box2box_transform.apply_deltas(k, anchors)
                for k in cat(pred_anchor_deltas, dim=1)
            ]
            loss_loc_refine = giou_loss(torch.stack(pred_boxes)[pos_mask],
                                        torch.stack(gt_boxes)[pos_mask],
                                        reduction="sum")
        else:
            raise ValueError(
                f"Invalid bbox reg loss type '{self.box_reg_loss_type}'")

        return {
            "loss_cls":
            loss_cls / self.loss_normalizer,
            "loss_loc_init":
            loss_loc_init * self.loss_loc_init_weight,
            "loss_loc_refine":
            loss_loc_refine / self.loss_normalizer *
            self.loss_loc_refine_weight,
        }
Beispiel #18
0
def fcos_losses(
        labels,
        reg_targets,
        logits_pred,
        reg_pred,
        ctrness_pred,
        focal_loss_alpha,
        focal_loss_gamma,
        iou_loss,
        gt_inds,
):
    num_classes = logits_pred.size(1)
    labels = labels.flatten()#返回一个折叠成一维的数组
    # 提取有类别的特征图中的点
    # 正例点的索引(有 label 的点的索引)
    pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
    # pos_inds : tensor([ 7971,  7972,  7973,  8123,  8124,  8125,  8275,  8276,  8277, 17133,
    # 17134, 17135, 20057, 20058, 20059, 20068, 20069, 20070, 20076, 20077,
    # 20078, 20087, 20088, 20089, 20095, 20096, 20097, 20106, 20107, 20108,
    # 20243, 20244], device='cuda:0')
    num_pos_local = pos_inds.numel()
    num_gpus = get_world_size()
    total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
    num_pos_avg = max(total_num_pos / num_gpus, 1.0)

    # prepare one_hot
    class_target = torch.zeros_like(logits_pred)
    class_target[pos_inds, labels[pos_inds]] = 1

    class_loss = sigmoid_focal_loss_jit(
        logits_pred,
        class_target,
        alpha=focal_loss_alpha,
        gamma=focal_loss_gamma,
        reduction="sum",
    ) / num_pos_avg
    # 根据pos_inds提取正样本
    reg_pred = reg_pred[pos_inds]
    reg_targets = reg_targets[pos_inds]
    ctrness_pred = ctrness_pred[pos_inds]
    gt_inds = gt_inds[pos_inds]

    ctrness_targets = compute_ctrness_targets(reg_targets)
    ctrness_targets_sum = ctrness_targets.sum()
    loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

    if pos_inds.numel() > 0:

        # 计算正例预测的框与真实框的 IOU loss
        #!这里的中心度作为权重输入进去
        reg_loss = iou_loss(
            reg_pred,
            reg_targets,
            ctrness_targets
        ) / loss_denorm
        #计算中心度的损失
        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred,
            ctrness_targets,
            reduction="sum"
        ) / num_pos_avg
    else:
        reg_loss = reg_pred.sum() * 0
        ctrness_loss = ctrness_pred.sum() * 0

    losses = {
        "loss_fcos_cls": class_loss,
        "loss_fcos_loc": reg_loss,
        "loss_fcos_ctr": ctrness_loss
    }
    extras = {
        "pos_inds": pos_inds,
        "gt_inds": gt_inds,
        "gt_ctr": ctrness_targets,
        "loss_denorm": loss_denorm
    }
    return losses, extras
Beispiel #19
0
    def DTInst_losses(self, labels, reg_targets, logits_pred, reg_pred,
                      ctrness_pred, mask_pred, mask_targets):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]
        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(reg_pred, reg_targets,
                                 ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        total_mask_loss = 0.
        dtm_pred_, binary_pred_ = self.mask_encoding.decoder(mask_pred,
                                                             is_train=True)
        code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder(
            mask_targets)
        if self.loss_on_mask:
            if 'mask_mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(dtm_pred_,
                                       dtm_targets,
                                       reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += mask_loss
            if 'weighted_mask_mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(dtm_pred_,
                                       dtm_targets,
                                       reduction='none')
                mask_loss = torch.sum(mask_loss * weight_maps, 1) / torch.sum(
                    weight_maps, 1) * ctrness_targets * self.mask_size**2
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += mask_loss
            if 'mask_difference' in self.mask_loss_type:
                w_ = torch.abs(
                    binary_pred_ * 1. -
                    mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                md_loss = torch.sum(w_, 1) * ctrness_targets
                md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
                total_mask_loss += md_loss
            if 'hd_one_side_binary' in self.mask_loss_type:  # the first attempt, not really accurate
                w_ = torch.abs(
                    binary_pred_ * 1. -
                    mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(w_ * hd_maps, 1) / (torch.sum(
                    w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_two_side_binary' in self.mask_loss_type:  # the first attempt, not really accurate
                w_ = torch.abs(
                    binary_pred_ * 1. -
                    mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    w_ *
                    (torch.clamp(dtm_pred_**2, -0.1, 1.1) +
                     torch.clamp(dtm_targets**2, -0.1, 1)), 1) / (torch.sum(
                         w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_weighted_one_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    dtm_diff * weight_maps * hd_maps,
                    1) / (torch.sum(weight_maps, 1) +
                          1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_weighted_two_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    dtm_diff * weight_maps * (dtm_pred_**2 + dtm_targets**2),
                    1) / (torch.sum(weight_maps, 1) +
                          1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_one_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(dtm_diff * hd_maps,
                                           1) * ctrness_targets
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_two_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    dtm_diff *
                    (torch.clamp(dtm_pred_, -1.1, 1.1)**2 + dtm_targets**2),
                    1) * ctrness_targets
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'contour_dice' in self.mask_loss_type:
                pred_contour = (dtm_pred_ + 0.9 < 0.55) * 1. * (
                    0.5 <= dtm_pred_ + 0.9
                )  # contour pixels with 0.05 tolerance
                target_contour = (dtm_targets < 0.05) * 1. * (dtm_targets <
                                                              0.05)
                # pred_contour = 0.5 <= dtm_pred_ + 0.9 < 0.55  # contour pixels with 0.05 tolerance
                # target_contour = 0. <= dtm_targets < 0.05
                overlap_ = torch.sum(pred_contour * 2. * target_contour, 1)
                union_ = torch.sum(pred_contour**2, 1) + torch.sum(
                    target_contour**2, 1)
                dice_loss = (
                    1. - overlap_ /
                    (union_ + 1e-4)) * ctrness_targets * self.mask_size**2
                dice_loss = dice_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += dice_loss
            if 'mask_dice' in self.mask_loss_type:
                overlap_ = torch.sum(binary_pred_ * 2. * mask_targets, 1)
                union_ = torch.sum(binary_pred_**2, 1) + torch.sum(
                    mask_targets**2, 1)
                dice_loss = (
                    1. - overlap_ /
                    (union_ + 1e-4)) * ctrness_targets * self.mask_size**2
                dice_loss = dice_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += dice_loss
        if self.loss_on_code:
            # m*m mask labels --> n_components encoding labels
            if 'mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(mask_pred,
                                       code_targets,
                                       reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                if self.mask_sparse_weight > 0.:
                    if self.sparsity_loss_type == 'L1':
                        sparsity_loss = torch.sum(torch.abs(mask_pred),
                                                  1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L1':
                        w_ = (
                            torch.abs(code_targets) < 1e-4
                        ) * 1.  # inactive codes, put L1 regularization on them
                        sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L2':
                        w_ = (
                            torch.abs(code_targets) < 1e-4
                        ) * 1.  # inactive codes, put L2 regularization on them
                        sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    else:
                        raise NotImplementedError
                total_mask_loss += mask_loss
            if 'smooth' in self.mask_loss_type:
                mask_loss = F.smooth_l1_loss(mask_pred,
                                             code_targets,
                                             reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'cosine' in self.mask_loss_type:
                mask_loss = loss_cos_sim(mask_pred, code_targets)
                mask_loss = mask_loss * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'kl_softmax' in self.mask_loss_type:
                mask_loss = loss_kl_div_softmax(mask_pred, code_targets)
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss

        losses = {
            "loss_DTInst_cls": class_loss,
            "loss_DTInst_loc": reg_loss,
            "loss_DTInst_ctr": ctrness_loss,
            "loss_DTInst_mask": total_mask_loss
        }
        return losses, {}
    def losses(self, locations, class_logits, center_score, box_reg_init,
               box_reg, gt_instances):
        gt_classes, loc_targets, topk_locations = self.get_ground_truth(
            locations, gt_instances)

        class_logits, box_reg_init, box_reg, center_score = permute_and_concat_v2(
            class_logits, box_reg_init, box_reg, center_score,
            self.num_classes)
        # Shapes: (N x R) and (N x R, 4), (N x R) respectively.

        gt_classes = gt_classes.flatten()
        loc_targets = loc_targets.view(-1, 4)

        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
        pos_inds = torch.nonzero(foreground_idxs).squeeze(1)

        num_gpus = get_num_gpus()
        # sync num_pos from all gpus
        total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()
                                                        ])).item()
        num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0)

        gt_classes_target = torch.zeros_like(class_logits)
        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1

        # logits loss
        cls_loss = sigmoid_focal_loss_jit(
            class_logits,
            gt_classes_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg_per_gpu

        if pos_inds.numel() > 0:
            if self.slender_centerness:
                gt_center_score = compute_slender_centerness_targets(
                    loc_targets[foreground_idxs])
            else:
                gt_center_score = compute_centerness_targets(
                    loc_targets[foreground_idxs])
            # average sum_centerness_targets from all gpus,
            # which is used to normalize centerness-weighed reg loss
            sum_centerness_targets_avg_per_gpu = \
                reduce_sum(gt_center_score.sum()).item() / float(num_gpus)

            topk_locations = topk_locations.view(-1)
            topk_gt_center_score = compute_centerness_targets(
                loc_targets[topk_locations])
            sum_topk_centerness_targets_avg_per_gpu = \
                reduce_sum(topk_gt_center_score.sum()).item() / float(num_gpus)

            loss_loc_init = iou_loss(
                box_reg_init[topk_locations],
                loc_targets[topk_locations],
                topk_gt_center_score,
                loss_type=self.iou_loss_type
            ) / sum_topk_centerness_targets_avg_per_gpu

            loss_loc_refine = iou_loss(box_reg[foreground_idxs],
                                       loc_targets[foreground_idxs],
                                       gt_center_score,
                                       loss_type=self.iou_loss_type
                                       ) / sum_centerness_targets_avg_per_gpu

            centerness_loss = F.binary_cross_entropy_with_logits(
                center_score[foreground_idxs],
                gt_center_score,
                reduction='sum') / num_pos_avg_per_gpu
        else:
            loss_loc_init = box_reg_init[foreground_idxs].sum()
            loss_loc_refine = box_reg[foreground_idxs].sum()
            reduce_sum(center_score[foreground_idxs].new_tensor([0.0]))
            centerness_loss = center_score[foreground_idxs].sum()

        return dict(
            loss_cls=cls_loss * self.loss_cls_weight,
            centerness_loss=centerness_loss * self.loss_cls_weight,
            loss_loc_init=loss_loc_init * self.loss_loc_init_weight,
            loss_loc_refine=loss_loc_refine * self.loss_loc_refine_weight,
        )
Beispiel #21
0
    def losses(self, anchors: List[Boxes], pred_logits: List[Tensor],
               gt_classes: List[Tensor], pred_anchor_deltas: List[Tensor],
               gt_boxes: List[Tensor]) -> Dict[str, float]:
        """
        Args:
            For `gt_classes` and `gt_anchors_deltas` parameters, see
                :meth:`RetinaNet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of anchors across levels, i.e. sum(Hi x Wi x A)
            For `pred_class_logits` and `pred_anchor_deltas`, see
                :meth:`RetinaNetHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls" and "loss_box_reg"
        """
        num_images: int = len(gt_classes)

        # shape(gt_classes) = (N, R)
        gt_classes_tensor: Tensor = torch.stack(gt_classes)

        # shape(anchors) = (R, 4)
        anchors_tensor: Tensor = type(anchors[0]).cat(anchors).tensor
        gt_anchor_deltas: List[Tensor] = [
            self.box2box_transform.get_deltas(anchors_tensor, k)
            for k in gt_boxes
        ]
        # shape(gt_anchor_deltas) = (N, R, 4)
        gt_anchor_deltas_tensor: Tensor = torch.stack(gt_anchor_deltas)

        valid_mask: Tensor = gt_classes_tensor >= 0
        pos_mask: Tensor = (gt_classes_tensor >= 0) & (gt_classes_tensor !=
                                                       self.num_classes)
        num_pos_anchors: int = pos_mask.sum().item()
        get_event_storage().put_scalar("num_pos_anchors",
                                       num_pos_anchors / num_images)
        self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer\
                                + (1 - self.loss_normalizer_momentum) * max(num_pos_anchors, 1)

        # classification and regression loss
        # no loss for the last (background) class --> [:, :-1]
        gt_classes_target: LongTensor = F.one_hot(
            gt_classes_tensor[valid_mask],
            num_classes=self.num_classes + 1)[:, :-1]

        # logits loss
        loss_cls = sigmoid_focal_loss_jit(
            inputs=cat(pred_logits, dim=1)[valid_mask],
            targets=gt_classes_target.to(pred_logits[0].dtype),
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum") / self.loss_normalizer

        # regression loss
        loss_box_reg = smooth_l1_loss(input=cat(pred_anchor_deltas,
                                                dim=1)[pos_mask],
                                      target=gt_anchor_deltas_tensor[pos_mask],
                                      beta=self.smooth_l1_loss_beta,
                                      reduction="sum") / self.loss_normalizer

        return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
Beispiel #22
0
    def classification_losses(self, gt_classes, pred_class_logits):
        """
        Args:
            For `gt_classes` and `gt_anchors_deltas` parameters, see
                :meth:`RetinaNet.get_ground_truth`.
            Their shapes are (N, R) and (N, R, 4), respectively, where R is
            the total number of anchors across levels, i.e. sum(Hi x Wi x A)
            For `pred_class_logits`, see
                :meth:`RetinaNetHead.forward`.

        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls"
        """
        start = 0
        pred_category = pred_class_logits[:, start:start +
                                          self.classification_classes[0]]
        start += self.classification_classes[0]
        pred_part = pred_class_logits[:, start:start +
                                      self.classification_classes[1]]
        start += self.classification_classes[1]
        pred_toward = pred_class_logits[:, start:start +
                                        self.classification_classes[2]]

        valid_idxs = gt_classes[self.classification_tasks[0]][
            1::self.classification_classes[0]] == 1
        data_type = pred_category.dtype
        num_batchs = pred_category.size()[0]
        num_model = valid_idxs.sum()

        valid_category = gt_classes[self.classification_tasks[0]][:] > -1
        # category loss
        if valid_category.sum() > 0:
            if self.activation == 'sigmoid':
                loss_category = sigmoid_focal_loss_jit(
                    pred_category.flatten()[valid_category],
                    gt_classes[self.classification_tasks[0]].to(
                        dtype=data_type)[valid_category],
                    alpha=self.focal_loss_alpha,
                    gamma=self.focal_loss_gamma,
                    reduction="sum",
                ) / max(1,
                        valid_category.sum() / self.classification_classes[0])
            elif self.activation == 'softmax':
                gt_category = torch.argmax(
                    gt_classes[self.classification_tasks[0]].view(
                        num_batchs, -1),
                    dim=1)
                valid_category = valid_category.view(num_batchs,
                                                     -1).sum(dim=1) > 0
                loss_category = F.cross_entropy(
                    pred_category[valid_category],
                    gt_category[valid_category],
                    reduction="sum",
                ) / max(1, valid_category.sum())
            else:
                raise Exception("Not implement classification activation!")
        else:
            loss_category = 0.0

        valid_part = gt_classes[self.classification_tasks[1]][:] > -1
        if valid_part.sum() > 0:
            loss_part = sigmoid_focal_loss_jit(
                pred_part.flatten()[valid_part],
                gt_classes[self.classification_tasks[1]].to(
                    dtype=data_type)[valid_part],
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            ) / max(1,
                    valid_part.sum() / self.classification_classes[1])
        else:
            loss_part = 0.0

        valid_toward = gt_classes[self.classification_tasks[2]][:] > -1
        if valid_toward.sum() > 0:
            loss_toward = sigmoid_focal_loss_jit(
                pred_toward.flatten()[valid_toward],
                gt_classes[self.classification_tasks[2]].to(
                    dtype=data_type)[valid_toward],
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            ) / max(1,
                    valid_toward.sum() / self.classification_classes[2])
        else:
            loss_toward = 0.0

        return {
            "loss_category": loss_category,
            "loss_part": loss_part,
            "loss_toward": loss_toward
        }
Beispiel #23
0
    def fcos_losses(self, instances):
        losses, extras = {}, {}

        # 1. compute the cls loss
        num_classes = instances.logits_pred.size(1)
        assert num_classes == self.num_classes

        labels = instances.labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)

        num_pos_local = torch.ones_like(pos_inds).sum()
        num_pos_avg = max(reduce_mean(num_pos_local).item(), 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(instances.logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(instances.logits_pred,
                                            class_target,
                                            alpha=self.focal_loss_alpha,
                                            gamma=self.focal_loss_gamma,
                                            reduction="sum")

        if self.loss_normalizer_cls == "moving_fg":
            self.moving_num_fg = self.moving_num_fg_momentum * self.moving_num_fg + (
                1 - self.moving_num_fg_momentum) * num_pos_avg
            class_loss = class_loss / self.moving_num_fg
        elif self.loss_normalizer_cls == "fg":
            class_loss = class_loss / num_pos_avg
        else:
            num_samples_local = torch.ones_like(labels).sum()
            num_samples_avg = max(reduce_mean(num_samples_local).item(), 1.0)
            class_loss = class_loss / num_samples_avg

        losses["loss_fcos_cls"] = class_loss * self.loss_weight_cls

        # 2. compute the box regression and quality loss
        instances = instances[pos_inds]
        instances.pos_inds = pos_inds

        ious, gious = compute_ious(instances.reg_pred, instances.reg_targets)

        if self.box_quality == "ctrness":
            ctrness_targets = compute_ctrness_targets(instances.reg_targets)
            instances.gt_ctrs = ctrness_targets

            ctrness_targets_sum = ctrness_targets.sum()
            loss_denorm = max(reduce_mean(ctrness_targets_sum).item(), 1e-6)
            extras["loss_denorm"] = loss_denorm

            reg_loss = self.loc_loss_func(ious, gious,
                                          ctrness_targets) / loss_denorm
            losses["loss_fcos_loc"] = reg_loss

            ctrness_loss = F.binary_cross_entropy_with_logits(
                instances.ctrness_pred, ctrness_targets,
                reduction="sum") / num_pos_avg
            losses["loss_fcos_ctr"] = ctrness_loss
        elif self.box_quality == "iou":
            reg_loss = self.loc_loss_func(ious, gious) / num_pos_avg
            losses["loss_fcos_loc"] = reg_loss

            quality_loss = F.binary_cross_entropy_with_logits(
                instances.ctrness_pred, ious.detach(),
                reduction="sum") / num_pos_avg
            losses["loss_fcos_iou"] = quality_loss
        else:
            raise NotImplementedError

        extras["instances"] = instances

        return extras, losses
Beispiel #24
0
    def losses(self, pred_logits, pred_init_boxes, pred_refine_boxes,
               gt_init_objectness, gt_init_bboxes, gt_cls: torch.Tensor,
               gt_refine_bboxes, strides):
        """
        Loss computation.
        Args:
            pred_logits: (N, X, C). Classification prediction, where X is the number
                of positions from all feature levels, C is the number of object classes.
            pred_init_boxes: (N, X, 4). Init box prediction.
            pred_refine_boxes: (N, X, 4). Refined box prediction.
            gt_init_objectness: (N, X). Foreground/background classification for initial
                prediction.
            gt_init_bboxes: (N, X, 4). Initial box prediction.
            gt_cls: (N, X), Long. GT for box classification, -1 indicates ignoring.
            gt_refine_bboxes: (N, X, 4). Refined box prediction.
            strides: (X). Scale factor at each position.
        Returns:
            dict[str, Tensor]:
                mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The dict keys are:
                "loss_cls", "loss_localization_init", and "loss_localization_refine".
        """

        valid_idxs = gt_cls >= 0
        foreground_idxs = valid_idxs.logical_and(gt_cls != self.num_classes)
        num_foreground = foreground_idxs.sum().item() / gt_init_bboxes.shape[0]
        get_event_storage().put_scalar("num_foreground", num_foreground)

        gt_cls_target = torch.zeros_like(pred_logits)
        gt_cls_target[foreground_idxs, gt_cls[foreground_idxs]] = 1

        self.loss_normalizer = (
            self.loss_normalizer_momentum * self.loss_normalizer +
            (1 - self.loss_normalizer_momentum) * num_foreground)

        loss_cls = sigmoid_focal_loss_jit(pred_logits[valid_idxs],
                                          gt_cls_target[valid_idxs],
                                          alpha=self.focal_loss_alpha,
                                          gamma=self.focal_loss_gamma,
                                          reduction="sum") / max(
                                              1, self.loss_normalizer)

        init_foreground_idxs = gt_init_objectness > 0
        strides = strides[None].repeat(pred_logits.shape[0], 1)
        coords_norm_init = strides[init_foreground_idxs].unsqueeze(-1) * 4
        loss_localization_init = smooth_l1_loss(
            pred_init_boxes[init_foreground_idxs] / coords_norm_init,
            gt_init_bboxes[init_foreground_idxs] / coords_norm_init,
            0.11,
            reduction='sum') / max(1, gt_init_objectness.sum()) * 0.5

        coords_norm_refine = strides[foreground_idxs].unsqueeze(-1) * 4
        loss_localization_refine = smooth_l1_loss(
            pred_refine_boxes[foreground_idxs] / coords_norm_refine,
            gt_refine_bboxes[foreground_idxs] / coords_norm_refine,
            0.11,
            reduction="sum") / max(1, self.loss_normalizer)

        return {
            "loss_cls": loss_cls,
            "loss_localization_init": loss_localization_init,
            "loss_localization_refine": loss_localization_refine
        }
Beispiel #25
0
def fcos_losses(
    labels,
    reg_targets,
    logits_pred,
    reg_pred,
    ctrness_pred,
    focal_loss_alpha,
    focal_loss_gamma,
    iou_loss,
    gt_inds,
):
    num_classes = logits_pred.size(1)
    labels = labels.flatten()

    pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
    num_pos_local = pos_inds.numel()
    num_gpus = get_world_size()
    total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
    num_pos_avg = max(total_num_pos / num_gpus, 1.0)

    # prepare one_hot
    class_target = torch.zeros_like(logits_pred)
    class_target[pos_inds, labels[pos_inds]] = 1

    class_loss = sigmoid_focal_loss_jit(
        logits_pred,
        class_target,
        alpha=focal_loss_alpha,
        gamma=focal_loss_gamma,
        reduction="sum",
    ) / num_pos_avg

    reg_pred = reg_pred[pos_inds]
    reg_targets = reg_targets[pos_inds]
    ctrness_pred = ctrness_pred[pos_inds]
    gt_inds = gt_inds[pos_inds]

    ctrness_targets = compute_ctrness_targets(reg_targets)
    ctrness_targets_sum = ctrness_targets.sum()
    loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

    if pos_inds.numel() > 0:
        reg_loss = iou_loss(reg_pred, reg_targets,
                            ctrness_targets) / loss_denorm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg
    else:
        reg_loss = reg_pred.sum() * 0
        ctrness_loss = ctrness_pred.sum() * 0

    losses = {
        "loss_fcos_cls": class_loss,
        "loss_fcos_loc": reg_loss,
        "loss_fcos_ctr": ctrness_loss
    }
    extras = {
        "pos_inds": pos_inds,
        "gt_inds": gt_inds,
        "gt_ctr": ctrness_targets,
        "loss_denorm": loss_denorm
    }
    return losses, extras
Beispiel #26
0
    def loss(self, cate_preds, kernel_preds, ins_pred, targets):
        pass
        ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list = targets
        # ins
        ins_labels = [
            torch.cat([
                ins_labels_level_img
                for ins_labels_level_img in ins_labels_level
            ], 0) for ins_labels_level in zip(*ins_label_list)
        ]

        kernel_preds = [[
            kernel_preds_level_img.view(kernel_preds_level_img.shape[0],
                                        -1)[:, grid_orders_level_img]
            for kernel_preds_level_img, grid_orders_level_img in zip(
                kernel_preds_level, grid_orders_level)
        ] for kernel_preds_level, grid_orders_level in zip(
            kernel_preds, zip(*grid_order_list))]
        # generate masks
        ins_pred_list = []
        for b_kernel_pred in kernel_preds:
            b_mask_pred = []
            for idx, kernel_pred in enumerate(b_kernel_pred):

                if kernel_pred.size()[-1] == 0:
                    continue
                cur_ins_pred = ins_pred[idx, ...]
                H, W = cur_ins_pred.shape[-2:]
                N, I = kernel_pred.shape
                cur_ins_pred = cur_ins_pred.unsqueeze(0)
                kernel_pred = kernel_pred.permute(1, 0).view(I, -1, 1, 1)
                cur_ins_pred = F.conv2d(cur_ins_pred, kernel_pred,
                                        stride=1).view(-1, H, W)
                b_mask_pred.append(cur_ins_pred)
            if len(b_mask_pred) == 0:
                b_mask_pred = None
            else:
                b_mask_pred = torch.cat(b_mask_pred, 0)
            ins_pred_list.append(b_mask_pred)

        ins_ind_labels = [
            torch.cat([
                ins_ind_labels_level_img.flatten()
                for ins_ind_labels_level_img in ins_ind_labels_level
            ]) for ins_ind_labels_level in zip(*ins_ind_label_list)
        ]
        flatten_ins_ind_labels = torch.cat(ins_ind_labels)

        num_ins = flatten_ins_ind_labels.sum()

        # dice loss
        loss_ins = []
        for input, target in zip(ins_pred_list, ins_labels):
            if input is None:
                continue
            input = torch.sigmoid(input)
            loss_ins.append(dice_loss(input, target))

        loss_ins_mean = torch.cat(loss_ins).mean()
        loss_ins = loss_ins_mean * self.ins_loss_weight

        # cate
        cate_labels = [
            torch.cat([
                cate_labels_level_img.flatten()
                for cate_labels_level_img in cate_labels_level
            ]) for cate_labels_level in zip(*cate_label_list)
        ]
        flatten_cate_labels = torch.cat(cate_labels)

        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
            for cate_pred in cate_preds
        ]
        flatten_cate_preds = torch.cat(cate_preds)

        # prepare one_hot
        pos_inds = torch.nonzero(
            flatten_cate_labels != self.num_classes).squeeze(1)

        flatten_cate_labels_oh = torch.zeros_like(flatten_cate_preds)
        flatten_cate_labels_oh[pos_inds, flatten_cate_labels[pos_inds]] = 1

        loss_cate = self.focal_loss_weight * sigmoid_focal_loss_jit(
            flatten_cate_preds,
            flatten_cate_labels_oh,
            gamma=self.focal_loss_gamma,
            alpha=self.focal_loss_alpha,
            reduction="sum") / (num_ins + 1)
        return {'loss_ins': loss_ins, 'loss_cate': loss_cate}
    def SMInst_losses(
            self,
            labels,
            reg_targets,
            logits_pred,
            reg_pred,
            ctrness_pred,
            mask_pred,
            mask_targets
    ):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        # print(mask_pred.size(), mask_tower_interm_outputs.size())

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]
        # mask_activation_pred = mask_activation_pred[pos_inds]

        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(
            reg_pred,
            reg_targets,
            ctrness_targets
        ) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred,
            ctrness_targets,
            reduction="sum"
        ) / num_pos_avg

        mask_targets_ = self.mask_encoding.encoder(mask_targets)
        mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True)

        # compute the loss for the activation code as binary classification
        # activation_targets = (torch.abs(mask_targets_) > 1e-4) * 1.
        # activation_loss = F.binary_cross_entropy_with_logits(
        #     mask_activation_pred,
        #     activation_targets,
        #     reduction='none'
        # )
        # activation_loss = activation_loss.sum(1) * ctrness_targets
        # activation_loss = activation_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)

        # if self.thresh_with_active:
        #     mask_pred = mask_pred * torch.sigmoid(mask_activation_pred)

        total_mask_loss = 0.
        if self.loss_on_mask:
            # n_components predictions --> m*m mask predictions without sigmoid
            # as sigmoid function is combined in loss.
            # mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True)
            if 'mask_mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(
                    mask_pred_,
                    mask_targets,
                    reduction='none'
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0)
                total_mask_loss += mask_loss
            if 'mask_iou' in self.mask_loss_type:
                overlap_ = torch.sum(mask_pred_bin * 1. * mask_targets, 1)
                union_ = torch.sum((mask_pred_bin + mask_targets) >= 1., 1)
                iou_loss = (1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size ** 2
                iou_loss = iou_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0)
                total_mask_loss += iou_loss
            if 'mask_difference' in self.mask_loss_type:
                w_ = torch.abs(mask_pred_bin * 1. - mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                md_loss = torch.sum(w_, 1) * ctrness_targets
                md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0)
                total_mask_loss += md_loss
        if self.loss_on_code:
            # m*m mask labels --> n_components encoding labels
            # mask_targets_ = self.mask_encoding.encoder(mask_targets)
            if 'mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(
                    mask_pred,
                    mask_targets_,
                    reduction='none'
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                if self.mask_sparse_weight > 0.:
                    if self.sparsity_loss_type == 'L1':
                        sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'L0':
                        w_ = (torch.abs(mask_targets_) >= 1e-2) * 1.  # the number of codes that are active
                        sparsity_loss = torch.sum(w_, 1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L1':
                        w_ = (torch.abs(mask_targets_) < 1e-2) * 1.  # inactive codes, put L1 regularization on them
                        sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L2':
                        w_ = (torch.abs(mask_targets_) < 1e-2) * 1.  # inactive codes, put L2 regularization on them
                        sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_KL':
                        w_ = (torch.abs(mask_targets_) < 1e-2) * 1.  # inactive codes, put L2 regularization on them
                        kl_ = kl_divergence(
                            mask_pred,
                            self.kl_rho
                        )
                        sparsity_loss = torch.sum(kl_ * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    else:
                        raise NotImplementedError
                total_mask_loss += mask_loss
            if 'smooth' in self.mask_loss_type:
                mask_loss = F.smooth_l1_loss(
                    mask_pred,
                    mask_targets_,
                    reduction='none'
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'cosine' in self.mask_loss_type:
                mask_loss = loss_cos_sim(
                    mask_pred,
                    mask_targets_
                )
                mask_loss = mask_loss * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'kl_softmax' in self.mask_loss_type:
                mask_loss = loss_kl_div_softmax(
                    mask_pred,
                    mask_targets_
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            elif 'kl_sigmoid' in self.mask_loss_type:
                mask_loss = loss_kl_div_sigmoid(
                    mask_pred,
                    mask_targets_
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            elif 'kl' in self.mask_loss_type:
                mask_loss = kl_divergence(
                    mask_pred,
                    self.kl_rho
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss

        losses = {
            "loss_SMInst_cls": class_loss,
            "loss_SMInst_loc": reg_loss,
            "loss_SMInst_ctr": ctrness_loss,
            "loss_SMInst_mask": total_mask_loss,
        }
        return losses, {}
Beispiel #28
0
    def __call__(self, locations, box_cls, box_regression, centerness,
                 targets):
        """
        Arguments:
            locations (list[BoxList])
            box_cls (list[Tensor])
            box_regression (list[Tensor])
            centerness (list[Tensor])
            targets (list[BoxList])

        Returns:
            cls_loss (Tensor)
            reg_loss (Tensor)
            centerness_loss (Tensor)
        """
        N = box_cls[0].size(0)
        num_classes = box_cls[0].size(1)
        labels, reg_targets = self.prepare_targets(locations, targets)
        box_cls_flatten = []
        box_regression_flatten = []
        centerness_flatten = []
        labels_flatten = []
        reg_targets_flatten = []
        for l in range(len(labels)):
            box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(
                -1, num_classes))
            box_regression_flatten.append(box_regression[l].permute(
                0, 2, 3, 1).reshape(-1, 4))
            labels_flatten.append(labels[l].reshape(-1))
            reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
            centerness_flatten.append(centerness[l].reshape(-1))

        box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
        box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
        centerness_flatten = torch.cat(centerness_flatten, dim=0)
        labels_flatten = torch.cat(labels_flatten, dim=0)
        reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)

        pos_inds = torch.nonzero(labels_flatten != 80).squeeze(1)

        box_regression_flatten = box_regression_flatten[pos_inds]
        reg_targets_flatten = reg_targets_flatten[pos_inds]
        centerness_flatten = centerness_flatten[pos_inds]

        #num_gpus = get_num_gpus()
        # sync num_pos from all gpus
        total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()
                                                        ])).item()
        #num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0)

        gt_classes_target = torch.zeros_like(box_cls_flatten)
        gt_classes_target[pos_inds, labels_flatten[pos_inds]] = 1

        cls_loss = sigmoid_focal_loss_jit(  #self.cls_loss_func(
            box_cls_flatten,
            gt_classes_target,  #.int(),
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / total_num_pos  #num_pos_avg_per_gpu
        if pos_inds.numel() > 0:
            centerness_targets = self.compute_centerness_targets(
                reg_targets_flatten)

            # average sum_centerness_targets from all gpus,
            # which is used to normalize centerness-weighed reg loss
            sum_centerness_targets_avg_per_gpu = \
                reduce_sum(centerness_targets.sum()).item() #/ float(num_gpus)
            reg_loss = self.box_reg_loss_func(
                box_regression_flatten, reg_targets_flatten,
                centerness_targets) / sum_centerness_targets_avg_per_gpu
            centerness_loss = self.centerness_loss_func(
                centerness_flatten,
                centerness_targets) / total_num_pos  #num_pos_avg_per_gpu
        else:
            reg_loss = box_regression_flatten.sum()
            reduce_sum(centerness_flatten.new_tensor([0.0]))
            centerness_loss = centerness_flatten.sum()

        return cls_loss, reg_loss, centerness_loss
Beispiel #29
0
    def losses(self, labels, reg_targets, box_cls, box_regression, centerness):
        N, num_classes = box_cls[0].shape[:2]

        box_cls_flatten = []
        box_regression_flatten = []
        centerness_flatten = []
        labels_flatten = []
        reg_targets_flatten = []
        for l in range(len(labels)):
            box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(
                -1, num_classes))
            box_regression_flatten.append(box_regression[l].permute(
                0, 2, 3, 1).reshape(-1, 4))
            labels_flatten.append(labels[l].reshape(-1))
            reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
            centerness_flatten.append(centerness[l].reshape(-1))

        box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
        box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
        centerness_flatten = torch.cat(centerness_flatten, dim=0)
        labels_flatten = torch.cat(labels_flatten, dim=0)
        reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)

        pos_inds = torch.nonzero((labels_flatten >= 0) & (
            labels_flatten != self.num_classes)).squeeze(1)

        box_regression_flatten = box_regression_flatten[pos_inds]
        reg_targets_flatten = reg_targets_flatten[pos_inds]
        centerness_flatten = centerness_flatten[pos_inds]

        num_gpus = get_num_gpus()
        # sync num_pos from all gpus
        total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()
                                                        ])).item()
        num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0)

        gt_classes_target = torch.zeros_like(box_cls_flatten)
        foreground_idxs = (labels_flatten >= 0) & (labels_flatten !=
                                                   self.num_classes)
        gt_classes_target[foreground_idxs, labels_flatten[foreground_idxs]] = 1

        cls_loss = sigmoid_focal_loss_jit(
            box_cls_flatten,
            gt_classes_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg_per_gpu

        if pos_inds.numel() > 0:
            centerness_targets = compute_centerness_targets(
                reg_targets_flatten)
            # average sum_centerness_targets from all gpus,
            # which is used to normalize centerness-weighed reg loss
            sum_centerness_targets_avg_per_gpu = \
                reduce_sum(centerness_targets.sum()).item() / float(num_gpus)

            reg_loss = iou_loss(box_regression_flatten,
                                reg_targets_flatten,
                                centerness_targets,
                                loss_type=self.iou_loss_type
                                ) / sum_centerness_targets_avg_per_gpu

            centerness_loss = F.binary_cross_entropy_with_logits(
                centerness_flatten, centerness_targets,
                reduction='sum') / num_pos_avg_per_gpu
        else:
            reg_loss = box_regression_flatten.sum()
            reduce_sum(centerness_flatten.new_tensor([0.0]))
            centerness_loss = centerness_flatten.sum()

        return dict(cls_loss=cls_loss,
                    reg_loss=reg_loss,
                    centerness_loss=centerness_loss)
Beispiel #30
0
 def run_focal_loss_jit() -> None:
     fl = sigmoid_focal_loss_jit(
         inputs, targets, gamma=0, alpha=alpha, reduction="mean"
     )
     fl.backward()
     torch.cuda.synchronize()