def fcos_losses(self, instances):
        num_classes = instances.logits_pred.size(1)
        assert num_classes == self.num_classes

        labels = instances.labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(instances.logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            instances.logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        instances = instances[pos_inds]
        instances.pos_inds = pos_inds

        ctrness_targets = compute_ctrness_targets(instances.reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)
        instances.gt_ctrs = ctrness_targets

        if pos_inds.numel() > 0:
            reg_loss = self.loc_loss_func(
                instances.reg_pred,
                instances.reg_targets,
                ctrness_targets
            ) / loss_denorm

            ctrness_loss = F.binary_cross_entropy_with_logits(
                instances.ctrness_pred,
                ctrness_targets,
                reduction="sum"
            ) / num_pos_avg
        else:
            reg_loss = instances.reg_pred.sum() * 0
            ctrness_loss = instances.ctrness_pred.sum() * 0

        losses = {
            "loss_fcos_cls": class_loss,
            "loss_fcos_loc": reg_loss,
            "loss_fcos_ctr": ctrness_loss
        }
        extras = {
            "instances": instances,
            "loss_denorm": loss_denorm
        }
        return extras, losses
Exemple #2
0
def fcos_losses(
    labels,
    reg_targets,
    logits_pred,
    reg_pred,
    ctrness_pred,
    focal_loss_alpha,
    focal_loss_gamma,
    iou_loss,
):
    num_classes = logits_pred.size(1)
    labels = labels.flatten()

    pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
    num_pos_local = pos_inds.numel()
    num_gpus = get_world_size()
    total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
    num_pos_avg = max(total_num_pos / num_gpus, 1.0)

    # prepare one_hot
    class_target = torch.zeros_like(logits_pred)
    class_target[pos_inds, labels[pos_inds]] = 1

    class_loss = sigmoid_focal_loss_jit(
        logits_pred,
        class_target,
        alpha=focal_loss_alpha,
        gamma=focal_loss_gamma,
        reduction="sum",
    ) / num_pos_avg

    reg_pred = reg_pred[pos_inds]
    reg_targets = reg_targets[pos_inds]
    ctrness_pred = ctrness_pred[pos_inds]

    ctrness_targets = compute_ctrness_targets(reg_targets)
    ctrness_targets_sum = ctrness_targets.sum()
    ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

    reg_loss = iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm

    ctrness_loss = F.binary_cross_entropy_with_logits(
        ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

    losses = {
        "loss_fcos_cls": class_loss,
        "loss_fcos_loc": reg_loss,
        "loss_fcos_ctr": ctrness_loss
    }
    return losses, {}
Exemple #3
0
def compute_loss(p1_heatmap_list, p3_heatmap_list, p1_logits, p3_logits):
    # gt_bitmasks = gt_bitmasks.float()
    # mask_logits = mask_logits.sigmoid()
    num_gpus = get_world_size()

    num_dice = (p1_heatmap_list**2).sum()
    num_dice = reduce_sum(p1_logits.new_tensor([num_dice])).item()
    num_dice = max(num_dice / num_gpus, 1.0)

    p1_loss = F.mse_loss(p1_heatmap_list, p1_logits,
                         reduction='sum') / num_dice

    num_dice = (p3_heatmap_list**2).sum()
    num_dice = reduce_sum(p3_logits.new_tensor([num_dice])).item()
    num_dice = max(num_dice / num_gpus, 1.0)

    p3_loss = F.mse_loss(p3_heatmap_list, p3_logits,
                         reduction='sum') / num_dice

    # loss = (p1_loss + p3_loss) / 2

    return p1_loss, p3_loss
Exemple #4
0
    def fcos_losses(self, labels, reg_targets, logits_pred, reg_pred,
                    ctrness_pred, gt_inds, mask_centers_targets):
        num_classes = logits_pred.size(1)
        assert num_classes == self.num_classes

        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        gt_inds = gt_inds[pos_inds]
        mask_center = mask_centers_targets[pos_inds]

        # 需要修改
        # ctrness_targets = compute_ctrness_targets(reg_targets)
        # ctrness_targets_sum = ctrness_targets.sum()
        # loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        if pos_inds.numel() > 0:
            reg_loss = self.loc_loss_func(
                reg_pred,
                reg_targets,
                ctrness_pred,
                mask_center,
            )
        else:
            reg_loss = reg_pred.sum() * 0
        losses = {"loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss}
        extras = {
            "pos_inds": pos_inds,
            "gt_inds": gt_inds,
        }
        return losses, extras
    def SMInst_losses(
            self,
            labels,
            reg_targets,
            logits_pred,
            reg_pred,
            ctrness_pred,
            mask_pred,
            mask_targets
    ):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        # print(mask_pred.size(), mask_tower_interm_outputs.size())

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]
        # mask_activation_pred = mask_activation_pred[pos_inds]

        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(
            reg_pred,
            reg_targets,
            ctrness_targets
        ) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred,
            ctrness_targets,
            reduction="sum"
        ) / num_pos_avg

        mask_targets_ = self.mask_encoding.encoder(mask_targets)
        mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True)

        # compute the loss for the activation code as binary classification
        # activation_targets = (torch.abs(mask_targets_) > 1e-4) * 1.
        # activation_loss = F.binary_cross_entropy_with_logits(
        #     mask_activation_pred,
        #     activation_targets,
        #     reduction='none'
        # )
        # activation_loss = activation_loss.sum(1) * ctrness_targets
        # activation_loss = activation_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)

        # if self.thresh_with_active:
        #     mask_pred = mask_pred * torch.sigmoid(mask_activation_pred)

        total_mask_loss = 0.
        if self.loss_on_mask:
            # n_components predictions --> m*m mask predictions without sigmoid
            # as sigmoid function is combined in loss.
            # mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True)
            if 'mask_mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(
                    mask_pred_,
                    mask_targets,
                    reduction='none'
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0)
                total_mask_loss += mask_loss
            if 'mask_iou' in self.mask_loss_type:
                overlap_ = torch.sum(mask_pred_bin * 1. * mask_targets, 1)
                union_ = torch.sum((mask_pred_bin + mask_targets) >= 1., 1)
                iou_loss = (1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size ** 2
                iou_loss = iou_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0)
                total_mask_loss += iou_loss
            if 'mask_difference' in self.mask_loss_type:
                w_ = torch.abs(mask_pred_bin * 1. - mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                md_loss = torch.sum(w_, 1) * ctrness_targets
                md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0)
                total_mask_loss += md_loss
        if self.loss_on_code:
            # m*m mask labels --> n_components encoding labels
            # mask_targets_ = self.mask_encoding.encoder(mask_targets)
            if 'mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(
                    mask_pred,
                    mask_targets_,
                    reduction='none'
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                if self.mask_sparse_weight > 0.:
                    if self.sparsity_loss_type == 'L1':
                        sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'L0':
                        w_ = (torch.abs(mask_targets_) >= 1e-2) * 1.  # the number of codes that are active
                        sparsity_loss = torch.sum(w_, 1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L1':
                        w_ = (torch.abs(mask_targets_) < 1e-2) * 1.  # inactive codes, put L1 regularization on them
                        sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L2':
                        w_ = (torch.abs(mask_targets_) < 1e-2) * 1.  # inactive codes, put L2 regularization on them
                        sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_KL':
                        w_ = (torch.abs(mask_targets_) < 1e-2) * 1.  # inactive codes, put L2 regularization on them
                        kl_ = kl_divergence(
                            mask_pred,
                            self.kl_rho
                        )
                        sparsity_loss = torch.sum(kl_ * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    else:
                        raise NotImplementedError
                total_mask_loss += mask_loss
            if 'smooth' in self.mask_loss_type:
                mask_loss = F.smooth_l1_loss(
                    mask_pred,
                    mask_targets_,
                    reduction='none'
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'cosine' in self.mask_loss_type:
                mask_loss = loss_cos_sim(
                    mask_pred,
                    mask_targets_
                )
                mask_loss = mask_loss * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'kl_softmax' in self.mask_loss_type:
                mask_loss = loss_kl_div_softmax(
                    mask_pred,
                    mask_targets_
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            elif 'kl_sigmoid' in self.mask_loss_type:
                mask_loss = loss_kl_div_sigmoid(
                    mask_pred,
                    mask_targets_
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            elif 'kl' in self.mask_loss_type:
                mask_loss = kl_divergence(
                    mask_pred,
                    self.kl_rho
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss

        losses = {
            "loss_SMInst_cls": class_loss,
            "loss_SMInst_loc": reg_loss,
            "loss_SMInst_ctr": ctrness_loss,
            "loss_SMInst_mask": total_mask_loss,
        }
        return losses, {}
Exemple #6
0
    def fcos_losses(self, instances):
        num_classes = instances.logits_pred.size(1)
        assert num_classes == self.num_classes

        labels = instances.labels.flatten()
        gt_object = instances.gt_inds

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        neg_inds = torch.nonzero(labels == num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(instances.logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            instances.logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="none",
        )  #/ num_pos_avg

        positive_diff = (
            1 - instances.logits_pred[class_target == 1].sigmoid()).abs()
        negative_diff = (
            0 - instances.logits_pred[class_target == 0].sigmoid()).abs()

        positive_mean = positive_diff.mean().detach()
        positive_std = positive_diff.std().detach()

        negative_mean = negative_diff.mean().detach()
        negative_std = negative_diff.std().detach()

        upper_true_loss = class_loss.flatten()[(class_target == 1).flatten()][
            (positive_diff >
             (positive_mean + positive_std))].sum() / num_pos_avg
        under_true_loss = class_loss.flatten()[(class_target == 1).flatten()][
            (positive_diff <=
             (positive_mean + positive_std))].sum() / num_pos_avg
        upper_false_loss = class_loss.flatten()[(class_target == 0).flatten()][
            (negative_diff >
             (negative_mean + negative_std))].sum() / num_pos_avg
        under_false_loss = class_loss.flatten()[(class_target == 0).flatten()][
            (negative_diff <=
             (negative_mean + negative_std))].sum() / num_pos_avg

        storage = get_event_storage()
        if storage.iter % 20 == 0:
            logger.info(
                "upper_true {}, under_true {} upper_false {} under_false {}".
                format((positive_diff > positive_mean + positive_std).sum(),
                       (positive_diff <= positive_mean + positive_std).sum(),
                       (negative_diff > negative_mean + negative_std).sum(),
                       (negative_diff <= negative_mean + negative_std).sum()))

        instances = instances[pos_inds]
        instances.pos_inds = pos_inds

        #assert (instances.gt_inds.unique() != gt_object.unique()).sum() == 0

        ctrness_targets = compute_ctrness_targets(instances.reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        loss_denorm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)
        instances.gt_ctrs = ctrness_targets

        if pos_inds.numel() > 0:
            reg_loss = self.loc_loss_func(instances.reg_pred,
                                          instances.reg_targets,
                                          ctrness_targets) / loss_denorm

            ctrness_loss = torch.nn.MSELoss(reduction="sum")(
                instances.ctrness_pred.sigmoid(),
                ctrness_targets) / num_pos_avg
        else:
            reg_loss = instances.reg_pred.sum() * 0
            ctrness_loss = instances.ctrness_pred.sum() * 0

        losses = {
            "loss_upper_true_cls": upper_true_loss,
            "loss_under_true_cls": under_true_loss,
            "loss_upper_false_cls": upper_false_loss,
            "loss_under_false_cls": under_false_loss,
            "loss_fcos_loc": reg_loss,
            "loss_fcos_ctr": ctrness_loss,
            #"loss_negative_identity_mean": negative_identity_mean_loss,
            #"loss_negative_identity_std": negative_identity_std_loss,
            #"loss_positive_identity": positive_identity_loss,
        }
        extras = {"instances": instances, "loss_denorm": loss_denorm}
        return extras, losses
Exemple #7
0
    def MEInst_losses(self, labels, reg_targets, logits_pred, reg_pred,
                      ctrness_pred, mask_pred, mask_targets):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]
        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(reg_pred, reg_targets,
                                 ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        if self.loss_on_mask:
            # n_components predictions --> m*m mask predictions without sigmoid
            # as sigmoid function is combined in loss.
            mask_pred = self.mask_encoding.decoder(mask_pred, is_train=True)
            mask_loss = self.mask_loss_func(mask_pred, mask_targets)
            mask_loss = mask_loss.sum(1) * ctrness_targets
            mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
        else:
            # m*m mask labels --> n_components encoding labels
            mask_targets = self.mask_encoding.encoder(mask_targets)
            if self.mask_loss_type == 'mse':
                mask_loss = self.mask_loss_func(mask_pred, mask_targets)
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask,
                                                  1.0)
            else:
                raise NotImplementedError

        losses = {
            "loss_MEInst_cls": class_loss,
            "loss_MEInst_loc": reg_loss,
            "loss_MEInst_ctr": ctrness_loss,
            "loss_MEInst_mask": mask_loss,
        }
        return losses, {}
    def DTInst_losses(self, labels, reg_targets, logits_pred, reg_pred,
                      ctrness_pred, mask_pred, mask_targets):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]
        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(reg_pred, reg_targets,
                                 ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        total_mask_loss = 0.
        dtm_pred_, binary_pred_ = self.mask_encoding.decoder(mask_pred,
                                                             is_train=True)
        code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder(
            mask_targets)
        if self.loss_on_mask:
            if 'mask_mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(dtm_pred_,
                                       dtm_targets,
                                       reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += mask_loss
            if 'weighted_mask_mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(dtm_pred_,
                                       dtm_targets,
                                       reduction='none')
                mask_loss = torch.sum(mask_loss * weight_maps, 1) / torch.sum(
                    weight_maps, 1) * ctrness_targets * self.mask_size**2
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += mask_loss
            if 'mask_difference' in self.mask_loss_type:
                w_ = torch.abs(
                    binary_pred_ * 1. -
                    mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                md_loss = torch.sum(w_, 1) * ctrness_targets
                md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
                total_mask_loss += md_loss
            if 'hd_one_side_binary' in self.mask_loss_type:  # the first attempt, not really accurate
                w_ = torch.abs(
                    binary_pred_ * 1. -
                    mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(w_ * hd_maps, 1) / (torch.sum(
                    w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_two_side_binary' in self.mask_loss_type:  # the first attempt, not really accurate
                w_ = torch.abs(
                    binary_pred_ * 1. -
                    mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    w_ *
                    (torch.clamp(dtm_pred_**2, -0.1, 1.1) +
                     torch.clamp(dtm_targets**2, -0.1, 1)), 1) / (torch.sum(
                         w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_weighted_one_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    dtm_diff * weight_maps * hd_maps,
                    1) / (torch.sum(weight_maps, 1) +
                          1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_weighted_two_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    dtm_diff * weight_maps * (dtm_pred_**2 + dtm_targets**2),
                    1) / (torch.sum(weight_maps, 1) +
                          1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_one_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(dtm_diff * hd_maps,
                                           1) * ctrness_targets
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_two_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    dtm_diff *
                    (torch.clamp(dtm_pred_, -1.1, 1.1)**2 + dtm_targets**2),
                    1) * ctrness_targets
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'contour_dice' in self.mask_loss_type:
                pred_contour = (dtm_pred_ + 0.9 < 0.55) * 1. * (
                    0.5 <= dtm_pred_ + 0.9
                )  # contour pixels with 0.05 tolerance
                target_contour = (dtm_targets < 0.05) * 1. * (dtm_targets <
                                                              0.05)
                # pred_contour = 0.5 <= dtm_pred_ + 0.9 < 0.55  # contour pixels with 0.05 tolerance
                # target_contour = 0. <= dtm_targets < 0.05
                overlap_ = torch.sum(pred_contour * 2. * target_contour, 1)
                union_ = torch.sum(pred_contour**2, 1) + torch.sum(
                    target_contour**2, 1)
                dice_loss = (
                    1. - overlap_ /
                    (union_ + 1e-4)) * ctrness_targets * self.mask_size**2
                dice_loss = dice_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += dice_loss
            if 'mask_dice' in self.mask_loss_type:
                overlap_ = torch.sum(binary_pred_ * 2. * mask_targets, 1)
                union_ = torch.sum(binary_pred_**2, 1) + torch.sum(
                    mask_targets**2, 1)
                dice_loss = (
                    1. - overlap_ /
                    (union_ + 1e-4)) * ctrness_targets * self.mask_size**2
                dice_loss = dice_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += dice_loss
        if self.loss_on_code:
            # m*m mask labels --> n_components encoding labels
            if 'mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(mask_pred,
                                       code_targets,
                                       reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                if self.mask_sparse_weight > 0.:
                    if self.sparsity_loss_type == 'L1':
                        sparsity_loss = torch.sum(torch.abs(mask_pred),
                                                  1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L1':
                        w_ = (
                            torch.abs(code_targets) < 1e-4
                        ) * 1.  # inactive codes, put L1 regularization on them
                        sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L2':
                        w_ = (
                            torch.abs(code_targets) < 1e-4
                        ) * 1.  # inactive codes, put L2 regularization on them
                        sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    else:
                        raise NotImplementedError
                total_mask_loss += mask_loss
            if 'smooth' in self.mask_loss_type:
                mask_loss = F.smooth_l1_loss(mask_pred,
                                             code_targets,
                                             reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'cosine' in self.mask_loss_type:
                mask_loss = loss_cos_sim(mask_pred, code_targets)
                mask_loss = mask_loss * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'kl_softmax' in self.mask_loss_type:
                mask_loss = loss_kl_div_softmax(mask_pred, code_targets)
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss

        losses = {
            "loss_DTInst_cls": class_loss,
            "loss_DTInst_loc": reg_loss,
            "loss_DTInst_ctr": ctrness_loss,
            "loss_DTInst_mask": total_mask_loss
        }
        return losses, {}
Exemple #9
0
def fcos_losses(
    labels,
    reg_targets,
    bezier_targets,
    logits_pred,
    reg_pred,
    bezier_pred,
    ctrness_pred,
    focal_loss_alpha,
    focal_loss_gamma,
    iou_loss,
):
    num_classes = logits_pred.size(1)
    labels = labels.flatten()

    pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
    num_pos_local = pos_inds.numel()
    num_gpus = get_world_size()
    total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
    num_pos_avg = max(total_num_pos / num_gpus, 1.0)

    # prepare one_hot
    class_target = torch.zeros_like(logits_pred)
    class_target[pos_inds, labels[pos_inds]] = 1

    class_loss = sigmoid_focal_loss_jit(
        logits_pred,
        class_target,
        alpha=focal_loss_alpha,
        gamma=focal_loss_gamma,
        reduction="sum",
    ) / num_pos_avg

    reg_pred = reg_pred[pos_inds]
    bezier_pred = bezier_pred[pos_inds]
    reg_targets = reg_targets[pos_inds]
    bezier_targets = bezier_targets[pos_inds]
    ctrness_pred = ctrness_pred[pos_inds]

    ious, gious = compute_ious(reg_pred, reg_targets)
    ctrness_targets = compute_ctrness_targets(reg_targets)
    ctrness_targets_sum = ctrness_targets.sum()
    loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

    if pos_inds.numel() > 0:
        reg_loss = iou_loss(ious, gious, ctrness_targets) / loss_denorm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg
    else:
        reg_loss = reg_pred.sum() * 0
        bezier_loss = bezier_pred.sum() * 0
        ctrness_loss = ctrness_pred.sum() * 0

    bezier_loss = F.smooth_l1_loss(bezier_pred,
                                   bezier_targets,
                                   reduction="none")
    bezier_loss = ((bezier_loss.mean(dim=-1) * ctrness_targets).sum() /
                   loss_denorm)

    losses = {
        "loss_fcos_cls": class_loss,
        "loss_fcos_loc": reg_loss,
        "loss_fcos_ctr": ctrness_loss,
        "loss_fcos_bezier": bezier_loss,
    }
    return losses
    def DTMRInst_losses(self, labels, reg_targets, logits_pred, reg_pred,
                        ctrness_pred, mask_pred, mask_pred_decoded_list,
                        mask_targets):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]

        for i in range(len(mask_pred_decoded_list)):
            mask_pred_decoded_list[i] = mask_pred_decoded_list[i][pos_inds]

        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(reg_pred, reg_targets,
                                 ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        total_mask_loss = 0.
        # _, binary_pred_ = self.mask_encoding.decoder(mask_pred, is_train=True)  # from sparse coefficients to DTMs/images
        # code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder(mask_targets)
        code_targets, dtm_targets, weight_maps, _ = self.mask_encoding.encoder(
            mask_targets)

        if self.loss_on_mask:
            if 'mask_mse' in self.mask_loss_type:
                mask_loss = 0
                for mask_pred_ in mask_pred_decoded_list:
                    _loss = F.mse_loss(mask_pred_,
                                       mask_targets,
                                       reduction='none')
                    _loss = _loss.sum(1) * ctrness_targets
                    _loss = _loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
                    mask_loss += _loss
                total_mask_loss += mask_loss
            if 'weighted_mask_mse' in self.mask_loss_type:
                mask_loss = 0
                for mask_pred_ in mask_pred_decoded_list:
                    _loss = F.mse_loss(mask_pred_,
                                       mask_targets,
                                       reduction='none')
                    _loss = torch.sum(_loss * weight_maps, 1) * ctrness_targets
                    _loss = _loss.sum() / torch.sum(weight_maps) / max(
                        ctrness_norm * self.mask_size**2, 1.0)
                    mask_loss += _loss
                total_mask_loss += mask_loss
            if 'mask_dice' in self.mask_loss_type:
                # This is to use all the output to calculate the mask loss
                dice_loss = 0
                for mask_pred_ in mask_pred_decoded_list:
                    overlap_ = torch.sum(mask_pred_ * 2. * mask_targets, 1)
                    union_ = torch.sum(mask_pred_**2, 1) + torch.sum(
                        mask_targets**2, 1)
                    _loss = (
                        1. - overlap_ /
                        (union_ + 1e-5)) * ctrness_targets * self.mask_size**2
                    _loss = _loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
                    dice_loss += _loss

                # This is to just use the last output to calculate the mask loss
                mask_pred_ = mask_pred_decoded_list[-1]
                overlap_ = torch.sum(mask_pred_ * 2. * mask_targets, 1)
                union_ = torch.sum(mask_pred_**2, 1) + torch.sum(
                    mask_targets**2, 1)
                _loss = (1. - overlap_ /
                         (union_ + 1e-5)) * ctrness_targets * self.mask_size**2
                dice_loss = _loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)

                total_mask_loss += dice_loss

        if self.loss_on_code:
            # m*m mask labels --> n_components encoding labels
            if 'mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(mask_pred,
                                       code_targets,
                                       reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                if self.mask_sparse_weight > 0.:
                    if self.sparsity_loss_type == 'L1':
                        sparsity_loss = torch.sum(torch.abs(mask_pred),
                                                  1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L1':
                        w_ = (
                            torch.abs(code_targets) < 1e-3
                        ) * 1.  # inactive codes, put L1 regularization on them
                        sparsity_loss = torch.sum(
                            torch.abs(mask_pred) * w_, 1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / torch.sum(
                            w_) / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L2':
                        w_ = (
                            torch.abs(code_targets) < 1e-3
                        ) * 1.  # inactive codes, put L2 regularization on them
                        sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    else:
                        raise NotImplementedError
                total_mask_loss += mask_loss
            if 'smooth' in self.mask_loss_type:
                mask_loss = F.smooth_l1_loss(mask_pred,
                                             code_targets,
                                             reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'cosine' in self.mask_loss_type:
                mask_loss = loss_cos_sim(mask_pred, code_targets)
                mask_loss = mask_loss * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'kl_softmax' in self.mask_loss_type:
                mask_loss = loss_kl_div_softmax(mask_pred, code_targets)
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss

        losses = {
            "loss_DTMRInst_cls": class_loss,
            "loss_DTMRInst_loc": reg_loss,
            "loss_DTMRInst_ctr": ctrness_loss,
            "loss_DTMRInst_mask": total_mask_loss
        }
        return losses, {}
Exemple #11
0
def compute_loss_softmax(gt_bitmasks, mask_logits, num_loss, num_instances,
                         direction, direction_mask_logits, gt_keypoint,
                         max_ranges, distance_norm):
    assert not torch.isnan(mask_logits).any()
    assert not torch.isnan(direction).any()
    assert not torch.isnan(direction_mask_logits).any()
    # direction_mask_logits = direction_mask_logits.detach()
    N, K, H, W = gt_bitmasks.size()
    # gt_bitmasks = gt_bitmasks.float()
    num_gpus = get_world_size()
    assert not (num_loss == 0).any()
    loss_weight = 1 / num_loss  #TODO num_loss can be 0
    sum_loss_weight = loss_weight.sum()
    assert sum_loss_weight != 0
    loss_weight = loss_weight[:, None].repeat(1, 17).flatten()

    gt_bitmasks = gt_bitmasks.reshape(N * K, H * W)
    mask_logits = mask_logits.reshape(N * K, H * W)
    gt_bitmasks_visible_mask = gt_bitmasks.sum(dim=1).bool()
    # assert gt_bitmasks_visible_mask.sum()!=0 #TODO AssertionError
    if gt_bitmasks_visible_mask.sum() != 0:
        loss_weight = loss_weight[gt_bitmasks_visible_mask]
        mask_logits = mask_logits[gt_bitmasks_visible_mask]
        gt_bitmasks = gt_bitmasks[gt_bitmasks_visible_mask]
        mask_logits = F.log_softmax(mask_logits, dim=1)

        total_instances = reduce_sum(mask_logits.new_tensor([num_instances
                                                             ])).item()
        gpu_balence_factor = num_instances / total_instances

        loss = (-mask_logits[gt_bitmasks])
        loss = (loss * loss_weight).sum() / 17
        loss = (loss / sum_loss_weight) * gpu_balence_factor

        max_ranges = max_ranges[:, None].repeat(
            1, 17).flatten()[gt_bitmasks_visible_mask]
        gt_keypoint = gt_keypoint[:, :, [0, 1]]

        N, H, W, K, _ = direction_mask_logits.size()
        direction = direction - gt_keypoint[:, None, None, :, :]
        direction = direction.permute(0, 3, 1, 2, 4).reshape(N * 17, H, W, 2)
        direction = direction[gt_bitmasks_visible_mask]
        direction = (direction[:, :, :, 0]**2 +
                     direction[:, :, :, 1]**2).sqrt()[:, :, :, None]
        assert (max_ranges != 0).all()
        direction = direction / max_ranges[:, None, None, None]
        direction = direction * distance_norm
        direction = (direction.sigmoid() - 0.5) * 2
        direction_mask_logits = direction_mask_logits.permute(
            0, 3, 1, 2, 4).reshape(N * 17, H, W, 1)
        direction_mask_logits = direction_mask_logits[gt_bitmasks_visible_mask]
        direction = direction * direction_mask_logits
        direction = direction.flatten(start_dim=1).sum(dim=1)
        direction = direction * loss_weight
        assert distance_norm != 0
        direction_loss = (direction / sum_loss_weight *
                          gpu_balence_factor) / distance_norm
        direction_loss = direction_loss.sum()
        assert not torch.isnan(direction_loss).any()
        assert not torch.isnan(loss).any()
        return loss, direction_loss
    else:
        print('gt_bitmasks_visible_mask.sum()==0')
        total_instances = reduce_sum(mask_logits.new_tensor([num_instances
                                                             ])).item()
        loss = mask_logits.sum() + direction.sum() + direction_mask_logits.sum(
        )
        loss = loss * 0.0
        return loss, loss