def fcos_losses(self, instances): num_classes = instances.logits_pred.size(1) assert num_classes == self.num_classes labels = instances.labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(instances.logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( instances.logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg instances = instances[pos_inds] instances.pos_inds = pos_inds ctrness_targets = compute_ctrness_targets(instances.reg_targets) ctrness_targets_sum = ctrness_targets.sum() loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) instances.gt_ctrs = ctrness_targets if pos_inds.numel() > 0: reg_loss = self.loc_loss_func( instances.reg_pred, instances.reg_targets, ctrness_targets ) / loss_denorm ctrness_loss = F.binary_cross_entropy_with_logits( instances.ctrness_pred, ctrness_targets, reduction="sum" ) / num_pos_avg else: reg_loss = instances.reg_pred.sum() * 0 ctrness_loss = instances.ctrness_pred.sum() * 0 losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss } extras = { "instances": instances, "loss_denorm": loss_denorm } return extras, losses
def fcos_losses( labels, reg_targets, logits_pred, reg_pred, ctrness_pred, focal_loss_alpha, focal_loss_gamma, iou_loss, ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss } return losses, {}
def compute_loss(p1_heatmap_list, p3_heatmap_list, p1_logits, p3_logits): # gt_bitmasks = gt_bitmasks.float() # mask_logits = mask_logits.sigmoid() num_gpus = get_world_size() num_dice = (p1_heatmap_list**2).sum() num_dice = reduce_sum(p1_logits.new_tensor([num_dice])).item() num_dice = max(num_dice / num_gpus, 1.0) p1_loss = F.mse_loss(p1_heatmap_list, p1_logits, reduction='sum') / num_dice num_dice = (p3_heatmap_list**2).sum() num_dice = reduce_sum(p3_logits.new_tensor([num_dice])).item() num_dice = max(num_dice / num_gpus, 1.0) p3_loss = F.mse_loss(p3_heatmap_list, p3_logits, reduction='sum') / num_dice # loss = (p1_loss + p3_loss) / 2 return p1_loss, p3_loss
def fcos_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, gt_inds, mask_centers_targets): num_classes = logits_pred.size(1) assert num_classes == self.num_classes labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] gt_inds = gt_inds[pos_inds] mask_center = mask_centers_targets[pos_inds] # 需要修改 # ctrness_targets = compute_ctrness_targets(reg_targets) # ctrness_targets_sum = ctrness_targets.sum() # loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) if pos_inds.numel() > 0: reg_loss = self.loc_loss_func( reg_pred, reg_targets, ctrness_pred, mask_center, ) else: reg_loss = reg_pred.sum() * 0 losses = {"loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss} extras = { "pos_inds": pos_inds, "gt_inds": gt_inds, } return losses, extras
def SMInst_losses( self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_targets ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg # print(mask_pred.size(), mask_tower_interm_outputs.size()) reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] # mask_activation_pred = mask_activation_pred[pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss( reg_pred, reg_targets, ctrness_targets ) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum" ) / num_pos_avg mask_targets_ = self.mask_encoding.encoder(mask_targets) mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True) # compute the loss for the activation code as binary classification # activation_targets = (torch.abs(mask_targets_) > 1e-4) * 1. # activation_loss = F.binary_cross_entropy_with_logits( # mask_activation_pred, # activation_targets, # reduction='none' # ) # activation_loss = activation_loss.sum(1) * ctrness_targets # activation_loss = activation_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) # if self.thresh_with_active: # mask_pred = mask_pred * torch.sigmoid(mask_activation_pred) total_mask_loss = 0. if self.loss_on_mask: # n_components predictions --> m*m mask predictions without sigmoid # as sigmoid function is combined in loss. # mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True) if 'mask_mse' in self.mask_loss_type: mask_loss = F.mse_loss( mask_pred_, mask_targets, reduction='none' ) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0) total_mask_loss += mask_loss if 'mask_iou' in self.mask_loss_type: overlap_ = torch.sum(mask_pred_bin * 1. * mask_targets, 1) union_ = torch.sum((mask_pred_bin + mask_targets) >= 1., 1) iou_loss = (1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size ** 2 iou_loss = iou_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0) total_mask_loss += iou_loss if 'mask_difference' in self.mask_loss_type: w_ = torch.abs(mask_pred_bin * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps md_loss = torch.sum(w_, 1) * ctrness_targets md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0) total_mask_loss += md_loss if self.loss_on_code: # m*m mask labels --> n_components encoding labels # mask_targets_ = self.mask_encoding.encoder(mask_targets) if 'mse' in self.mask_loss_type: mask_loss = F.mse_loss( mask_pred, mask_targets_, reduction='none' ) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) if self.mask_sparse_weight > 0.: if self.sparsity_loss_type == 'L1': sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'L0': w_ = (torch.abs(mask_targets_) >= 1e-2) * 1. # the number of codes that are active sparsity_loss = torch.sum(w_, 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L1': w_ = (torch.abs(mask_targets_) < 1e-2) * 1. # inactive codes, put L1 regularization on them sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L2': w_ = (torch.abs(mask_targets_) < 1e-2) * 1. # inactive codes, put L2 regularization on them sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_KL': w_ = (torch.abs(mask_targets_) < 1e-2) * 1. # inactive codes, put L2 regularization on them kl_ = kl_divergence( mask_pred, self.kl_rho ) sparsity_loss = torch.sum(kl_ * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight else: raise NotImplementedError total_mask_loss += mask_loss if 'smooth' in self.mask_loss_type: mask_loss = F.smooth_l1_loss( mask_pred, mask_targets_, reduction='none' ) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'cosine' in self.mask_loss_type: mask_loss = loss_cos_sim( mask_pred, mask_targets_ ) mask_loss = mask_loss * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'kl_softmax' in self.mask_loss_type: mask_loss = loss_kl_div_softmax( mask_pred, mask_targets_ ) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss elif 'kl_sigmoid' in self.mask_loss_type: mask_loss = loss_kl_div_sigmoid( mask_pred, mask_targets_ ) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss elif 'kl' in self.mask_loss_type: mask_loss = kl_divergence( mask_pred, self.kl_rho ) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss losses = { "loss_SMInst_cls": class_loss, "loss_SMInst_loc": reg_loss, "loss_SMInst_ctr": ctrness_loss, "loss_SMInst_mask": total_mask_loss, } return losses, {}
def fcos_losses(self, instances): num_classes = instances.logits_pred.size(1) assert num_classes == self.num_classes labels = instances.labels.flatten() gt_object = instances.gt_inds pos_inds = torch.nonzero(labels != num_classes).squeeze(1) neg_inds = torch.nonzero(labels == num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(instances.logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( instances.logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="none", ) #/ num_pos_avg positive_diff = ( 1 - instances.logits_pred[class_target == 1].sigmoid()).abs() negative_diff = ( 0 - instances.logits_pred[class_target == 0].sigmoid()).abs() positive_mean = positive_diff.mean().detach() positive_std = positive_diff.std().detach() negative_mean = negative_diff.mean().detach() negative_std = negative_diff.std().detach() upper_true_loss = class_loss.flatten()[(class_target == 1).flatten()][ (positive_diff > (positive_mean + positive_std))].sum() / num_pos_avg under_true_loss = class_loss.flatten()[(class_target == 1).flatten()][ (positive_diff <= (positive_mean + positive_std))].sum() / num_pos_avg upper_false_loss = class_loss.flatten()[(class_target == 0).flatten()][ (negative_diff > (negative_mean + negative_std))].sum() / num_pos_avg under_false_loss = class_loss.flatten()[(class_target == 0).flatten()][ (negative_diff <= (negative_mean + negative_std))].sum() / num_pos_avg storage = get_event_storage() if storage.iter % 20 == 0: logger.info( "upper_true {}, under_true {} upper_false {} under_false {}". format((positive_diff > positive_mean + positive_std).sum(), (positive_diff <= positive_mean + positive_std).sum(), (negative_diff > negative_mean + negative_std).sum(), (negative_diff <= negative_mean + negative_std).sum())) instances = instances[pos_inds] instances.pos_inds = pos_inds #assert (instances.gt_inds.unique() != gt_object.unique()).sum() == 0 ctrness_targets = compute_ctrness_targets(instances.reg_targets) ctrness_targets_sum = ctrness_targets.sum() loss_denorm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) instances.gt_ctrs = ctrness_targets if pos_inds.numel() > 0: reg_loss = self.loc_loss_func(instances.reg_pred, instances.reg_targets, ctrness_targets) / loss_denorm ctrness_loss = torch.nn.MSELoss(reduction="sum")( instances.ctrness_pred.sigmoid(), ctrness_targets) / num_pos_avg else: reg_loss = instances.reg_pred.sum() * 0 ctrness_loss = instances.ctrness_pred.sum() * 0 losses = { "loss_upper_true_cls": upper_true_loss, "loss_under_true_cls": under_true_loss, "loss_upper_false_cls": upper_false_loss, "loss_under_false_cls": under_false_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss, #"loss_negative_identity_mean": negative_identity_mean_loss, #"loss_negative_identity_std": negative_identity_std_loss, #"loss_positive_identity": positive_identity_loss, } extras = {"instances": instances, "loss_denorm": loss_denorm} return extras, losses
def MEInst_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_targets): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg if self.loss_on_mask: # n_components predictions --> m*m mask predictions without sigmoid # as sigmoid function is combined in loss. mask_pred = self.mask_encoding.decoder(mask_pred, is_train=True) mask_loss = self.mask_loss_func(mask_pred, mask_targets) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) else: # m*m mask labels --> n_components encoding labels mask_targets = self.mask_encoding.encoder(mask_targets) if self.mask_loss_type == 'mse': mask_loss = self.mask_loss_func(mask_pred, mask_targets) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask, 1.0) else: raise NotImplementedError losses = { "loss_MEInst_cls": class_loss, "loss_MEInst_loc": reg_loss, "loss_MEInst_ctr": ctrness_loss, "loss_MEInst_mask": mask_loss, } return losses, {}
def DTInst_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_targets): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg total_mask_loss = 0. dtm_pred_, binary_pred_ = self.mask_encoding.decoder(mask_pred, is_train=True) code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder( mask_targets) if self.loss_on_mask: if 'mask_mse' in self.mask_loss_type: mask_loss = F.mse_loss(dtm_pred_, dtm_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += mask_loss if 'weighted_mask_mse' in self.mask_loss_type: mask_loss = F.mse_loss(dtm_pred_, dtm_targets, reduction='none') mask_loss = torch.sum(mask_loss * weight_maps, 1) / torch.sum( weight_maps, 1) * ctrness_targets * self.mask_size**2 mask_loss = mask_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += mask_loss if 'mask_difference' in self.mask_loss_type: w_ = torch.abs( binary_pred_ * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps md_loss = torch.sum(w_, 1) * ctrness_targets md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += md_loss if 'hd_one_side_binary' in self.mask_loss_type: # the first attempt, not really accurate w_ = torch.abs( binary_pred_ * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum(w_ * hd_maps, 1) / (torch.sum( w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_two_side_binary' in self.mask_loss_type: # the first attempt, not really accurate w_ = torch.abs( binary_pred_ * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( w_ * (torch.clamp(dtm_pred_**2, -0.1, 1.1) + torch.clamp(dtm_targets**2, -0.1, 1)), 1) / (torch.sum( w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_weighted_one_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( dtm_diff * weight_maps * hd_maps, 1) / (torch.sum(weight_maps, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_weighted_two_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( dtm_diff * weight_maps * (dtm_pred_**2 + dtm_targets**2), 1) / (torch.sum(weight_maps, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_one_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum(dtm_diff * hd_maps, 1) * ctrness_targets hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_two_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( dtm_diff * (torch.clamp(dtm_pred_, -1.1, 1.1)**2 + dtm_targets**2), 1) * ctrness_targets hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'contour_dice' in self.mask_loss_type: pred_contour = (dtm_pred_ + 0.9 < 0.55) * 1. * ( 0.5 <= dtm_pred_ + 0.9 ) # contour pixels with 0.05 tolerance target_contour = (dtm_targets < 0.05) * 1. * (dtm_targets < 0.05) # pred_contour = 0.5 <= dtm_pred_ + 0.9 < 0.55 # contour pixels with 0.05 tolerance # target_contour = 0. <= dtm_targets < 0.05 overlap_ = torch.sum(pred_contour * 2. * target_contour, 1) union_ = torch.sum(pred_contour**2, 1) + torch.sum( target_contour**2, 1) dice_loss = ( 1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size**2 dice_loss = dice_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += dice_loss if 'mask_dice' in self.mask_loss_type: overlap_ = torch.sum(binary_pred_ * 2. * mask_targets, 1) union_ = torch.sum(binary_pred_**2, 1) + torch.sum( mask_targets**2, 1) dice_loss = ( 1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size**2 dice_loss = dice_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += dice_loss if self.loss_on_code: # m*m mask labels --> n_components encoding labels if 'mse' in self.mask_loss_type: mask_loss = F.mse_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) if self.mask_sparse_weight > 0.: if self.sparsity_loss_type == 'L1': sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L1': w_ = ( torch.abs(code_targets) < 1e-4 ) * 1. # inactive codes, put L1 regularization on them sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L2': w_ = ( torch.abs(code_targets) < 1e-4 ) * 1. # inactive codes, put L2 regularization on them sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight else: raise NotImplementedError total_mask_loss += mask_loss if 'smooth' in self.mask_loss_type: mask_loss = F.smooth_l1_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'cosine' in self.mask_loss_type: mask_loss = loss_cos_sim(mask_pred, code_targets) mask_loss = mask_loss * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'kl_softmax' in self.mask_loss_type: mask_loss = loss_kl_div_softmax(mask_pred, code_targets) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss losses = { "loss_DTInst_cls": class_loss, "loss_DTInst_loc": reg_loss, "loss_DTInst_ctr": ctrness_loss, "loss_DTInst_mask": total_mask_loss } return losses, {}
def fcos_losses( labels, reg_targets, bezier_targets, logits_pred, reg_pred, bezier_pred, ctrness_pred, focal_loss_alpha, focal_loss_gamma, iou_loss, ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] bezier_pred = bezier_pred[pos_inds] reg_targets = reg_targets[pos_inds] bezier_targets = bezier_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] ious, gious = compute_ious(reg_pred, reg_targets) ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) if pos_inds.numel() > 0: reg_loss = iou_loss(ious, gious, ctrness_targets) / loss_denorm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg else: reg_loss = reg_pred.sum() * 0 bezier_loss = bezier_pred.sum() * 0 ctrness_loss = ctrness_pred.sum() * 0 bezier_loss = F.smooth_l1_loss(bezier_pred, bezier_targets, reduction="none") bezier_loss = ((bezier_loss.mean(dim=-1) * ctrness_targets).sum() / loss_denorm) losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss, "loss_fcos_bezier": bezier_loss, } return losses
def DTMRInst_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_pred_decoded_list, mask_targets): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] for i in range(len(mask_pred_decoded_list)): mask_pred_decoded_list[i] = mask_pred_decoded_list[i][pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg total_mask_loss = 0. # _, binary_pred_ = self.mask_encoding.decoder(mask_pred, is_train=True) # from sparse coefficients to DTMs/images # code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder(mask_targets) code_targets, dtm_targets, weight_maps, _ = self.mask_encoding.encoder( mask_targets) if self.loss_on_mask: if 'mask_mse' in self.mask_loss_type: mask_loss = 0 for mask_pred_ in mask_pred_decoded_list: _loss = F.mse_loss(mask_pred_, mask_targets, reduction='none') _loss = _loss.sum(1) * ctrness_targets _loss = _loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) mask_loss += _loss total_mask_loss += mask_loss if 'weighted_mask_mse' in self.mask_loss_type: mask_loss = 0 for mask_pred_ in mask_pred_decoded_list: _loss = F.mse_loss(mask_pred_, mask_targets, reduction='none') _loss = torch.sum(_loss * weight_maps, 1) * ctrness_targets _loss = _loss.sum() / torch.sum(weight_maps) / max( ctrness_norm * self.mask_size**2, 1.0) mask_loss += _loss total_mask_loss += mask_loss if 'mask_dice' in self.mask_loss_type: # This is to use all the output to calculate the mask loss dice_loss = 0 for mask_pred_ in mask_pred_decoded_list: overlap_ = torch.sum(mask_pred_ * 2. * mask_targets, 1) union_ = torch.sum(mask_pred_**2, 1) + torch.sum( mask_targets**2, 1) _loss = ( 1. - overlap_ / (union_ + 1e-5)) * ctrness_targets * self.mask_size**2 _loss = _loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) dice_loss += _loss # This is to just use the last output to calculate the mask loss mask_pred_ = mask_pred_decoded_list[-1] overlap_ = torch.sum(mask_pred_ * 2. * mask_targets, 1) union_ = torch.sum(mask_pred_**2, 1) + torch.sum( mask_targets**2, 1) _loss = (1. - overlap_ / (union_ + 1e-5)) * ctrness_targets * self.mask_size**2 dice_loss = _loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += dice_loss if self.loss_on_code: # m*m mask labels --> n_components encoding labels if 'mse' in self.mask_loss_type: mask_loss = F.mse_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) if self.mask_sparse_weight > 0.: if self.sparsity_loss_type == 'L1': sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L1': w_ = ( torch.abs(code_targets) < 1e-3 ) * 1. # inactive codes, put L1 regularization on them sparsity_loss = torch.sum( torch.abs(mask_pred) * w_, 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / torch.sum( w_) / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L2': w_ = ( torch.abs(code_targets) < 1e-3 ) * 1. # inactive codes, put L2 regularization on them sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight else: raise NotImplementedError total_mask_loss += mask_loss if 'smooth' in self.mask_loss_type: mask_loss = F.smooth_l1_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'cosine' in self.mask_loss_type: mask_loss = loss_cos_sim(mask_pred, code_targets) mask_loss = mask_loss * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'kl_softmax' in self.mask_loss_type: mask_loss = loss_kl_div_softmax(mask_pred, code_targets) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss losses = { "loss_DTMRInst_cls": class_loss, "loss_DTMRInst_loc": reg_loss, "loss_DTMRInst_ctr": ctrness_loss, "loss_DTMRInst_mask": total_mask_loss } return losses, {}
def compute_loss_softmax(gt_bitmasks, mask_logits, num_loss, num_instances, direction, direction_mask_logits, gt_keypoint, max_ranges, distance_norm): assert not torch.isnan(mask_logits).any() assert not torch.isnan(direction).any() assert not torch.isnan(direction_mask_logits).any() # direction_mask_logits = direction_mask_logits.detach() N, K, H, W = gt_bitmasks.size() # gt_bitmasks = gt_bitmasks.float() num_gpus = get_world_size() assert not (num_loss == 0).any() loss_weight = 1 / num_loss #TODO num_loss can be 0 sum_loss_weight = loss_weight.sum() assert sum_loss_weight != 0 loss_weight = loss_weight[:, None].repeat(1, 17).flatten() gt_bitmasks = gt_bitmasks.reshape(N * K, H * W) mask_logits = mask_logits.reshape(N * K, H * W) gt_bitmasks_visible_mask = gt_bitmasks.sum(dim=1).bool() # assert gt_bitmasks_visible_mask.sum()!=0 #TODO AssertionError if gt_bitmasks_visible_mask.sum() != 0: loss_weight = loss_weight[gt_bitmasks_visible_mask] mask_logits = mask_logits[gt_bitmasks_visible_mask] gt_bitmasks = gt_bitmasks[gt_bitmasks_visible_mask] mask_logits = F.log_softmax(mask_logits, dim=1) total_instances = reduce_sum(mask_logits.new_tensor([num_instances ])).item() gpu_balence_factor = num_instances / total_instances loss = (-mask_logits[gt_bitmasks]) loss = (loss * loss_weight).sum() / 17 loss = (loss / sum_loss_weight) * gpu_balence_factor max_ranges = max_ranges[:, None].repeat( 1, 17).flatten()[gt_bitmasks_visible_mask] gt_keypoint = gt_keypoint[:, :, [0, 1]] N, H, W, K, _ = direction_mask_logits.size() direction = direction - gt_keypoint[:, None, None, :, :] direction = direction.permute(0, 3, 1, 2, 4).reshape(N * 17, H, W, 2) direction = direction[gt_bitmasks_visible_mask] direction = (direction[:, :, :, 0]**2 + direction[:, :, :, 1]**2).sqrt()[:, :, :, None] assert (max_ranges != 0).all() direction = direction / max_ranges[:, None, None, None] direction = direction * distance_norm direction = (direction.sigmoid() - 0.5) * 2 direction_mask_logits = direction_mask_logits.permute( 0, 3, 1, 2, 4).reshape(N * 17, H, W, 1) direction_mask_logits = direction_mask_logits[gt_bitmasks_visible_mask] direction = direction * direction_mask_logits direction = direction.flatten(start_dim=1).sum(dim=1) direction = direction * loss_weight assert distance_norm != 0 direction_loss = (direction / sum_loss_weight * gpu_balence_factor) / distance_norm direction_loss = direction_loss.sum() assert not torch.isnan(direction_loss).any() assert not torch.isnan(loss).any() return loss, direction_loss else: print('gt_bitmasks_visible_mask.sum()==0') total_instances = reduce_sum(mask_logits.new_tensor([num_instances ])).item() loss = mask_logits.sum() + direction.sum() + direction_mask_logits.sum( ) loss = loss * 0.0 return loss, loss