Ejemplo n.º 1
0
    def SMInst_losses(
            self,
            labels,
            reg_targets,
            logits_pred,
            reg_pred,
            ctrness_pred,
            mask_pred,
            mask_targets
    ):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        # print(mask_pred.size(), mask_tower_interm_outputs.size())

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]
        # mask_activation_pred = mask_activation_pred[pos_inds]

        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(
            reg_pred,
            reg_targets,
            ctrness_targets
        ) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred,
            ctrness_targets,
            reduction="sum"
        ) / num_pos_avg

        mask_targets_ = self.mask_encoding.encoder(mask_targets)
        mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True)

        # compute the loss for the activation code as binary classification
        # activation_targets = (torch.abs(mask_targets_) > 1e-4) * 1.
        # activation_loss = F.binary_cross_entropy_with_logits(
        #     mask_activation_pred,
        #     activation_targets,
        #     reduction='none'
        # )
        # activation_loss = activation_loss.sum(1) * ctrness_targets
        # activation_loss = activation_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)

        # if self.thresh_with_active:
        #     mask_pred = mask_pred * torch.sigmoid(mask_activation_pred)

        total_mask_loss = 0.
        if self.loss_on_mask:
            # n_components predictions --> m*m mask predictions without sigmoid
            # as sigmoid function is combined in loss.
            # mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True)
            if 'mask_mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(
                    mask_pred_,
                    mask_targets,
                    reduction='none'
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0)
                total_mask_loss += mask_loss
            if 'mask_iou' in self.mask_loss_type:
                overlap_ = torch.sum(mask_pred_bin * 1. * mask_targets, 1)
                union_ = torch.sum((mask_pred_bin + mask_targets) >= 1., 1)
                iou_loss = (1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size ** 2
                iou_loss = iou_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0)
                total_mask_loss += iou_loss
            if 'mask_difference' in self.mask_loss_type:
                w_ = torch.abs(mask_pred_bin * 1. - mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                md_loss = torch.sum(w_, 1) * ctrness_targets
                md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0)
                total_mask_loss += md_loss
        if self.loss_on_code:
            # m*m mask labels --> n_components encoding labels
            # mask_targets_ = self.mask_encoding.encoder(mask_targets)
            if 'mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(
                    mask_pred,
                    mask_targets_,
                    reduction='none'
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                if self.mask_sparse_weight > 0.:
                    if self.sparsity_loss_type == 'L1':
                        sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'L0':
                        w_ = (torch.abs(mask_targets_) >= 1e-2) * 1.  # the number of codes that are active
                        sparsity_loss = torch.sum(w_, 1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L1':
                        w_ = (torch.abs(mask_targets_) < 1e-2) * 1.  # inactive codes, put L1 regularization on them
                        sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L2':
                        w_ = (torch.abs(mask_targets_) < 1e-2) * 1.  # inactive codes, put L2 regularization on them
                        sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_KL':
                        w_ = (torch.abs(mask_targets_) < 1e-2) * 1.  # inactive codes, put L2 regularization on them
                        kl_ = kl_divergence(
                            mask_pred,
                            self.kl_rho
                        )
                        sparsity_loss = torch.sum(kl_ * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    else:
                        raise NotImplementedError
                total_mask_loss += mask_loss
            if 'smooth' in self.mask_loss_type:
                mask_loss = F.smooth_l1_loss(
                    mask_pred,
                    mask_targets_,
                    reduction='none'
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'cosine' in self.mask_loss_type:
                mask_loss = loss_cos_sim(
                    mask_pred,
                    mask_targets_
                )
                mask_loss = mask_loss * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'kl_softmax' in self.mask_loss_type:
                mask_loss = loss_kl_div_softmax(
                    mask_pred,
                    mask_targets_
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            elif 'kl_sigmoid' in self.mask_loss_type:
                mask_loss = loss_kl_div_sigmoid(
                    mask_pred,
                    mask_targets_
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            elif 'kl' in self.mask_loss_type:
                mask_loss = kl_divergence(
                    mask_pred,
                    self.kl_rho
                )
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss

        losses = {
            "loss_SMInst_cls": class_loss,
            "loss_SMInst_loc": reg_loss,
            "loss_SMInst_ctr": ctrness_loss,
            "loss_SMInst_mask": total_mask_loss,
        }
        return losses, {}
Ejemplo n.º 2
0
    def MEInst_losses(self, labels, reg_targets, logits_pred, reg_pred,
                      ctrness_pred, mask_pred, mask_targets):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]
        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(reg_pred, reg_targets,
                                 ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        total_mask_loss = 0.
        if self.loss_on_mask:
            # n_components predictions --> m*m mask predictions without sigmoid
            # as sigmoid function is combined in loss.
            mask_pred_ = self.mask_encoding.decoder(mask_pred, is_train=True)
            mask_loss = self.bce(mask_pred_, mask_targets)
            mask_loss = mask_loss.sum(1) * ctrness_targets
            mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
            total_mask_loss += mask_loss
        if self.loss_on_code:
            # m*m mask labels --> n_components encoding labels
            mask_targets_ = self.mask_encoding.encoder(mask_targets)
            if 'mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(mask_pred,
                                       mask_targets_,
                                       reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask,
                                                  1.0)
                total_mask_loss += mask_loss
            if 'smooth' in self.mask_loss_type:
                mask_loss = F.smooth_l1_loss(mask_pred,
                                             mask_targets_,
                                             reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask,
                                                  1.0)
                total_mask_loss += mask_loss
            if 'cosine' in self.mask_loss_type:
                mask_loss = loss_cos_sim(mask_pred, mask_targets_)
                mask_loss = mask_loss * ctrness_targets * self.dim_mask
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask,
                                                  1.0)
                total_mask_loss += mask_loss
            if 'kl_softmax' in self.mask_loss_type:
                mask_loss = loss_kl_div_softmax(mask_pred, mask_targets_)
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.dim_mask
                mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask,
                                                  1.0)
                total_mask_loss += mask_loss

        losses = {
            "loss_MEInst_cls": class_loss,
            "loss_MEInst_loc": reg_loss,
            "loss_MEInst_ctr": ctrness_loss,
            "loss_MEInst_mask": total_mask_loss,
        }
        return losses, {}
Ejemplo n.º 3
0
    def DTInst_losses(self, labels, reg_targets, logits_pred, reg_pred,
                      ctrness_pred, mask_pred, mask_targets):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]
        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(reg_pred, reg_targets,
                                 ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        total_mask_loss = 0.
        dtm_pred_, binary_pred_ = self.mask_encoding.decoder(mask_pred,
                                                             is_train=True)
        code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder(
            mask_targets)
        if self.loss_on_mask:
            if 'mask_mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(dtm_pred_,
                                       dtm_targets,
                                       reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += mask_loss
            if 'weighted_mask_mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(dtm_pred_,
                                       dtm_targets,
                                       reduction='none')
                mask_loss = torch.sum(mask_loss * weight_maps, 1) / torch.sum(
                    weight_maps, 1) * ctrness_targets * self.mask_size**2
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += mask_loss
            if 'mask_difference' in self.mask_loss_type:
                w_ = torch.abs(
                    binary_pred_ * 1. -
                    mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                md_loss = torch.sum(w_, 1) * ctrness_targets
                md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
                total_mask_loss += md_loss
            if 'hd_one_side_binary' in self.mask_loss_type:  # the first attempt, not really accurate
                w_ = torch.abs(
                    binary_pred_ * 1. -
                    mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(w_ * hd_maps, 1) / (torch.sum(
                    w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_two_side_binary' in self.mask_loss_type:  # the first attempt, not really accurate
                w_ = torch.abs(
                    binary_pred_ * 1. -
                    mask_targets * 1)  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    w_ *
                    (torch.clamp(dtm_pred_**2, -0.1, 1.1) +
                     torch.clamp(dtm_targets**2, -0.1, 1)), 1) / (torch.sum(
                         w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_weighted_one_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    dtm_diff * weight_maps * hd_maps,
                    1) / (torch.sum(weight_maps, 1) +
                          1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_weighted_two_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    dtm_diff * weight_maps * (dtm_pred_**2 + dtm_targets**2),
                    1) / (torch.sum(weight_maps, 1) +
                          1e-4) * ctrness_targets * self.mask_size**2
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_one_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(dtm_diff * hd_maps,
                                           1) * ctrness_targets
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'hd_two_side_dtm' in self.mask_loss_type:
                dtm_diff = (
                    dtm_pred_ -
                    dtm_targets)**2  # 1's are inconsistent pixels in hd_maps
                hausdorff_loss = torch.sum(
                    dtm_diff *
                    (torch.clamp(dtm_pred_, -1.1, 1.1)**2 + dtm_targets**2),
                    1) * ctrness_targets
                hausdorff_loss = hausdorff_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += hausdorff_loss
            if 'contour_dice' in self.mask_loss_type:
                pred_contour = (dtm_pred_ + 0.9 < 0.55) * 1. * (
                    0.5 <= dtm_pred_ + 0.9
                )  # contour pixels with 0.05 tolerance
                target_contour = (dtm_targets < 0.05) * 1. * (dtm_targets <
                                                              0.05)
                # pred_contour = 0.5 <= dtm_pred_ + 0.9 < 0.55  # contour pixels with 0.05 tolerance
                # target_contour = 0. <= dtm_targets < 0.05
                overlap_ = torch.sum(pred_contour * 2. * target_contour, 1)
                union_ = torch.sum(pred_contour**2, 1) + torch.sum(
                    target_contour**2, 1)
                dice_loss = (
                    1. - overlap_ /
                    (union_ + 1e-4)) * ctrness_targets * self.mask_size**2
                dice_loss = dice_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += dice_loss
            if 'mask_dice' in self.mask_loss_type:
                overlap_ = torch.sum(binary_pred_ * 2. * mask_targets, 1)
                union_ = torch.sum(binary_pred_**2, 1) + torch.sum(
                    mask_targets**2, 1)
                dice_loss = (
                    1. - overlap_ /
                    (union_ + 1e-4)) * ctrness_targets * self.mask_size**2
                dice_loss = dice_loss.sum() / max(
                    ctrness_norm * self.mask_size**2, 1.0)
                total_mask_loss += dice_loss
        if self.loss_on_code:
            # m*m mask labels --> n_components encoding labels
            if 'mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(mask_pred,
                                       code_targets,
                                       reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                if self.mask_sparse_weight > 0.:
                    if self.sparsity_loss_type == 'L1':
                        sparsity_loss = torch.sum(torch.abs(mask_pred),
                                                  1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L1':
                        w_ = (
                            torch.abs(code_targets) < 1e-4
                        ) * 1.  # inactive codes, put L1 regularization on them
                        sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L2':
                        w_ = (
                            torch.abs(code_targets) < 1e-4
                        ) * 1.  # inactive codes, put L2 regularization on them
                        sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    else:
                        raise NotImplementedError
                total_mask_loss += mask_loss
            if 'smooth' in self.mask_loss_type:
                mask_loss = F.smooth_l1_loss(mask_pred,
                                             code_targets,
                                             reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'cosine' in self.mask_loss_type:
                mask_loss = loss_cos_sim(mask_pred, code_targets)
                mask_loss = mask_loss * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'kl_softmax' in self.mask_loss_type:
                mask_loss = loss_kl_div_softmax(mask_pred, code_targets)
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss

        losses = {
            "loss_DTInst_cls": class_loss,
            "loss_DTInst_loc": reg_loss,
            "loss_DTInst_ctr": ctrness_loss,
            "loss_DTInst_mask": total_mask_loss
        }
        return losses, {}
Ejemplo n.º 4
0
    def DTMRInst_losses(self, labels, reg_targets, logits_pred, reg_pred,
                        ctrness_pred, mask_pred, mask_pred_decoded_list,
                        mask_targets):
        num_classes = logits_pred.size(1)
        labels = labels.flatten()

        pos_inds = torch.nonzero(labels != num_classes).squeeze(1)
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        class_target = torch.zeros_like(logits_pred)
        class_target[pos_inds, labels[pos_inds]] = 1

        class_loss = sigmoid_focal_loss_jit(
            logits_pred,
            class_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        reg_pred = reg_pred[pos_inds]
        reg_targets = reg_targets[pos_inds]
        ctrness_pred = ctrness_pred[pos_inds]
        mask_pred = mask_pred[pos_inds]

        for i in range(len(mask_pred_decoded_list)):
            mask_pred_decoded_list[i] = mask_pred_decoded_list[i][pos_inds]

        assert mask_pred.shape[0] == mask_targets.shape[0], \
            print("The number(positive) should be equal between "
                  "masks_pred(prediction) and mask_targets(target).")

        ctrness_targets = compute_ctrness_targets(reg_targets)
        ctrness_targets_sum = ctrness_targets.sum()
        ctrness_norm = max(
            reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        reg_loss = self.iou_loss(reg_pred, reg_targets,
                                 ctrness_targets) / ctrness_norm

        ctrness_loss = F.binary_cross_entropy_with_logits(
            ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg

        total_mask_loss = 0.
        # _, binary_pred_ = self.mask_encoding.decoder(mask_pred, is_train=True)  # from sparse coefficients to DTMs/images
        # code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder(mask_targets)
        code_targets, dtm_targets, weight_maps, _ = self.mask_encoding.encoder(
            mask_targets)

        if self.loss_on_mask:
            if 'mask_mse' in self.mask_loss_type:
                mask_loss = 0
                for mask_pred_ in mask_pred_decoded_list:
                    _loss = F.mse_loss(mask_pred_,
                                       mask_targets,
                                       reduction='none')
                    _loss = _loss.sum(1) * ctrness_targets
                    _loss = _loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
                    mask_loss += _loss
                total_mask_loss += mask_loss
            if 'weighted_mask_mse' in self.mask_loss_type:
                mask_loss = 0
                for mask_pred_ in mask_pred_decoded_list:
                    _loss = F.mse_loss(mask_pred_,
                                       mask_targets,
                                       reduction='none')
                    _loss = torch.sum(_loss * weight_maps, 1) * ctrness_targets
                    _loss = _loss.sum() / torch.sum(weight_maps) / max(
                        ctrness_norm * self.mask_size**2, 1.0)
                    mask_loss += _loss
                total_mask_loss += mask_loss
            if 'mask_dice' in self.mask_loss_type:
                # This is to use all the output to calculate the mask loss
                dice_loss = 0
                for mask_pred_ in mask_pred_decoded_list:
                    overlap_ = torch.sum(mask_pred_ * 2. * mask_targets, 1)
                    union_ = torch.sum(mask_pred_**2, 1) + torch.sum(
                        mask_targets**2, 1)
                    _loss = (
                        1. - overlap_ /
                        (union_ + 1e-5)) * ctrness_targets * self.mask_size**2
                    _loss = _loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)
                    dice_loss += _loss

                # This is to just use the last output to calculate the mask loss
                mask_pred_ = mask_pred_decoded_list[-1]
                overlap_ = torch.sum(mask_pred_ * 2. * mask_targets, 1)
                union_ = torch.sum(mask_pred_**2, 1) + torch.sum(
                    mask_targets**2, 1)
                _loss = (1. - overlap_ /
                         (union_ + 1e-5)) * ctrness_targets * self.mask_size**2
                dice_loss = _loss.sum() / max(ctrness_norm * self.mask_size**2,
                                              1.0)

                total_mask_loss += dice_loss

        if self.loss_on_code:
            # m*m mask labels --> n_components encoding labels
            if 'mse' in self.mask_loss_type:
                mask_loss = F.mse_loss(mask_pred,
                                       code_targets,
                                       reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                if self.mask_sparse_weight > 0.:
                    if self.sparsity_loss_type == 'L1':
                        sparsity_loss = torch.sum(torch.abs(mask_pred),
                                                  1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L1':
                        w_ = (
                            torch.abs(code_targets) < 1e-3
                        ) * 1.  # inactive codes, put L1 regularization on them
                        sparsity_loss = torch.sum(
                            torch.abs(mask_pred) * w_, 1) * ctrness_targets
                        sparsity_loss = sparsity_loss.sum() / torch.sum(
                            w_) / max(ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + sparsity_loss * self.mask_sparse_weight
                    elif self.sparsity_loss_type == 'weighted_L2':
                        w_ = (
                            torch.abs(code_targets) < 1e-3
                        ) * 1.  # inactive codes, put L2 regularization on them
                        sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \
                                        * ctrness_targets * self.num_codes
                        sparsity_loss = sparsity_loss.sum() / max(
                            ctrness_norm * self.num_codes, 1.0)
                        mask_loss = mask_loss * self.mask_loss_weight + \
                                    sparsity_loss * self.mask_sparse_weight
                    else:
                        raise NotImplementedError
                total_mask_loss += mask_loss
            if 'smooth' in self.mask_loss_type:
                mask_loss = F.smooth_l1_loss(mask_pred,
                                             code_targets,
                                             reduction='none')
                mask_loss = mask_loss.sum(1) * ctrness_targets
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'cosine' in self.mask_loss_type:
                mask_loss = loss_cos_sim(mask_pred, code_targets)
                mask_loss = mask_loss * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss
            if 'kl_softmax' in self.mask_loss_type:
                mask_loss = loss_kl_div_softmax(mask_pred, code_targets)
                mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes
                mask_loss = mask_loss.sum() / max(
                    ctrness_norm * self.num_codes, 1.0)
                total_mask_loss += mask_loss

        losses = {
            "loss_DTMRInst_cls": class_loss,
            "loss_DTMRInst_loc": reg_loss,
            "loss_DTMRInst_ctr": ctrness_loss,
            "loss_DTMRInst_mask": total_mask_loss
        }
        return losses, {}