def SMInst_losses( self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_targets ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg # print(mask_pred.size(), mask_tower_interm_outputs.size()) reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] # mask_activation_pred = mask_activation_pred[pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss( reg_pred, reg_targets, ctrness_targets ) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum" ) / num_pos_avg mask_targets_ = self.mask_encoding.encoder(mask_targets) mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True) # compute the loss for the activation code as binary classification # activation_targets = (torch.abs(mask_targets_) > 1e-4) * 1. # activation_loss = F.binary_cross_entropy_with_logits( # mask_activation_pred, # activation_targets, # reduction='none' # ) # activation_loss = activation_loss.sum(1) * ctrness_targets # activation_loss = activation_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) # if self.thresh_with_active: # mask_pred = mask_pred * torch.sigmoid(mask_activation_pred) total_mask_loss = 0. if self.loss_on_mask: # n_components predictions --> m*m mask predictions without sigmoid # as sigmoid function is combined in loss. # mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True) if 'mask_mse' in self.mask_loss_type: mask_loss = F.mse_loss( mask_pred_, mask_targets, reduction='none' ) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0) total_mask_loss += mask_loss if 'mask_iou' in self.mask_loss_type: overlap_ = torch.sum(mask_pred_bin * 1. * mask_targets, 1) union_ = torch.sum((mask_pred_bin + mask_targets) >= 1., 1) iou_loss = (1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size ** 2 iou_loss = iou_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0) total_mask_loss += iou_loss if 'mask_difference' in self.mask_loss_type: w_ = torch.abs(mask_pred_bin * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps md_loss = torch.sum(w_, 1) * ctrness_targets md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0) total_mask_loss += md_loss if self.loss_on_code: # m*m mask labels --> n_components encoding labels # mask_targets_ = self.mask_encoding.encoder(mask_targets) if 'mse' in self.mask_loss_type: mask_loss = F.mse_loss( mask_pred, mask_targets_, reduction='none' ) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) if self.mask_sparse_weight > 0.: if self.sparsity_loss_type == 'L1': sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'L0': w_ = (torch.abs(mask_targets_) >= 1e-2) * 1. # the number of codes that are active sparsity_loss = torch.sum(w_, 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L1': w_ = (torch.abs(mask_targets_) < 1e-2) * 1. # inactive codes, put L1 regularization on them sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L2': w_ = (torch.abs(mask_targets_) < 1e-2) * 1. # inactive codes, put L2 regularization on them sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_KL': w_ = (torch.abs(mask_targets_) < 1e-2) * 1. # inactive codes, put L2 regularization on them kl_ = kl_divergence( mask_pred, self.kl_rho ) sparsity_loss = torch.sum(kl_ * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight else: raise NotImplementedError total_mask_loss += mask_loss if 'smooth' in self.mask_loss_type: mask_loss = F.smooth_l1_loss( mask_pred, mask_targets_, reduction='none' ) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'cosine' in self.mask_loss_type: mask_loss = loss_cos_sim( mask_pred, mask_targets_ ) mask_loss = mask_loss * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'kl_softmax' in self.mask_loss_type: mask_loss = loss_kl_div_softmax( mask_pred, mask_targets_ ) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss elif 'kl_sigmoid' in self.mask_loss_type: mask_loss = loss_kl_div_sigmoid( mask_pred, mask_targets_ ) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss elif 'kl' in self.mask_loss_type: mask_loss = kl_divergence( mask_pred, self.kl_rho ) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss losses = { "loss_SMInst_cls": class_loss, "loss_SMInst_loc": reg_loss, "loss_SMInst_ctr": ctrness_loss, "loss_SMInst_mask": total_mask_loss, } return losses, {}
def MEInst_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_targets): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg total_mask_loss = 0. if self.loss_on_mask: # n_components predictions --> m*m mask predictions without sigmoid # as sigmoid function is combined in loss. mask_pred_ = self.mask_encoding.decoder(mask_pred, is_train=True) mask_loss = self.bce(mask_pred_, mask_targets) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += mask_loss if self.loss_on_code: # m*m mask labels --> n_components encoding labels mask_targets_ = self.mask_encoding.encoder(mask_targets) if 'mse' in self.mask_loss_type: mask_loss = F.mse_loss(mask_pred, mask_targets_, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask, 1.0) total_mask_loss += mask_loss if 'smooth' in self.mask_loss_type: mask_loss = F.smooth_l1_loss(mask_pred, mask_targets_, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask, 1.0) total_mask_loss += mask_loss if 'cosine' in self.mask_loss_type: mask_loss = loss_cos_sim(mask_pred, mask_targets_) mask_loss = mask_loss * ctrness_targets * self.dim_mask mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask, 1.0) total_mask_loss += mask_loss if 'kl_softmax' in self.mask_loss_type: mask_loss = loss_kl_div_softmax(mask_pred, mask_targets_) mask_loss = mask_loss.sum(1) * ctrness_targets * self.dim_mask mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask, 1.0) total_mask_loss += mask_loss losses = { "loss_MEInst_cls": class_loss, "loss_MEInst_loc": reg_loss, "loss_MEInst_ctr": ctrness_loss, "loss_MEInst_mask": total_mask_loss, } return losses, {}
def DTInst_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_targets): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg total_mask_loss = 0. dtm_pred_, binary_pred_ = self.mask_encoding.decoder(mask_pred, is_train=True) code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder( mask_targets) if self.loss_on_mask: if 'mask_mse' in self.mask_loss_type: mask_loss = F.mse_loss(dtm_pred_, dtm_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += mask_loss if 'weighted_mask_mse' in self.mask_loss_type: mask_loss = F.mse_loss(dtm_pred_, dtm_targets, reduction='none') mask_loss = torch.sum(mask_loss * weight_maps, 1) / torch.sum( weight_maps, 1) * ctrness_targets * self.mask_size**2 mask_loss = mask_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += mask_loss if 'mask_difference' in self.mask_loss_type: w_ = torch.abs( binary_pred_ * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps md_loss = torch.sum(w_, 1) * ctrness_targets md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += md_loss if 'hd_one_side_binary' in self.mask_loss_type: # the first attempt, not really accurate w_ = torch.abs( binary_pred_ * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum(w_ * hd_maps, 1) / (torch.sum( w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_two_side_binary' in self.mask_loss_type: # the first attempt, not really accurate w_ = torch.abs( binary_pred_ * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( w_ * (torch.clamp(dtm_pred_**2, -0.1, 1.1) + torch.clamp(dtm_targets**2, -0.1, 1)), 1) / (torch.sum( w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_weighted_one_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( dtm_diff * weight_maps * hd_maps, 1) / (torch.sum(weight_maps, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_weighted_two_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( dtm_diff * weight_maps * (dtm_pred_**2 + dtm_targets**2), 1) / (torch.sum(weight_maps, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_one_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum(dtm_diff * hd_maps, 1) * ctrness_targets hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_two_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( dtm_diff * (torch.clamp(dtm_pred_, -1.1, 1.1)**2 + dtm_targets**2), 1) * ctrness_targets hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'contour_dice' in self.mask_loss_type: pred_contour = (dtm_pred_ + 0.9 < 0.55) * 1. * ( 0.5 <= dtm_pred_ + 0.9 ) # contour pixels with 0.05 tolerance target_contour = (dtm_targets < 0.05) * 1. * (dtm_targets < 0.05) # pred_contour = 0.5 <= dtm_pred_ + 0.9 < 0.55 # contour pixels with 0.05 tolerance # target_contour = 0. <= dtm_targets < 0.05 overlap_ = torch.sum(pred_contour * 2. * target_contour, 1) union_ = torch.sum(pred_contour**2, 1) + torch.sum( target_contour**2, 1) dice_loss = ( 1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size**2 dice_loss = dice_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += dice_loss if 'mask_dice' in self.mask_loss_type: overlap_ = torch.sum(binary_pred_ * 2. * mask_targets, 1) union_ = torch.sum(binary_pred_**2, 1) + torch.sum( mask_targets**2, 1) dice_loss = ( 1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size**2 dice_loss = dice_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += dice_loss if self.loss_on_code: # m*m mask labels --> n_components encoding labels if 'mse' in self.mask_loss_type: mask_loss = F.mse_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) if self.mask_sparse_weight > 0.: if self.sparsity_loss_type == 'L1': sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L1': w_ = ( torch.abs(code_targets) < 1e-4 ) * 1. # inactive codes, put L1 regularization on them sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L2': w_ = ( torch.abs(code_targets) < 1e-4 ) * 1. # inactive codes, put L2 regularization on them sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight else: raise NotImplementedError total_mask_loss += mask_loss if 'smooth' in self.mask_loss_type: mask_loss = F.smooth_l1_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'cosine' in self.mask_loss_type: mask_loss = loss_cos_sim(mask_pred, code_targets) mask_loss = mask_loss * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'kl_softmax' in self.mask_loss_type: mask_loss = loss_kl_div_softmax(mask_pred, code_targets) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss losses = { "loss_DTInst_cls": class_loss, "loss_DTInst_loc": reg_loss, "loss_DTInst_ctr": ctrness_loss, "loss_DTInst_mask": total_mask_loss } return losses, {}
def DTMRInst_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_pred_decoded_list, mask_targets): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] for i in range(len(mask_pred_decoded_list)): mask_pred_decoded_list[i] = mask_pred_decoded_list[i][pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg total_mask_loss = 0. # _, binary_pred_ = self.mask_encoding.decoder(mask_pred, is_train=True) # from sparse coefficients to DTMs/images # code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder(mask_targets) code_targets, dtm_targets, weight_maps, _ = self.mask_encoding.encoder( mask_targets) if self.loss_on_mask: if 'mask_mse' in self.mask_loss_type: mask_loss = 0 for mask_pred_ in mask_pred_decoded_list: _loss = F.mse_loss(mask_pred_, mask_targets, reduction='none') _loss = _loss.sum(1) * ctrness_targets _loss = _loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) mask_loss += _loss total_mask_loss += mask_loss if 'weighted_mask_mse' in self.mask_loss_type: mask_loss = 0 for mask_pred_ in mask_pred_decoded_list: _loss = F.mse_loss(mask_pred_, mask_targets, reduction='none') _loss = torch.sum(_loss * weight_maps, 1) * ctrness_targets _loss = _loss.sum() / torch.sum(weight_maps) / max( ctrness_norm * self.mask_size**2, 1.0) mask_loss += _loss total_mask_loss += mask_loss if 'mask_dice' in self.mask_loss_type: # This is to use all the output to calculate the mask loss dice_loss = 0 for mask_pred_ in mask_pred_decoded_list: overlap_ = torch.sum(mask_pred_ * 2. * mask_targets, 1) union_ = torch.sum(mask_pred_**2, 1) + torch.sum( mask_targets**2, 1) _loss = ( 1. - overlap_ / (union_ + 1e-5)) * ctrness_targets * self.mask_size**2 _loss = _loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) dice_loss += _loss # This is to just use the last output to calculate the mask loss mask_pred_ = mask_pred_decoded_list[-1] overlap_ = torch.sum(mask_pred_ * 2. * mask_targets, 1) union_ = torch.sum(mask_pred_**2, 1) + torch.sum( mask_targets**2, 1) _loss = (1. - overlap_ / (union_ + 1e-5)) * ctrness_targets * self.mask_size**2 dice_loss = _loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += dice_loss if self.loss_on_code: # m*m mask labels --> n_components encoding labels if 'mse' in self.mask_loss_type: mask_loss = F.mse_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) if self.mask_sparse_weight > 0.: if self.sparsity_loss_type == 'L1': sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L1': w_ = ( torch.abs(code_targets) < 1e-3 ) * 1. # inactive codes, put L1 regularization on them sparsity_loss = torch.sum( torch.abs(mask_pred) * w_, 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / torch.sum( w_) / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L2': w_ = ( torch.abs(code_targets) < 1e-3 ) * 1. # inactive codes, put L2 regularization on them sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight else: raise NotImplementedError total_mask_loss += mask_loss if 'smooth' in self.mask_loss_type: mask_loss = F.smooth_l1_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'cosine' in self.mask_loss_type: mask_loss = loss_cos_sim(mask_pred, code_targets) mask_loss = mask_loss * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'kl_softmax' in self.mask_loss_type: mask_loss = loss_kl_div_softmax(mask_pred, code_targets) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss losses = { "loss_DTMRInst_cls": class_loss, "loss_DTMRInst_loc": reg_loss, "loss_DTMRInst_ctr": ctrness_loss, "loss_DTMRInst_mask": total_mask_loss } return losses, {}