def loss_labels(self, outputs, targets, indices, num_boxes, log=False): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ assert 'pred_logits' in outputs src_logits = outputs['pred_logits'] idx = self._get_src_permutation_idx(indices) target_classes_o = torch.cat( [t["labels"][J] for t, (_, J) in zip(targets, indices)]) target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) target_classes[idx] = target_classes_o if self.use_focal: src_logits = src_logits.flatten(0, 1) # prepare one_hot target. target_classes = target_classes.flatten(0, 1) pos_inds = torch.nonzero(target_classes != self.num_classes, as_tuple=True)[0] labels = torch.zeros_like(src_logits) labels[pos_inds, target_classes[pos_inds]] = 1 # comp focal loss. class_loss = sigmoid_focal_loss_jit( src_logits, labels, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_boxes losses = {'loss_ce': class_loss} else: loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) losses = {'loss_ce': loss_ce} if log: # TODO this should probably be a separate loss, not hacked in this one here losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] return losses
def losses( self, anchors, gt_classes, gt_boxes, pred_class_logits, pred_anchor_deltas, pred_class_logits_var=None, pred_bbox_cov=None): """ Args: For `gt_classes` and `gt_anchors_deltas` parameters, see :meth:`RetinaNet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x A) For `pred_class_logits`, `pred_anchor_deltas`, `pred_class_logits_var` and `pred_bbox_cov`, see :meth:`RetinaNetHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ num_images = len(gt_classes) gt_labels = torch.stack(gt_classes) # (N, R) anchors = type(anchors[0]).cat(anchors).tensor # (R, 4) gt_anchor_deltas = [ self.box2box_transform.get_deltas( anchors, k) for k in gt_boxes] gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4) valid_mask = gt_labels >= 0 pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) num_pos_anchors = pos_mask.sum().item() get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + \ (1 - self.loss_normalizer_momentum) * max(num_pos_anchors, 1) # classification and regression loss # Shapes: # (N x R, K) for class_logits and class_logits_var. # (N x R, 4), (N x R x 10) for pred_anchor_deltas and pred_class_bbox_cov respectively. # Transform per-feature layer lists to a single tensor pred_class_logits = cat(pred_class_logits, dim=1) pred_anchor_deltas = cat(pred_anchor_deltas, dim=1) if pred_class_logits_var is not None: pred_class_logits_var = cat( pred_class_logits_var, dim=1) if pred_bbox_cov is not None: pred_bbox_cov = cat( pred_bbox_cov, dim=1) gt_classes_target = torch.nn.functional.one_hot( gt_labels[valid_mask], num_classes=self.num_classes + 1)[ :, :- 1].to( pred_class_logits[0].dtype) # no loss for the last (background) class # Classification losses if self.compute_cls_var: # Compute classification variance according to: # "What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?", NIPS 2017 if self.cls_var_loss == 'loss_attenuation': num_samples = self.cls_var_num_samples # Compute standard deviation pred_class_logits_var = torch.sqrt(torch.exp( pred_class_logits_var[valid_mask])) pred_class_logits = pred_class_logits[valid_mask] # Produce normal samples using logits as the mean and the standard deviation computed above # Scales with GPU memory. 12 GB ---> 3 Samples per anchor for # COCO dataset. univariate_normal_dists = distributions.normal.Normal( pred_class_logits, scale=pred_class_logits_var) pred_class_stochastic_logits = univariate_normal_dists.rsample( (num_samples,)) pred_class_stochastic_logits = pred_class_stochastic_logits.view( (pred_class_stochastic_logits.shape[1] * num_samples, pred_class_stochastic_logits.shape[2], -1)) pred_class_stochastic_logits = pred_class_stochastic_logits.squeeze( 2) # Produce copies of the target classes to match the number of # stochastic samples. gt_classes_target = torch.unsqueeze(gt_classes_target, 0) gt_classes_target = torch.repeat_interleave( gt_classes_target, num_samples, dim=0).view( (gt_classes_target.shape[1] * num_samples, gt_classes_target.shape[2], -1)) gt_classes_target = gt_classes_target.squeeze(2) # Produce copies of the target classes to form the stochastic # focal loss. loss_cls = sigmoid_focal_loss_jit( pred_class_stochastic_logits, gt_classes_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / (num_samples * max(1, self.loss_normalizer)) else: raise ValueError( 'Invalid classification loss name {}.'.format( self.bbox_cov_loss)) else: # Standard loss computation in case one wants to use this code # without any probabilistic inference. loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_mask], gt_classes_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, self.loss_normalizer) # Compute Regression Loss pred_anchor_deltas = pred_anchor_deltas[pos_mask] gt_anchors_deltas = gt_anchor_deltas[pos_mask] if self.compute_bbox_cov: # We have to clamp the output variance else probabilistic metrics # go to infinity. pred_bbox_cov = clamp_log_variance(pred_bbox_cov[pos_mask]) if self.bbox_cov_loss == 'negative_log_likelihood': if self.bbox_cov_type == 'diagonal': # Compute regression variance according to: # "What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?", NIPS 2017 # This implementation with smooth_l1_loss outperforms using # torch.distribution.multivariate_normal. Losses might have different numerical values # since we do not include constants in this implementation. loss_box_reg = 0.5 * torch.exp(-pred_bbox_cov) * smooth_l1_loss( pred_anchor_deltas, gt_anchors_deltas, beta=self.smooth_l1_beta) loss_covariance_regularize = 0.5 * pred_bbox_cov loss_box_reg += loss_covariance_regularize # Sum over all elements loss_box_reg = torch.sum( loss_box_reg) / max(1, self.loss_normalizer) else: # Multivariate negative log likelihood. Implemented with # pytorch multivariate_normal.log_prob function. Custom implementations fail to finish training # due to NAN loss. # This is the Cholesky decomposition of the covariance matrix. We reconstruct it from 10 estimated # parameters as a lower triangular matrix. forecaster_cholesky = covariance_output_to_cholesky( pred_bbox_cov) # Compute multivariate normal distribution using torch # distribution functions. multivariate_normal_dists = distributions.multivariate_normal.MultivariateNormal( pred_anchor_deltas, scale_tril=forecaster_cholesky) loss_box_reg = - \ multivariate_normal_dists.log_prob(gt_anchors_deltas) loss_box_reg = torch.sum( loss_box_reg) / max(1, self.loss_normalizer) elif self.bbox_cov_loss == 'second_moment_matching': # Compute regression covariance using second moment matching. loss_box_reg = smooth_l1_loss( pred_anchor_deltas, gt_anchors_deltas, beta=self.smooth_l1_beta) # Compute errors errors = (pred_anchor_deltas - gt_anchors_deltas) if self.bbox_cov_type == 'diagonal': # Compute second moment matching term. second_moment_matching_term = smooth_l1_loss( torch.exp(pred_bbox_cov), errors ** 2, beta=self.smooth_l1_beta) loss_box_reg += second_moment_matching_term loss_box_reg = torch.sum( loss_box_reg) / max(1, self.loss_normalizer) else: # Compute second moment matching term. errors = torch.unsqueeze(errors, 2) gt_error_covar = torch.matmul( errors, torch.transpose(errors, 2, 1)) # This is the cholesky decomposition of the covariance matrix. We reconstruct it from 10 estimated # parameters as a lower triangular matrix. forecaster_cholesky = covariance_output_to_cholesky( pred_bbox_cov) predicted_covar = torch.matmul( forecaster_cholesky, torch.transpose( forecaster_cholesky, 2, 1)) second_moment_matching_term = smooth_l1_loss( predicted_covar, gt_error_covar, beta=self.smooth_l1_beta, reduction='sum') loss_box_reg = (torch.sum( loss_box_reg) + second_moment_matching_term) / max(1, self.loss_normalizer) elif self.bbox_cov_loss == 'energy_loss': # Compute regression variance according to energy score loss. forecaster_means = pred_anchor_deltas # Compute forecaster cholesky. Takes care of diagonal case # automatically. forecaster_cholesky = covariance_output_to_cholesky( pred_bbox_cov) # Define normal distribution samples. To compute energy score, # we need i+1 samples. # Define per-anchor Distributions multivariate_normal_dists = distributions.multivariate_normal.MultivariateNormal( forecaster_means, scale_tril=forecaster_cholesky) # Define Monte-Carlo Samples distributions_samples = multivariate_normal_dists.rsample( (self.bbox_cov_num_samples + 1,)) distributions_samples_1 = distributions_samples[0:self.bbox_cov_num_samples, :, :] distributions_samples_2 = distributions_samples[1: self.bbox_cov_num_samples + 1, :, :] # Compute energy score gt_anchors_deltas_samples = torch.repeat_interleave( gt_anchors_deltas.unsqueeze(0), self.bbox_cov_num_samples, dim=0) energy_score_first_term = 2.0 * smooth_l1_loss( distributions_samples_1, gt_anchors_deltas_samples, beta=self.smooth_l1_beta, reduction="sum") / self.bbox_cov_num_samples # First term energy_score_second_term = - smooth_l1_loss( distributions_samples_1, distributions_samples_2, beta=self.smooth_l1_beta, reduction="sum") / self.bbox_cov_num_samples # Second term # Final Loss loss_box_reg = ( energy_score_first_term + energy_score_second_term) / max(1, self.loss_normalizer) else: raise ValueError( 'Invalid regression loss name {}.'.format( self.bbox_cov_loss)) # Perform loss annealing. Essential for reliably training variance estimates using NLL in RetinaNet. # For energy score and second moment matching, this is optional. standard_regression_loss = smooth_l1_loss( pred_anchor_deltas, gt_anchors_deltas, beta=self.smooth_l1_beta, reduction="sum", ) / max(1, self.loss_normalizer) probabilistic_loss_weight = get_probabilistic_loss_weight( self.current_step, self.annealing_step) loss_box_reg = (1.0 - probabilistic_loss_weight) * \ standard_regression_loss + probabilistic_loss_weight * loss_box_reg else: # Standard regression loss in case no variance is needed to be # estimated. loss_box_reg = smooth_l1_loss( pred_anchor_deltas, gt_anchors_deltas, beta=self.smooth_l1_beta, reduction="sum", ) / max(1, self.loss_normalizer) return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
def losses(self, init_gt_classes, init_reg_targets, refine_gt_classes, refine_reg_targets, \ pred_class_logits, pred_box_reg_init, pred_box_reg, pred_center_score, strides, pred_ratio): strides = strides.repeat(pred_class_logits[0].shape[0]) # [N*X] pred_class_logits, pred_box_reg_init, pred_box_reg, pred_center_score, pred_ratio = \ permute_and_concat(pred_class_logits, pred_box_reg_init, pred_box_reg, pred_center_score, pred_ratio, self.num_classes) # Shapes: (N x R) and (N x R, 4), (N x R) respectively. init_gt_classes = init_gt_classes.flatten() init_reg_targets = init_reg_targets.view(-1, 4) init_foreground_idxs = (init_gt_classes >= 0) & (init_gt_classes != self.num_classes) init_pos_inds = torch.nonzero(init_foreground_idxs).squeeze(1) num_gpus = get_num_gpus() # sync num_pos from all gpus init_total_num_pos = reduce_sum(init_pos_inds.new_tensor([init_pos_inds.numel()])).item() init_num_pos_avg_per_gpu = max(init_total_num_pos / float(num_gpus), 1.0) refine_gt_classes = refine_gt_classes.flatten() refine_reg_targets = refine_reg_targets.view(-1, 4) refine_foreground_idxs = (refine_gt_classes >= 0) & (refine_gt_classes != self.num_classes) refine_pos_inds = torch.nonzero(refine_foreground_idxs).squeeze(1) # sync num_pos from all gpus refine_total_num_pos = reduce_sum(refine_pos_inds.new_tensor([refine_pos_inds.numel()])).item() refine_num_pos_avg_per_gpu = max(refine_total_num_pos / float(num_gpus), 1.0) gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[refine_foreground_idxs, refine_gt_classes[refine_foreground_idxs]] = 1 # logits loss cls_loss = sigmoid_focal_loss_jit( pred_class_logits, gt_classes_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / refine_num_pos_avg_per_gpu init_foreground_targets = init_reg_targets[init_foreground_idxs] gt_ratio_1 = (init_foreground_targets[:,0] + init_foreground_targets[:,2]) \ / (init_foreground_targets[:,1] + init_foreground_targets[:,3]) gt_ratio_2 = 1 / gt_ratio_1 gt_ratios = torch.stack((gt_ratio_1,gt_ratio_2), dim = 1) gt_ratio = gt_ratios.min(dim=1)[0] gt_center_score = compute_centerness_targets(init_reg_targets[init_foreground_idxs], gt_ratio) # average sum_centerness_targets from all gpus, # which is used to normalize centerness-weighed reg loss sum_centerness_targets_avg_per_gpu = \ reduce_sum(gt_center_score.sum()).item() / float(num_gpus) reg_loss_init = iou_loss( pred_box_reg_init[init_foreground_idxs], init_reg_targets[init_foreground_idxs], gt_center_score, loss_type=self.iou_loss_type ) / sum_centerness_targets_avg_per_gpu coords_norm_refine = strides[refine_foreground_idxs].unsqueeze(-1) * 4 reg_loss = smooth_l1_loss( pred_box_reg[refine_foreground_idxs] / coords_norm_refine, refine_reg_targets[refine_foreground_idxs] / coords_norm_refine, 0.11, reduction="sum") / max(1, refine_num_pos_avg_per_gpu) # reg_loss = iou_loss( # pred_box_reg[refine_foreground_idxs], refine_reg_targets[refine_foreground_idxs], 1, # loss_type=self.iou_loss_type # ) / sum_centerness_targets_avg_per_gpu centerness_loss = F.binary_cross_entropy_with_logits( torch.pow(torch.abs(pred_center_score[init_foreground_idxs]), pred_ratio[init_foreground_idxs]), gt_center_score, reduction='sum' ) / init_num_pos_avg_per_gpu return dict(cls_loss=cls_loss, reg_loss_init=reg_loss_init, reg_loss=reg_loss, centerness_loss=centerness_loss)
def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes): """ Args: anchors (list[Boxes]): a list of #feature level Boxes gt_labels, gt_boxes: see output of :meth:`RetinaNet.label_anchors`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x Ai) pred_logits, pred_anchor_deltas: both are list[Tensor]. Each element in the list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4). Where K is the number of classes used in `pred_logits`. Returns: dict[str, Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ num_images = len(gt_labels) gt_labels = torch.stack(gt_labels) # (N, R) anchors = type(anchors[0]).cat(anchors).tensor # (R, 4) gt_anchor_deltas = [self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes] gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4) valid_mask = gt_labels >= 0 pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) num_pos_anchors = pos_mask.sum().item() get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( 1 - self.loss_normalizer_momentum ) * max(num_pos_anchors, 1) # classification and regression loss gt_labels_target = F.one_hot(gt_labels[valid_mask], num_classes=self.num_classes + 1)[ :, :-1 ] # no loss for the last (background) class loss_cls = sigmoid_focal_loss_jit( cat(pred_logits, dim=1)[valid_mask], gt_labels_target.to(pred_logits[0].dtype), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) if self.box_reg_loss_type == "smooth_l1": loss_box_reg = smooth_l1_loss( cat(pred_anchor_deltas, dim=1)[pos_mask], gt_anchor_deltas[pos_mask], beta=self.smooth_l1_beta, reduction="sum", ) elif self.box_reg_loss_type == "giou": pred_boxes = [ self.box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1) ] loss_box_reg = giou_loss( torch.stack(pred_boxes)[pos_mask], torch.stack(gt_boxes)[pos_mask], reduction="sum" ) else: raise ValueError(f"Invalid bbox reg loss type '{self.box_reg_loss_type}'") return { "loss_cls": loss_cls / self.loss_normalizer, "loss_box_reg": loss_box_reg / self.loss_normalizer, }
def forward(self, features, gt_instances=None): for i, f in enumerate(self.in_features): if i == 0: x = self.refine[i](features[f]) else: x_p = self.refine[i](features[f]) target_h, target_w = x.size()[2:] h, w = x_p.size()[2:] assert target_h % h == 0 assert target_w % w == 0 factor_h, factor_w = target_h // h, target_w // w assert factor_h == factor_w x_p = F.interpolate(x_p, scale_factor=factor_h, mode='bilinear', align_corners=True) # x_p = aligned_bilinear(x_p, factor_h) x = x + x_p mask_feats = self.tower(x) if self.num_outputs == 0: mask_feats = mask_feats[:, :self.num_outputs] losses = {} # auxiliary thing semantic loss if self.training and self.sem_loss_on: logits_pred = self.logits( self.seg_head(features[self.in_features[0]])) # compute semantic targets semantic_targets = [] for per_im_gt in gt_instances: h, w = per_im_gt.gt_bitmasks_full.size()[-2:] areas = per_im_gt.gt_bitmasks_full.sum(dim=-1).sum(dim=-1) areas = areas[:, None, None].repeat(1, h, w) areas[per_im_gt.gt_bitmasks_full == 0] = INF areas = areas.permute(1, 2, 0).reshape(h * w, -1) min_areas, inds = areas.min(dim=1) per_im_sematic_targets = per_im_gt.gt_classes[inds] + 1 per_im_sematic_targets[min_areas == INF] = 0 per_im_sematic_targets = per_im_sematic_targets.reshape(h, w) semantic_targets.append(per_im_sematic_targets) semantic_targets = torch.stack(semantic_targets, dim=0) # resize target to reduce memory semantic_targets = semantic_targets[:, None, self.out_stride // 2::self.out_stride, self.out_stride // 2::self.out_stride] # prepare one-hot targets num_classes = logits_pred.size(1) class_range = torch.arange(num_classes, dtype=logits_pred.dtype, device=logits_pred.device)[:, None, None] class_range = class_range + 1 one_hot = (semantic_targets == class_range).float() num_pos = (one_hot > 0).sum().float().clamp(min=1.0) loss_sem = sigmoid_focal_loss_jit( logits_pred, one_hot, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos losses['loss_sem'] = loss_sem return mask_feats, losses
def fcos_losses(self, instances): num_classes = instances.logits_pred.size(1) assert num_classes == self.num_classes labels = instances.labels.flatten() gt_object = instances.gt_inds pos_inds = torch.nonzero(labels != num_classes).squeeze(1) neg_inds = torch.nonzero(labels == num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(instances.logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( instances.logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="none", ) #/ num_pos_avg positive_diff = ( 1 - instances.logits_pred[class_target == 1].sigmoid()).abs() negative_diff = ( 0 - instances.logits_pred[class_target == 0].sigmoid()).abs() positive_mean = positive_diff.mean().detach() positive_std = positive_diff.std().detach() negative_mean = negative_diff.mean().detach() negative_std = negative_diff.std().detach() upper_true_loss = class_loss.flatten()[(class_target == 1).flatten()][ (positive_diff > (positive_mean + positive_std))].sum() / num_pos_avg under_true_loss = class_loss.flatten()[(class_target == 1).flatten()][ (positive_diff <= (positive_mean + positive_std))].sum() / num_pos_avg upper_false_loss = class_loss.flatten()[(class_target == 0).flatten()][ (negative_diff > (negative_mean + negative_std))].sum() / num_pos_avg under_false_loss = class_loss.flatten()[(class_target == 0).flatten()][ (negative_diff <= (negative_mean + negative_std))].sum() / num_pos_avg storage = get_event_storage() if storage.iter % 20 == 0: logger.info( "upper_true {}, under_true {} upper_false {} under_false {}". format((positive_diff > positive_mean + positive_std).sum(), (positive_diff <= positive_mean + positive_std).sum(), (negative_diff > negative_mean + negative_std).sum(), (negative_diff <= negative_mean + negative_std).sum())) instances = instances[pos_inds] instances.pos_inds = pos_inds #assert (instances.gt_inds.unique() != gt_object.unique()).sum() == 0 ctrness_targets = compute_ctrness_targets(instances.reg_targets) ctrness_targets_sum = ctrness_targets.sum() loss_denorm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) instances.gt_ctrs = ctrness_targets if pos_inds.numel() > 0: reg_loss = self.loc_loss_func(instances.reg_pred, instances.reg_targets, ctrness_targets) / loss_denorm ctrness_loss = torch.nn.MSELoss(reduction="sum")( instances.ctrness_pred.sigmoid(), ctrness_targets) / num_pos_avg else: reg_loss = instances.reg_pred.sum() * 0 ctrness_loss = instances.ctrness_pred.sum() * 0 losses = { "loss_upper_true_cls": upper_true_loss, "loss_under_true_cls": under_true_loss, "loss_upper_false_cls": upper_false_loss, "loss_under_false_cls": under_false_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss, #"loss_negative_identity_mean": negative_identity_mean_loss, #"loss_negative_identity_std": negative_identity_std_loss, #"loss_positive_identity": positive_identity_loss, } extras = {"instances": instances, "loss_denorm": loss_denorm} return extras, losses
def MEInst_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_targets): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg if self.loss_on_mask: # n_components predictions --> m*m mask predictions without sigmoid # as sigmoid function is combined in loss. mask_pred = self.mask_encoding.decoder(mask_pred, is_train=True) mask_loss = self.mask_loss_func(mask_pred, mask_targets) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) else: # m*m mask labels --> n_components encoding labels mask_targets = self.mask_encoding.encoder(mask_targets) if self.mask_loss_type == 'mse': mask_loss = self.mask_loss_func(mask_pred, mask_targets) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.dim_mask, 1.0) else: raise NotImplementedError losses = { "loss_MEInst_cls": class_loss, "loss_MEInst_loc": reg_loss, "loss_MEInst_ctr": ctrness_loss, "loss_MEInst_mask": mask_loss, } return losses, {}
def fcos_losses(self, instances): num_classes = instances.logits_pred.size(1) assert num_classes == self.num_classes labels = instances.labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = torch.ones_like(pos_inds).sum() num_pos_avg = max(reduce_mean(num_pos_local).item(), 1.0) # prepare one_hot class_target = torch.zeros_like(instances.logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit(instances.logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum") if self.loss_normalizer_cls == "moving_fg": self.moving_num_fg = self.moving_num_fg_momentum * self.moving_num_fg + ( 1 - self.moving_num_fg_momentum) * num_pos_avg class_loss = class_loss / self.moving_num_fg elif self.loss_normalizer_cls == "fg": class_loss = class_loss / num_pos_avg else: num_samples_local = torch.ones_like(labels).sum() num_samples_avg = max(reduce_mean(num_samples_local).item(), 1.0) class_loss = class_loss / num_samples_avg class_loss = class_loss * self.loss_weight_cls instances = instances[pos_inds] instances.pos_inds = pos_inds ctrness_targets = compute_ctrness_targets(instances.reg_targets) ctrness_targets_sum = ctrness_targets.sum() loss_denorm = max(reduce_mean(ctrness_targets_sum).item(), 1e-6) instances.gt_ctrs = ctrness_targets if pos_inds.numel() > 0: reg_loss = self.loc_loss_func(instances.reg_pred, instances.reg_targets, ctrness_targets) / loss_denorm ctrness_loss = F.binary_cross_entropy_with_logits( instances.ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg else: reg_loss = instances.reg_pred.sum() * 0 ctrness_loss = instances.ctrness_pred.sum() * 0 losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss } extras = {"instances": instances, "loss_denorm": loss_denorm} return extras, losses
def fcose_losses( labels, reg_targets, ext_targets, logits_pred, reg_pred, ext_pred, ctrness_pred, focal_loss_alpha, focal_loss_gamma, iou_loss, ext_loss ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 # background-0; C binary cls class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ext_pred = ext_pred[pos_inds] ext_targets = ext_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) ext_pt_loss = ext_loss( ext_pred, ext_targets, ctrness_targets ) / ctrness_norm reg_loss = iou_loss( reg_pred, reg_targets, ctrness_targets ) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum" ) / num_pos_avg losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss, "loss_ext_pts": ext_pt_loss } return losses, {}
def forward(self, indices, gt_instances, anchors, pred_class_logits, pred_anchor_deltas): pred_class_logits = cat(pred_class_logits, dim=1).view(-1, self.num_classes) pred_anchor_deltas = cat(pred_anchor_deltas, dim=1).view(-1, 4) anchors = [Boxes.cat(anchors_i) for anchors_i in anchors] N = len(anchors) # list[Tensor(R, 4)], one for each image all_anchors = Boxes.cat(anchors).tensor # Boxes(Tensor(N*R, 4)) predicted_boxes = self.box2box_transform.apply_deltas( pred_anchor_deltas, all_anchors) predicted_boxes = predicted_boxes.reshape(N, -1, 4) ious = [] pos_ious = [] for i in range(N): src_idx, tgt_idx = indices[i] iou = box_iou(predicted_boxes[i, ...], gt_instances[i].gt_boxes.tensor) if iou.numel() == 0: max_iou = iou.new_full((iou.size(0), ), 0) else: max_iou = iou.max(dim=1)[0] a_iou = box_iou(anchors[i].tensor, gt_instances[i].gt_boxes.tensor) if a_iou.numel() == 0: pos_iou = a_iou.new_full((0, ), 0) else: pos_iou = a_iou[src_idx, tgt_idx] ious.append(max_iou) pos_ious.append(pos_iou) ious = torch.cat(ious) ignore_idx = ious > self.neg_ignore_thresh pos_ious = torch.cat(pos_ious) pos_ignore_idx = pos_ious < self.pos_ignore_thresh src_idx = torch.cat([ src + idx * anchors[0].tensor.shape[0] for idx, (src, _) in enumerate(indices) ]) gt_classes = torch.full(pred_class_logits.shape[:1], self.num_classes, dtype=torch.int64, device=pred_class_logits.device) gt_classes[ignore_idx] = -1 target_classes_o = torch.cat( [t.gt_classes[J] for t, (_, J) in zip(gt_instances, indices)]) target_classes_o[pos_ignore_idx] = -1 gt_classes[src_idx] = target_classes_o valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 if comm.get_world_size() > 1: dist.all_reduce(num_foreground) num_foreground = num_foreground * 1.0 / comm.get_world_size() # cls loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) # reg loss target_boxes = torch.cat( [t.gt_boxes.tensor[i] for t, (_, i) in zip(gt_instances, indices)], dim=0) target_boxes = target_boxes[~pos_ignore_idx] matched_predicted_boxes = predicted_boxes.reshape( -1, 4)[src_idx[~pos_ignore_idx]] loss_box_reg = giou_loss(matched_predicted_boxes, target_boxes, reduction="sum") return { "loss_cls": loss_cls / max(1, num_foreground), "loss_box_reg": loss_box_reg / max(1, num_foreground), }
def losses(self, indices, gt_instances, anchors, pred_class_logits, pred_anchor_deltas): pred_class_logits = cat(pred_class_logits, dim=1).view(-1, self.num_classes) pred_anchor_deltas = cat(pred_anchor_deltas, dim=1).view(-1, 4) anchors = [Boxes.cat(anchors_i) for anchors_i in anchors] N = len(anchors) # list[Tensor(R, 4)], one for each image all_anchors = Boxes.cat(anchors).tensor # Boxes(Tensor(N*R, 4)) predicted_boxes = self.box2box_transform.apply_deltas( pred_anchor_deltas, all_anchors) predicted_boxes = predicted_boxes.reshape(N, -1, 4) # We obtain positive anchors by choosing gt boxes' k nearest anchors # and leave the rest to be negative anchors. However, there may # exist negative anchors that have similar distances with the chosen # positives. These negatives may cause ambiguity for model training # if we just set them as negatives. Given that we want the model's # predict boxes on negative anchors to have low IoU with gt boxes, # we set a threshold on the IoU between predicted boxes and gt boxes # instead of the IoU between anchor boxes and gt boxes. ious = [] pos_ious = [] for i in range(N): src_idx, tgt_idx = indices[i] iou = box_iou(predicted_boxes[i, ...], gt_instances[i].gt_boxes.tensor) if iou.numel() == 0: max_iou = iou.new_full((iou.size(0), ), 0) else: max_iou = iou.max(dim=1)[0] a_iou = box_iou(anchors[i].tensor, gt_instances[i].gt_boxes.tensor) if a_iou.numel() == 0: pos_iou = a_iou.new_full((0, ), 0) else: pos_iou = a_iou[src_idx, tgt_idx] ious.append(max_iou) pos_ious.append(pos_iou) ious = torch.cat(ious) ignore_idx = ious > self.neg_ignore_thresh pos_ious = torch.cat(pos_ious) pos_ignore_idx = pos_ious < self.pos_ignore_thresh src_idx = torch.cat([ src + idx * anchors[0].tensor.shape[0] for idx, (src, _) in enumerate(indices) ]) gt_classes = torch.full(pred_class_logits.shape[:1], self.num_classes, dtype=torch.int64, device=pred_class_logits.device) gt_classes[ignore_idx] = -1 target_classes_o = torch.cat( [t.gt_classes[J] for t, (_, J) in zip(gt_instances, indices)]) target_classes_o[pos_ignore_idx] = -1 gt_classes[src_idx] = target_classes_o valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 if comm.get_world_size() > 1: dist.all_reduce(num_foreground) num_foreground = num_foreground * 1.0 / comm.get_world_size() # cls loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) # reg loss target_boxes = torch.cat( [t.gt_boxes.tensor[i] for t, (_, i) in zip(gt_instances, indices)], dim=0) target_boxes = target_boxes[~pos_ignore_idx] matched_predicted_boxes = predicted_boxes.reshape( -1, 4)[src_idx[~pos_ignore_idx]] loss_box_reg = giou_loss(matched_predicted_boxes, target_boxes, reduction="sum") return { "loss_cls": loss_cls / max(1, num_foreground), "loss_box_reg": loss_box_reg / max(1, num_foreground), }
def losses(self, center_pts, cls_outs, pts_outs_init, pts_outs_refine, targets): """ Args: center_pts: (list[list[Tensor]]): a list of N=#image elements. Each is a list of #feature level tensors. The tensors contains shifts of this image on the specific feature level. cls_outs: List[Tensor], each item in list with shape:[N, num_classes, H, W] pts_outs_init: List[Tensor], each item in list with shape:[N, num_points*2, H, W] pts_outs_refine: List[Tensor], each item in list with shape:[N, num_points*2, H, W] targets: (list[Instances]): a list of N `Instances`s. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Specify `targets` during training only. Returns: dict[str:Tensor]: mapping from a named loss to scalar tensor """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_outs] assert len(featmap_sizes) == len(center_pts[0]) pts_dim = 2 * self.num_points cls_outs = [ cls_out.permute(0, 2, 3, 1).reshape(cls_out.size(0), -1, self.num_classes) for cls_out in cls_outs ] pts_outs_init = [ pts_out_init.permute(0, 2, 3, 1).reshape(pts_out_init.size(0), -1, pts_dim) for pts_out_init in pts_outs_init ] pts_outs_refine = [ pts_out_refine.permute(0, 2, 3, 1).reshape(pts_out_refine.size(0), -1, pts_dim) for pts_out_refine in pts_outs_refine ] cls_outs = torch.cat(cls_outs, dim=1) pts_outs_init = torch.cat(pts_outs_init, dim=1) pts_outs_refine = torch.cat(pts_outs_refine, dim=1) pts_strides = [] for i, s in enumerate(center_pts[0]): pts_strides.append( cls_outs.new_full((s.size(0), ), self.fpn_strides[i])) pts_strides = torch.cat(pts_strides, dim=0) center_pts = [ torch.cat(c_pts, dim=0).to(cls_outs.device) for c_pts in center_pts ] pred_cls = [] pred_init = [] pred_refine = [] target_cls = [] target_init = [] target_refine = [] num_pos_init = 0 num_pos_refine = 0 for img, (per_center_pts, cls_prob, pts_init, pts_refine, per_targets) in enumerate( zip(center_pts, cls_outs, pts_outs_init, pts_outs_refine, targets)): assert per_center_pts.shape[:-1] == cls_prob.shape[:-1] gt_bboxes = per_targets.gt_boxes.to(cls_prob.device) gt_labels = per_targets.gt_classes.to(cls_prob.device) pts_init_bbox_targets, pts_init_labels_targets = \ self.point_targets(per_center_pts, pts_strides, gt_bboxes.tensor, gt_labels) # per_center_pts, shape:[N, 18] per_center_pts_repeat = per_center_pts.repeat(1, self.num_points) normalize_term = self.point_base_scale * pts_strides normalize_term = normalize_term.reshape(-1, 1) # bbox_center = torch.cat([per_center_pts, per_center_pts], dim=1) per_pts_strides = pts_strides.reshape(-1, 1) pts_init_coordinate = pts_init * per_pts_strides + \ per_center_pts_repeat init_bbox_pred = self.pts_to_bbox(pts_init_coordinate) foreground_idxs = (pts_init_labels_targets >= 0) & \ (pts_init_labels_targets != self.num_classes) pred_init.append(init_bbox_pred[foreground_idxs] / normalize_term[foreground_idxs]) target_init.append(pts_init_bbox_targets[foreground_idxs] / normalize_term[foreground_idxs]) num_pos_init += foreground_idxs.sum() # A another way to convert predicted offset to bbox # bbox_pred_init = self.pts_to_bbox(pts_init.detach()) * \ # per_pts_strides # init_bbox_pred = bbox_center + bbox_pred_init pts_refine_bbox_targets, pts_refine_labels_targets = \ self.bbox_targets(init_bbox_pred, gt_bboxes, gt_labels) pts_refine_coordinate = pts_refine * per_pts_strides + per_center_pts_repeat refine_bbox_pred = self.pts_to_bbox(pts_refine_coordinate) # bbox_pred_refine = self.pts_to_bbox(pts_refine) * per_pts_strides # refine_bbox_pred = bbox_center + bbox_pred_refine foreground_idxs = (pts_refine_labels_targets >= 0) & \ (pts_refine_labels_targets != self.num_classes) pred_refine.append(refine_bbox_pred[foreground_idxs] / normalize_term[foreground_idxs]) target_refine.append(pts_refine_bbox_targets[foreground_idxs] / normalize_term[foreground_idxs]) num_pos_refine += foreground_idxs.sum() gt_classes_target = torch.zeros_like(cls_prob) gt_classes_target[foreground_idxs, pts_refine_labels_targets[foreground_idxs]] = 1 pred_cls.append(cls_prob) target_cls.append(gt_classes_target) pred_cls = torch.cat(pred_cls, dim=0) pred_init = torch.cat(pred_init, dim=0) pred_refine = torch.cat(pred_refine, dim=0) target_cls = torch.cat(target_cls, dim=0) target_init = torch.cat(target_init, dim=0) target_refine = torch.cat(target_refine, dim=0) loss_cls = sigmoid_focal_loss_jit( pred_cls, target_cls, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum") / max( 1, num_pos_refine.item()) * self.loss_cls_weight loss_pts_init = smooth_l1_loss( pred_init, target_init, beta=0.11, reduction='sum') / max( 1, num_pos_init.item()) * self.loss_loc_init_weight loss_pts_refine = smooth_l1_loss( pred_refine, target_refine, beta=0.11, reduction='sum') / max( 1, num_pos_refine.item()) * self.loss_loc_refine_weight return { "loss_cls": loss_cls, "loss_pts_init": loss_pts_init, "loss_pts_refine": loss_pts_refine }
def fcos_losses( self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, controllers_pred, focal_loss_alpha, focal_loss_gamma, iou_loss, matched_idxes, im_idxes, locations, ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = (sigmoid_focal_loss_jit( logits_pred, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos_avg) reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] controllers_pred = controllers_pred[pos_inds] matched_idxes = matched_idxes[pos_inds] im_idxes = im_idxes[pos_inds] locations = locations[pos_inds] ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg # for CondInst batch_ins = pos_inds.shape[0] N, C, h, w = self.masks.shape center_x = torch.clamp(locations[:, 0], min=0, max=w - 1).long() center_y = torch.clamp(locations[:, 1], min=0, max=h - 1).long() x_range = torch.linspace(-1, 1, w, device=self.masks.device) y_range = torch.linspace(-1, 1, h, device=self.masks.device) y, x = torch.meshgrid(y_range, x_range) x = x.unsqueeze(0).unsqueeze(0) y = y.unsqueeze(0).unsqueeze(0) grid = torch.cat([x, y], 1) offset_x = x_range[center_x].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) offset_y = y_range[center_y].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) offset_xy = torch.cat([offset_x, offset_y], 1) coords_feat = grid - offset_xy masks_feat = self.masks r_h = int(h * self.strides[0]) r_w = int(w * self.strides[0]) targets_masks = [ target_im.gt_masks.tensor for target_im in self.gt_instances ] masks_t = self.prepare_masks(h, w, r_h, r_w, targets_masks) mask_loss = masks_feat[0].new_tensor(0.0) batch_ins = im_idxes.shape[0] # for each image for i in range(N): inds = (im_idxes == i).nonzero().flatten() ins_num = inds.shape[0] if ins_num > 0: controllers = controllers_pred[inds] coord_feat = coords_feat[inds] mask_feat = masks_feat[None, i] mask_feat = torch.cat([mask_feat] * ins_num, dim=0) comb_feat = torch.cat((mask_feat, coord_feat), dim=1).view(1, -1, h, w) weight1, bias1, weight2, bias2, weight3, bias3 = torch.split( controllers, [80, 8, 64, 8, 8, 1], dim=1) bias1, bias2, bias3 = bias1.flatten(), bias2.flatten( ), bias3.flatten() weight1 = weight1.reshape(-1, 8, 10).reshape( -1, 10).unsqueeze(-1).unsqueeze(-1) weight2 = weight2.reshape(-1, 8, 8).reshape( -1, 8).unsqueeze(-1).unsqueeze(-1) weight3 = weight3.unsqueeze(-1).unsqueeze(-1) conv1 = F.conv2d(comb_feat, weight1, bias1, groups=ins_num).relu() conv2 = F.conv2d(conv1, weight2, bias2, groups=ins_num).relu() masks_per_image = F.conv2d(conv2, weight3, bias3, groups=ins_num) masks_per_image = aligned_bilinear( masks_per_image, self.strides[0])[0].sigmoid() for j in range(ins_num): ind = inds[j] mask_gt = masks_t[i][matched_idxes[ind]].float() mask_pred = masks_per_image[j] mask_loss += self.dice_loss(mask_pred, mask_gt) if batch_ins > 0: mask_loss = mask_loss / batch_ins losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss, "loss_mask": mask_loss, } return losses, {}
def FCOSLosses(cls_scores, bbox_preds, centernesses, labels, bbox_targets, reg_loss, cfg): """ Arguments: cls_scores, bbox_preds, centernesses: Same as the output of :meth:`FCOSHead.forward` labels, bbox_targets: Same as the output of :func:`FCOSTargets` Returns: losses (dict[str: Tensor]): A dict mapping from loss name to loss value. """ # fmt: off num_classes = cfg.MODEL.FCOS.NUM_CLASSES focal_loss_alpha = cfg.MODEL.FCOS.LOSS_ALPHA focal_loss_gamma = cfg.MODEL.FCOS.LOSS_GAMMA # fmt: on # Collect all logits and regression predictions over feature maps # and images to arrive at the same shape as the labels and targets # The final ordering is L, N, H, W from slowest to fastest axis. flatten_cls_scores = cat( [ # Reshape: (N, C, Hi, Wi) -> (N, Hi, Wi, C) -> (N*Hi*Wi, C) cls_score.permute(0, 2, 3, 1).reshape(-1, num_classes) for cls_score in cls_scores ], dim=0) flatten_bbox_preds = cat( [ # Reshape: (N, 4, Hi, Wi) -> (N, Hi, Wi, 4) -> (N*Hi*Wi, 4) bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) for bbox_pred in bbox_preds ], dim=0) flatten_centernesses = cat( [ # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,) centerness.reshape(-1) for centerness in centernesses ], dim=0) # flatten classification and regression targets. flatten_labels = cat(labels) flatten_bbox_targets = cat(bbox_targets) # retain indices of positive predictions. pos_inds = torch.nonzero(flatten_labels != num_classes).squeeze(1) num_pos = max(len(pos_inds), 1.0) # prepare one_hot label. class_target = torch.zeros_like(flatten_cls_scores) class_target[pos_inds, flatten_labels[pos_inds]] = 1 # classification loss: Focal loss loss_cls = sigmoid_focal_loss_jit( flatten_cls_scores, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos # compute regression loss and centerness loss only for positive samples. pos_bbox_preds = flatten_bbox_preds[pos_inds] pos_centernesses = flatten_centernesses[pos_inds] pos_bbox_targets = flatten_bbox_targets[pos_inds] # compute centerness targets. pos_centerness_targets = compute_centerness_targets(pos_bbox_targets) centerness_norm = max(pos_centerness_targets.sum(), 1e-6) # regression loss: IoU loss loss_bbox = reg_loss(pos_bbox_preds, pos_bbox_targets, weight=pos_centerness_targets) / centerness_norm # centerness loss: Binary CrossEntropy loss loss_centerness = F.binary_cross_entropy_with_logits( pos_centernesses, pos_centerness_targets, reduction="sum") / num_pos # final loss dict. losses = dict(loss_fcos_cls=loss_cls, loss_fcos_loc=loss_bbox, loss_fcos_ctr=loss_centerness) return losses
def fcos_losses( self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, coeffs_pred, protos, focal_loss_alpha, focal_loss_gamma, iou_loss, matched_idxes, im_idxes ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = iou_loss( reg_pred, reg_targets, ctrness_targets ) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum" ) / num_pos_avg # for yolact coeffs_pred = coeffs_pred[pos_inds] matched_idxes = matched_idxes[pos_inds] im_idxes = im_idxes[pos_inds] N, _, m_h, m_w = protos.shape r_h = int(m_h * self.strides[0]) r_w = int(m_w * self.strides[0]) targets_masks = [target_im.gt_masks.tensor for target_im in self.gt_instances] masks_t = self.prepare_masks(m_h, m_w, r_h, r_w, targets_masks) num_ins = coeffs_pred.shape[0] mask_loss = coeffs_pred[0].new_tensor(0.0) for i in range(num_ins): im_id = im_idxes[i] mask_pred = torch.sigmoid((protos[im_id]*coeffs_pred[i].view(self.num_protos,1,1)).sum(dim=0)) mask_gt = masks_t[im_id][matched_idxes[i]].float() mask_loss += self.dice_loss(mask_pred, mask_gt) if num_ins > 0: mask_loss = mask_loss/num_ins losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss, "loss_mask": mask_loss } return losses, {}
def fcos_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, controllers_pred, focal_loss_alpha, focal_loss_gamma, iou_loss, matched_idxes, im_idxes): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] controllers_pred = controllers_pred[pos_inds] matched_idxes = matched_idxes[pos_inds] im_idxes = im_idxes[pos_inds] ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg # for CondInst N, C, h, w = self.masks.shape grid_x = torch.arange(w).view(1, -1).float().repeat( h, 1).cuda() / (w - 1) * 2 - 1 grid_y = torch.arange(h).view(-1, 1).float().repeat( 1, w).cuda() / (h - 1) * 2 - 1 x_map = grid_x.view(1, 1, h, w).repeat(N, 1, 1, 1) y_map = grid_y.view(1, 1, h, w).repeat(N, 1, 1, 1) masks_feat = torch.cat((self.masks, x_map, y_map), dim=1) r_h = int(h * self.strides[0]) r_w = int(w * self.strides[0]) # seg head mask_loss = 0 ''' targets_masks = [target_im.gt_masks.tensor for target_im in self.gt_instances] masks_t = self.prepare_masks(h, w, r_h, r_w, targets_masks) mask_loss = masks_feat[0].new_tensor(0.0) batch_ins = im_idxes.shape[0] # for each image for i in range(N): inds = (im_idxes==i).nonzero().flatten() ins_num = inds.shape[0] if ins_num > 0: controllers = controllers_pred[inds] mask_feat = masks_feat[None, i] weights1 = controllers[:, :80].reshape(-1,8,10).reshape(-1,10).unsqueeze(-1).unsqueeze(-1) bias1 = controllers[:, 80:88].flatten() weights2 = controllers[:, 88:152].reshape(-1,8,8).reshape(-1,8).unsqueeze(-1).unsqueeze(-1) bias2 = controllers[:, 152:160].flatten() weights3 = controllers[:, 160:168].unsqueeze(-1).unsqueeze(-1) bias3 = controllers[:,168:169].flatten() conv1 = F.conv2d(mask_feat,weights1,bias1).relu() conv2 = F.conv2d(conv1, weights2, bias2, groups = ins_num).relu() #masks_per_image = F.conv2d(conv2, weights3, bias3, groups = ins_num)[0].sigmoid() masks_per_image = F.conv2d(conv2, weights3, bias3, groups = ins_num) masks_per_image = aligned_bilinear(masks_per_image, self.strides[0])[0].sigmoid() for j in range(ins_num): ind = inds[j] mask_gt = masks_t[i][matched_idxes[ind]].float() mask_pred = masks_per_image[j] mask_loss += self.dice_loss(mask_pred, mask_gt) if batch_ins > 0: mask_loss = mask_loss / batch_ins ''' losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss, "loss_mask": mask_loss } return losses, {}
def losses(self, anchors, pred_logits, pred_boxes_init, pred_anchor_deltas, gt_instances, point_centers, strides): """ Args: anchors (list[Boxes]): a list of #feature level Boxes gt_labels, gt_boxes: see output of :meth:`RetinaNet.label_anchors`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x Ai) pred_logits, pred_anchor_deltas: both are list[Tensor]. Each element in the list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4). Where K is the number of classes used in `pred_logits`. Returns: dict[str, Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances) gt_labels_init, gt_boxes_init = self.get_ground_truth( point_centers, strides, gt_instances) # Transpose the Hi*Wi*A dimension to the middle: pred_logits = [ permute_to_N_HWA_K(x, self.num_classes) for x in pred_logits ] pred_anchor_deltas = [ permute_to_N_HWA_K(x, 4) for x in pred_anchor_deltas ] num_images = len(gt_labels) gt_labels = torch.stack(gt_labels) # (N, R) anchors = type(anchors[0]).cat(anchors).tensor # (R, 4) gt_anchor_deltas = [ self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes ] gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4) valid_mask = gt_labels >= 0 pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) num_pos_anchors = pos_mask.sum().item() get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( 1 - self.loss_normalizer_momentum) * max(num_pos_anchors, 1) # classification and regression loss gt_labels_target = F.one_hot(gt_labels[valid_mask], num_classes=self.num_classes + 1)[:, :-1] # no loss for the last (background) class loss_cls = sigmoid_focal_loss_jit( cat(pred_logits, dim=1)[valid_mask], gt_labels_target.to(pred_logits[0].dtype), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) * self.loss_cls_weight init_foreground_idxs = gt_labels_init > 0 strides = strides[None].repeat(pred_logits[0].shape[0], 1) coords_norm_init = strides[init_foreground_idxs].unsqueeze(-1) * 4 loss_loc_init = smooth_l1_loss( pred_boxes_init[init_foreground_idxs] / coords_norm_init, gt_boxes_init[init_foreground_idxs] / coords_norm_init, beta=0.11, reduction="sum", ) / max(init_foreground_idxs.sum(), 1) if self.box_reg_loss_type == "smooth_l1": loss_loc_refine = smooth_l1_loss( cat(pred_anchor_deltas, dim=1)[pos_mask], gt_anchor_deltas[pos_mask], beta=0.11, reduction="sum", ) elif self.box_reg_loss_type == "giou": pred_boxes = [ self.box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1) ] loss_loc_refine = giou_loss(torch.stack(pred_boxes)[pos_mask], torch.stack(gt_boxes)[pos_mask], reduction="sum") else: raise ValueError( f"Invalid bbox reg loss type '{self.box_reg_loss_type}'") return { "loss_cls": loss_cls / self.loss_normalizer, "loss_loc_init": loss_loc_init * self.loss_loc_init_weight, "loss_loc_refine": loss_loc_refine / self.loss_normalizer * self.loss_loc_refine_weight, }
def fcos_losses( labels, reg_targets, logits_pred, reg_pred, ctrness_pred, focal_loss_alpha, focal_loss_gamma, iou_loss, gt_inds, ): num_classes = logits_pred.size(1) labels = labels.flatten()#返回一个折叠成一维的数组 # 提取有类别的特征图中的点 # 正例点的索引(有 label 的点的索引) pos_inds = torch.nonzero(labels != num_classes).squeeze(1) # pos_inds : tensor([ 7971, 7972, 7973, 8123, 8124, 8125, 8275, 8276, 8277, 17133, # 17134, 17135, 20057, 20058, 20059, 20068, 20069, 20070, 20076, 20077, # 20078, 20087, 20088, 20089, 20095, 20096, 20097, 20106, 20107, 20108, # 20243, 20244], device='cuda:0') num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos_avg # 根据pos_inds提取正样本 reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] gt_inds = gt_inds[pos_inds] ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) if pos_inds.numel() > 0: # 计算正例预测的框与真实框的 IOU loss #!这里的中心度作为权重输入进去 reg_loss = iou_loss( reg_pred, reg_targets, ctrness_targets ) / loss_denorm #计算中心度的损失 ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum" ) / num_pos_avg else: reg_loss = reg_pred.sum() * 0 ctrness_loss = ctrness_pred.sum() * 0 losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss } extras = { "pos_inds": pos_inds, "gt_inds": gt_inds, "gt_ctr": ctrness_targets, "loss_denorm": loss_denorm } return losses, extras
def DTInst_losses(self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_targets): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max( reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss(reg_pred, reg_targets, ctrness_targets) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg total_mask_loss = 0. dtm_pred_, binary_pred_ = self.mask_encoding.decoder(mask_pred, is_train=True) code_targets, dtm_targets, weight_maps, hd_maps = self.mask_encoding.encoder( mask_targets) if self.loss_on_mask: if 'mask_mse' in self.mask_loss_type: mask_loss = F.mse_loss(dtm_pred_, dtm_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += mask_loss if 'weighted_mask_mse' in self.mask_loss_type: mask_loss = F.mse_loss(dtm_pred_, dtm_targets, reduction='none') mask_loss = torch.sum(mask_loss * weight_maps, 1) / torch.sum( weight_maps, 1) * ctrness_targets * self.mask_size**2 mask_loss = mask_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += mask_loss if 'mask_difference' in self.mask_loss_type: w_ = torch.abs( binary_pred_ * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps md_loss = torch.sum(w_, 1) * ctrness_targets md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += md_loss if 'hd_one_side_binary' in self.mask_loss_type: # the first attempt, not really accurate w_ = torch.abs( binary_pred_ * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum(w_ * hd_maps, 1) / (torch.sum( w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_two_side_binary' in self.mask_loss_type: # the first attempt, not really accurate w_ = torch.abs( binary_pred_ * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( w_ * (torch.clamp(dtm_pred_**2, -0.1, 1.1) + torch.clamp(dtm_targets**2, -0.1, 1)), 1) / (torch.sum( w_, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_weighted_one_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( dtm_diff * weight_maps * hd_maps, 1) / (torch.sum(weight_maps, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_weighted_two_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( dtm_diff * weight_maps * (dtm_pred_**2 + dtm_targets**2), 1) / (torch.sum(weight_maps, 1) + 1e-4) * ctrness_targets * self.mask_size**2 hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_one_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum(dtm_diff * hd_maps, 1) * ctrness_targets hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'hd_two_side_dtm' in self.mask_loss_type: dtm_diff = ( dtm_pred_ - dtm_targets)**2 # 1's are inconsistent pixels in hd_maps hausdorff_loss = torch.sum( dtm_diff * (torch.clamp(dtm_pred_, -1.1, 1.1)**2 + dtm_targets**2), 1) * ctrness_targets hausdorff_loss = hausdorff_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += hausdorff_loss if 'contour_dice' in self.mask_loss_type: pred_contour = (dtm_pred_ + 0.9 < 0.55) * 1. * ( 0.5 <= dtm_pred_ + 0.9 ) # contour pixels with 0.05 tolerance target_contour = (dtm_targets < 0.05) * 1. * (dtm_targets < 0.05) # pred_contour = 0.5 <= dtm_pred_ + 0.9 < 0.55 # contour pixels with 0.05 tolerance # target_contour = 0. <= dtm_targets < 0.05 overlap_ = torch.sum(pred_contour * 2. * target_contour, 1) union_ = torch.sum(pred_contour**2, 1) + torch.sum( target_contour**2, 1) dice_loss = ( 1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size**2 dice_loss = dice_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += dice_loss if 'mask_dice' in self.mask_loss_type: overlap_ = torch.sum(binary_pred_ * 2. * mask_targets, 1) union_ = torch.sum(binary_pred_**2, 1) + torch.sum( mask_targets**2, 1) dice_loss = ( 1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size**2 dice_loss = dice_loss.sum() / max( ctrness_norm * self.mask_size**2, 1.0) total_mask_loss += dice_loss if self.loss_on_code: # m*m mask labels --> n_components encoding labels if 'mse' in self.mask_loss_type: mask_loss = F.mse_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) if self.mask_sparse_weight > 0.: if self.sparsity_loss_type == 'L1': sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L1': w_ = ( torch.abs(code_targets) < 1e-4 ) * 1. # inactive codes, put L1 regularization on them sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L2': w_ = ( torch.abs(code_targets) < 1e-4 ) * 1. # inactive codes, put L2 regularization on them sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight else: raise NotImplementedError total_mask_loss += mask_loss if 'smooth' in self.mask_loss_type: mask_loss = F.smooth_l1_loss(mask_pred, code_targets, reduction='none') mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'cosine' in self.mask_loss_type: mask_loss = loss_cos_sim(mask_pred, code_targets) mask_loss = mask_loss * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'kl_softmax' in self.mask_loss_type: mask_loss = loss_kl_div_softmax(mask_pred, code_targets) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max( ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss losses = { "loss_DTInst_cls": class_loss, "loss_DTInst_loc": reg_loss, "loss_DTInst_ctr": ctrness_loss, "loss_DTInst_mask": total_mask_loss } return losses, {}
def losses(self, locations, class_logits, center_score, box_reg_init, box_reg, gt_instances): gt_classes, loc_targets, topk_locations = self.get_ground_truth( locations, gt_instances) class_logits, box_reg_init, box_reg, center_score = permute_and_concat_v2( class_logits, box_reg_init, box_reg, center_score, self.num_classes) # Shapes: (N x R) and (N x R, 4), (N x R) respectively. gt_classes = gt_classes.flatten() loc_targets = loc_targets.view(-1, 4) foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) pos_inds = torch.nonzero(foreground_idxs).squeeze(1) num_gpus = get_num_gpus() # sync num_pos from all gpus total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel() ])).item() num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0) gt_classes_target = torch.zeros_like(class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 # logits loss cls_loss = sigmoid_focal_loss_jit( class_logits, gt_classes_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg_per_gpu if pos_inds.numel() > 0: if self.slender_centerness: gt_center_score = compute_slender_centerness_targets( loc_targets[foreground_idxs]) else: gt_center_score = compute_centerness_targets( loc_targets[foreground_idxs]) # average sum_centerness_targets from all gpus, # which is used to normalize centerness-weighed reg loss sum_centerness_targets_avg_per_gpu = \ reduce_sum(gt_center_score.sum()).item() / float(num_gpus) topk_locations = topk_locations.view(-1) topk_gt_center_score = compute_centerness_targets( loc_targets[topk_locations]) sum_topk_centerness_targets_avg_per_gpu = \ reduce_sum(topk_gt_center_score.sum()).item() / float(num_gpus) loss_loc_init = iou_loss( box_reg_init[topk_locations], loc_targets[topk_locations], topk_gt_center_score, loss_type=self.iou_loss_type ) / sum_topk_centerness_targets_avg_per_gpu loss_loc_refine = iou_loss(box_reg[foreground_idxs], loc_targets[foreground_idxs], gt_center_score, loss_type=self.iou_loss_type ) / sum_centerness_targets_avg_per_gpu centerness_loss = F.binary_cross_entropy_with_logits( center_score[foreground_idxs], gt_center_score, reduction='sum') / num_pos_avg_per_gpu else: loss_loc_init = box_reg_init[foreground_idxs].sum() loss_loc_refine = box_reg[foreground_idxs].sum() reduce_sum(center_score[foreground_idxs].new_tensor([0.0])) centerness_loss = center_score[foreground_idxs].sum() return dict( loss_cls=cls_loss * self.loss_cls_weight, centerness_loss=centerness_loss * self.loss_cls_weight, loss_loc_init=loss_loc_init * self.loss_loc_init_weight, loss_loc_refine=loss_loc_refine * self.loss_loc_refine_weight, )
def losses(self, anchors: List[Boxes], pred_logits: List[Tensor], gt_classes: List[Tensor], pred_anchor_deltas: List[Tensor], gt_boxes: List[Tensor]) -> Dict[str, float]: """ Args: For `gt_classes` and `gt_anchors_deltas` parameters, see :meth:`RetinaNet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x A) For `pred_class_logits` and `pred_anchor_deltas`, see :meth:`RetinaNetHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ num_images: int = len(gt_classes) # shape(gt_classes) = (N, R) gt_classes_tensor: Tensor = torch.stack(gt_classes) # shape(anchors) = (R, 4) anchors_tensor: Tensor = type(anchors[0]).cat(anchors).tensor gt_anchor_deltas: List[Tensor] = [ self.box2box_transform.get_deltas(anchors_tensor, k) for k in gt_boxes ] # shape(gt_anchor_deltas) = (N, R, 4) gt_anchor_deltas_tensor: Tensor = torch.stack(gt_anchor_deltas) valid_mask: Tensor = gt_classes_tensor >= 0 pos_mask: Tensor = (gt_classes_tensor >= 0) & (gt_classes_tensor != self.num_classes) num_pos_anchors: int = pos_mask.sum().item() get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer\ + (1 - self.loss_normalizer_momentum) * max(num_pos_anchors, 1) # classification and regression loss # no loss for the last (background) class --> [:, :-1] gt_classes_target: LongTensor = F.one_hot( gt_classes_tensor[valid_mask], num_classes=self.num_classes + 1)[:, :-1] # logits loss loss_cls = sigmoid_focal_loss_jit( inputs=cat(pred_logits, dim=1)[valid_mask], targets=gt_classes_target.to(pred_logits[0].dtype), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum") / self.loss_normalizer # regression loss loss_box_reg = smooth_l1_loss(input=cat(pred_anchor_deltas, dim=1)[pos_mask], target=gt_anchor_deltas_tensor[pos_mask], beta=self.smooth_l1_loss_beta, reduction="sum") / self.loss_normalizer return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
def classification_losses(self, gt_classes, pred_class_logits): """ Args: For `gt_classes` and `gt_anchors_deltas` parameters, see :meth:`RetinaNet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x A) For `pred_class_logits`, see :meth:`RetinaNetHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" """ start = 0 pred_category = pred_class_logits[:, start:start + self.classification_classes[0]] start += self.classification_classes[0] pred_part = pred_class_logits[:, start:start + self.classification_classes[1]] start += self.classification_classes[1] pred_toward = pred_class_logits[:, start:start + self.classification_classes[2]] valid_idxs = gt_classes[self.classification_tasks[0]][ 1::self.classification_classes[0]] == 1 data_type = pred_category.dtype num_batchs = pred_category.size()[0] num_model = valid_idxs.sum() valid_category = gt_classes[self.classification_tasks[0]][:] > -1 # category loss if valid_category.sum() > 0: if self.activation == 'sigmoid': loss_category = sigmoid_focal_loss_jit( pred_category.flatten()[valid_category], gt_classes[self.classification_tasks[0]].to( dtype=data_type)[valid_category], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, valid_category.sum() / self.classification_classes[0]) elif self.activation == 'softmax': gt_category = torch.argmax( gt_classes[self.classification_tasks[0]].view( num_batchs, -1), dim=1) valid_category = valid_category.view(num_batchs, -1).sum(dim=1) > 0 loss_category = F.cross_entropy( pred_category[valid_category], gt_category[valid_category], reduction="sum", ) / max(1, valid_category.sum()) else: raise Exception("Not implement classification activation!") else: loss_category = 0.0 valid_part = gt_classes[self.classification_tasks[1]][:] > -1 if valid_part.sum() > 0: loss_part = sigmoid_focal_loss_jit( pred_part.flatten()[valid_part], gt_classes[self.classification_tasks[1]].to( dtype=data_type)[valid_part], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, valid_part.sum() / self.classification_classes[1]) else: loss_part = 0.0 valid_toward = gt_classes[self.classification_tasks[2]][:] > -1 if valid_toward.sum() > 0: loss_toward = sigmoid_focal_loss_jit( pred_toward.flatten()[valid_toward], gt_classes[self.classification_tasks[2]].to( dtype=data_type)[valid_toward], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, valid_toward.sum() / self.classification_classes[2]) else: loss_toward = 0.0 return { "loss_category": loss_category, "loss_part": loss_part, "loss_toward": loss_toward }
def fcos_losses(self, instances): losses, extras = {}, {} # 1. compute the cls loss num_classes = instances.logits_pred.size(1) assert num_classes == self.num_classes labels = instances.labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = torch.ones_like(pos_inds).sum() num_pos_avg = max(reduce_mean(num_pos_local).item(), 1.0) # prepare one_hot class_target = torch.zeros_like(instances.logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit(instances.logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum") if self.loss_normalizer_cls == "moving_fg": self.moving_num_fg = self.moving_num_fg_momentum * self.moving_num_fg + ( 1 - self.moving_num_fg_momentum) * num_pos_avg class_loss = class_loss / self.moving_num_fg elif self.loss_normalizer_cls == "fg": class_loss = class_loss / num_pos_avg else: num_samples_local = torch.ones_like(labels).sum() num_samples_avg = max(reduce_mean(num_samples_local).item(), 1.0) class_loss = class_loss / num_samples_avg losses["loss_fcos_cls"] = class_loss * self.loss_weight_cls # 2. compute the box regression and quality loss instances = instances[pos_inds] instances.pos_inds = pos_inds ious, gious = compute_ious(instances.reg_pred, instances.reg_targets) if self.box_quality == "ctrness": ctrness_targets = compute_ctrness_targets(instances.reg_targets) instances.gt_ctrs = ctrness_targets ctrness_targets_sum = ctrness_targets.sum() loss_denorm = max(reduce_mean(ctrness_targets_sum).item(), 1e-6) extras["loss_denorm"] = loss_denorm reg_loss = self.loc_loss_func(ious, gious, ctrness_targets) / loss_denorm losses["loss_fcos_loc"] = reg_loss ctrness_loss = F.binary_cross_entropy_with_logits( instances.ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg losses["loss_fcos_ctr"] = ctrness_loss elif self.box_quality == "iou": reg_loss = self.loc_loss_func(ious, gious) / num_pos_avg losses["loss_fcos_loc"] = reg_loss quality_loss = F.binary_cross_entropy_with_logits( instances.ctrness_pred, ious.detach(), reduction="sum") / num_pos_avg losses["loss_fcos_iou"] = quality_loss else: raise NotImplementedError extras["instances"] = instances return extras, losses
def losses(self, pred_logits, pred_init_boxes, pred_refine_boxes, gt_init_objectness, gt_init_bboxes, gt_cls: torch.Tensor, gt_refine_bboxes, strides): """ Loss computation. Args: pred_logits: (N, X, C). Classification prediction, where X is the number of positions from all feature levels, C is the number of object classes. pred_init_boxes: (N, X, 4). Init box prediction. pred_refine_boxes: (N, X, 4). Refined box prediction. gt_init_objectness: (N, X). Foreground/background classification for initial prediction. gt_init_bboxes: (N, X, 4). Initial box prediction. gt_cls: (N, X), Long. GT for box classification, -1 indicates ignoring. gt_refine_bboxes: (N, X, 4). Refined box prediction. strides: (X). Scale factor at each position. Returns: dict[str, Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls", "loss_localization_init", and "loss_localization_refine". """ valid_idxs = gt_cls >= 0 foreground_idxs = valid_idxs.logical_and(gt_cls != self.num_classes) num_foreground = foreground_idxs.sum().item() / gt_init_bboxes.shape[0] get_event_storage().put_scalar("num_foreground", num_foreground) gt_cls_target = torch.zeros_like(pred_logits) gt_cls_target[foreground_idxs, gt_cls[foreground_idxs]] = 1 self.loss_normalizer = ( self.loss_normalizer_momentum * self.loss_normalizer + (1 - self.loss_normalizer_momentum) * num_foreground) loss_cls = sigmoid_focal_loss_jit(pred_logits[valid_idxs], gt_cls_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum") / max( 1, self.loss_normalizer) init_foreground_idxs = gt_init_objectness > 0 strides = strides[None].repeat(pred_logits.shape[0], 1) coords_norm_init = strides[init_foreground_idxs].unsqueeze(-1) * 4 loss_localization_init = smooth_l1_loss( pred_init_boxes[init_foreground_idxs] / coords_norm_init, gt_init_bboxes[init_foreground_idxs] / coords_norm_init, 0.11, reduction='sum') / max(1, gt_init_objectness.sum()) * 0.5 coords_norm_refine = strides[foreground_idxs].unsqueeze(-1) * 4 loss_localization_refine = smooth_l1_loss( pred_refine_boxes[foreground_idxs] / coords_norm_refine, gt_refine_bboxes[foreground_idxs] / coords_norm_refine, 0.11, reduction="sum") / max(1, self.loss_normalizer) return { "loss_cls": loss_cls, "loss_localization_init": loss_localization_init, "loss_localization_refine": loss_localization_refine }
def fcos_losses( labels, reg_targets, logits_pred, reg_pred, ctrness_pred, focal_loss_alpha, focal_loss_gamma, iou_loss, gt_inds, ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=focal_loss_alpha, gamma=focal_loss_gamma, reduction="sum", ) / num_pos_avg reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] gt_inds = gt_inds[pos_inds] ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() loss_denorm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) if pos_inds.numel() > 0: reg_loss = iou_loss(reg_pred, reg_targets, ctrness_targets) / loss_denorm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum") / num_pos_avg else: reg_loss = reg_pred.sum() * 0 ctrness_loss = ctrness_pred.sum() * 0 losses = { "loss_fcos_cls": class_loss, "loss_fcos_loc": reg_loss, "loss_fcos_ctr": ctrness_loss } extras = { "pos_inds": pos_inds, "gt_inds": gt_inds, "gt_ctr": ctrness_targets, "loss_denorm": loss_denorm } return losses, extras
def loss(self, cate_preds, kernel_preds, ins_pred, targets): pass ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list = targets # ins ins_labels = [ torch.cat([ ins_labels_level_img for ins_labels_level_img in ins_labels_level ], 0) for ins_labels_level in zip(*ins_label_list) ] kernel_preds = [[ kernel_preds_level_img.view(kernel_preds_level_img.shape[0], -1)[:, grid_orders_level_img] for kernel_preds_level_img, grid_orders_level_img in zip( kernel_preds_level, grid_orders_level) ] for kernel_preds_level, grid_orders_level in zip( kernel_preds, zip(*grid_order_list))] # generate masks ins_pred_list = [] for b_kernel_pred in kernel_preds: b_mask_pred = [] for idx, kernel_pred in enumerate(b_kernel_pred): if kernel_pred.size()[-1] == 0: continue cur_ins_pred = ins_pred[idx, ...] H, W = cur_ins_pred.shape[-2:] N, I = kernel_pred.shape cur_ins_pred = cur_ins_pred.unsqueeze(0) kernel_pred = kernel_pred.permute(1, 0).view(I, -1, 1, 1) cur_ins_pred = F.conv2d(cur_ins_pred, kernel_pred, stride=1).view(-1, H, W) b_mask_pred.append(cur_ins_pred) if len(b_mask_pred) == 0: b_mask_pred = None else: b_mask_pred = torch.cat(b_mask_pred, 0) ins_pred_list.append(b_mask_pred) ins_ind_labels = [ torch.cat([ ins_ind_labels_level_img.flatten() for ins_ind_labels_level_img in ins_ind_labels_level ]) for ins_ind_labels_level in zip(*ins_ind_label_list) ] flatten_ins_ind_labels = torch.cat(ins_ind_labels) num_ins = flatten_ins_ind_labels.sum() # dice loss loss_ins = [] for input, target in zip(ins_pred_list, ins_labels): if input is None: continue input = torch.sigmoid(input) loss_ins.append(dice_loss(input, target)) loss_ins_mean = torch.cat(loss_ins).mean() loss_ins = loss_ins_mean * self.ins_loss_weight # cate cate_labels = [ torch.cat([ cate_labels_level_img.flatten() for cate_labels_level_img in cate_labels_level ]) for cate_labels_level in zip(*cate_label_list) ] flatten_cate_labels = torch.cat(cate_labels) cate_preds = [ cate_pred.permute(0, 2, 3, 1).reshape(-1, self.num_classes) for cate_pred in cate_preds ] flatten_cate_preds = torch.cat(cate_preds) # prepare one_hot pos_inds = torch.nonzero( flatten_cate_labels != self.num_classes).squeeze(1) flatten_cate_labels_oh = torch.zeros_like(flatten_cate_preds) flatten_cate_labels_oh[pos_inds, flatten_cate_labels[pos_inds]] = 1 loss_cate = self.focal_loss_weight * sigmoid_focal_loss_jit( flatten_cate_preds, flatten_cate_labels_oh, gamma=self.focal_loss_gamma, alpha=self.focal_loss_alpha, reduction="sum") / (num_ins + 1) return {'loss_ins': loss_ins, 'loss_cate': loss_cate}
def SMInst_losses( self, labels, reg_targets, logits_pred, reg_pred, ctrness_pred, mask_pred, mask_targets ): num_classes = logits_pred.size(1) labels = labels.flatten() pos_inds = torch.nonzero(labels != num_classes).squeeze(1) num_pos_local = pos_inds.numel() num_gpus = get_world_size() total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item() num_pos_avg = max(total_num_pos / num_gpus, 1.0) # prepare one_hot class_target = torch.zeros_like(logits_pred) class_target[pos_inds, labels[pos_inds]] = 1 class_loss = sigmoid_focal_loss_jit( logits_pred, class_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg # print(mask_pred.size(), mask_tower_interm_outputs.size()) reg_pred = reg_pred[pos_inds] reg_targets = reg_targets[pos_inds] ctrness_pred = ctrness_pred[pos_inds] mask_pred = mask_pred[pos_inds] # mask_activation_pred = mask_activation_pred[pos_inds] assert mask_pred.shape[0] == mask_targets.shape[0], \ print("The number(positive) should be equal between " "masks_pred(prediction) and mask_targets(target).") ctrness_targets = compute_ctrness_targets(reg_targets) ctrness_targets_sum = ctrness_targets.sum() ctrness_norm = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6) reg_loss = self.iou_loss( reg_pred, reg_targets, ctrness_targets ) / ctrness_norm ctrness_loss = F.binary_cross_entropy_with_logits( ctrness_pred, ctrness_targets, reduction="sum" ) / num_pos_avg mask_targets_ = self.mask_encoding.encoder(mask_targets) mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True) # compute the loss for the activation code as binary classification # activation_targets = (torch.abs(mask_targets_) > 1e-4) * 1. # activation_loss = F.binary_cross_entropy_with_logits( # mask_activation_pred, # activation_targets, # reduction='none' # ) # activation_loss = activation_loss.sum(1) * ctrness_targets # activation_loss = activation_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) # if self.thresh_with_active: # mask_pred = mask_pred * torch.sigmoid(mask_activation_pred) total_mask_loss = 0. if self.loss_on_mask: # n_components predictions --> m*m mask predictions without sigmoid # as sigmoid function is combined in loss. # mask_pred_, mask_pred_bin = self.mask_encoding.decoder(mask_pred, is_train=True) if 'mask_mse' in self.mask_loss_type: mask_loss = F.mse_loss( mask_pred_, mask_targets, reduction='none' ) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0) total_mask_loss += mask_loss if 'mask_iou' in self.mask_loss_type: overlap_ = torch.sum(mask_pred_bin * 1. * mask_targets, 1) union_ = torch.sum((mask_pred_bin + mask_targets) >= 1., 1) iou_loss = (1. - overlap_ / (union_ + 1e-4)) * ctrness_targets * self.mask_size ** 2 iou_loss = iou_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0) total_mask_loss += iou_loss if 'mask_difference' in self.mask_loss_type: w_ = torch.abs(mask_pred_bin * 1. - mask_targets * 1) # 1's are inconsistent pixels in hd_maps md_loss = torch.sum(w_, 1) * ctrness_targets md_loss = md_loss.sum() / max(ctrness_norm * self.mask_size ** 2, 1.0) total_mask_loss += md_loss if self.loss_on_code: # m*m mask labels --> n_components encoding labels # mask_targets_ = self.mask_encoding.encoder(mask_targets) if 'mse' in self.mask_loss_type: mask_loss = F.mse_loss( mask_pred, mask_targets_, reduction='none' ) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) if self.mask_sparse_weight > 0.: if self.sparsity_loss_type == 'L1': sparsity_loss = torch.sum(torch.abs(mask_pred), 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'L0': w_ = (torch.abs(mask_targets_) >= 1e-2) * 1. # the number of codes that are active sparsity_loss = torch.sum(w_, 1) * ctrness_targets sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L1': w_ = (torch.abs(mask_targets_) < 1e-2) * 1. # inactive codes, put L1 regularization on them sparsity_loss = torch.sum(torch.abs(mask_pred) * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_L2': w_ = (torch.abs(mask_targets_) < 1e-2) * 1. # inactive codes, put L2 regularization on them sparsity_loss = torch.sum(mask_pred ** 2. * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight elif self.sparsity_loss_type == 'weighted_KL': w_ = (torch.abs(mask_targets_) < 1e-2) * 1. # inactive codes, put L2 regularization on them kl_ = kl_divergence( mask_pred, self.kl_rho ) sparsity_loss = torch.sum(kl_ * w_, 1) / torch.sum(w_, 1) \ * ctrness_targets * self.num_codes sparsity_loss = sparsity_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) mask_loss = mask_loss * self.mask_loss_weight + \ sparsity_loss * self.mask_sparse_weight else: raise NotImplementedError total_mask_loss += mask_loss if 'smooth' in self.mask_loss_type: mask_loss = F.smooth_l1_loss( mask_pred, mask_targets_, reduction='none' ) mask_loss = mask_loss.sum(1) * ctrness_targets mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'cosine' in self.mask_loss_type: mask_loss = loss_cos_sim( mask_pred, mask_targets_ ) mask_loss = mask_loss * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss if 'kl_softmax' in self.mask_loss_type: mask_loss = loss_kl_div_softmax( mask_pred, mask_targets_ ) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss elif 'kl_sigmoid' in self.mask_loss_type: mask_loss = loss_kl_div_sigmoid( mask_pred, mask_targets_ ) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss elif 'kl' in self.mask_loss_type: mask_loss = kl_divergence( mask_pred, self.kl_rho ) mask_loss = mask_loss.sum(1) * ctrness_targets * self.num_codes mask_loss = mask_loss.sum() / max(ctrness_norm * self.num_codes, 1.0) total_mask_loss += mask_loss losses = { "loss_SMInst_cls": class_loss, "loss_SMInst_loc": reg_loss, "loss_SMInst_ctr": ctrness_loss, "loss_SMInst_mask": total_mask_loss, } return losses, {}
def __call__(self, locations, box_cls, box_regression, centerness, targets): """ Arguments: locations (list[BoxList]) box_cls (list[Tensor]) box_regression (list[Tensor]) centerness (list[Tensor]) targets (list[BoxList]) Returns: cls_loss (Tensor) reg_loss (Tensor) centerness_loss (Tensor) """ N = box_cls[0].size(0) num_classes = box_cls[0].size(1) labels, reg_targets = self.prepare_targets(locations, targets) box_cls_flatten = [] box_regression_flatten = [] centerness_flatten = [] labels_flatten = [] reg_targets_flatten = [] for l in range(len(labels)): box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape( -1, num_classes)) box_regression_flatten.append(box_regression[l].permute( 0, 2, 3, 1).reshape(-1, 4)) labels_flatten.append(labels[l].reshape(-1)) reg_targets_flatten.append(reg_targets[l].reshape(-1, 4)) centerness_flatten.append(centerness[l].reshape(-1)) box_cls_flatten = torch.cat(box_cls_flatten, dim=0) box_regression_flatten = torch.cat(box_regression_flatten, dim=0) centerness_flatten = torch.cat(centerness_flatten, dim=0) labels_flatten = torch.cat(labels_flatten, dim=0) reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) pos_inds = torch.nonzero(labels_flatten != 80).squeeze(1) box_regression_flatten = box_regression_flatten[pos_inds] reg_targets_flatten = reg_targets_flatten[pos_inds] centerness_flatten = centerness_flatten[pos_inds] #num_gpus = get_num_gpus() # sync num_pos from all gpus total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel() ])).item() #num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0) gt_classes_target = torch.zeros_like(box_cls_flatten) gt_classes_target[pos_inds, labels_flatten[pos_inds]] = 1 cls_loss = sigmoid_focal_loss_jit( #self.cls_loss_func( box_cls_flatten, gt_classes_target, #.int(), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / total_num_pos #num_pos_avg_per_gpu if pos_inds.numel() > 0: centerness_targets = self.compute_centerness_targets( reg_targets_flatten) # average sum_centerness_targets from all gpus, # which is used to normalize centerness-weighed reg loss sum_centerness_targets_avg_per_gpu = \ reduce_sum(centerness_targets.sum()).item() #/ float(num_gpus) reg_loss = self.box_reg_loss_func( box_regression_flatten, reg_targets_flatten, centerness_targets) / sum_centerness_targets_avg_per_gpu centerness_loss = self.centerness_loss_func( centerness_flatten, centerness_targets) / total_num_pos #num_pos_avg_per_gpu else: reg_loss = box_regression_flatten.sum() reduce_sum(centerness_flatten.new_tensor([0.0])) centerness_loss = centerness_flatten.sum() return cls_loss, reg_loss, centerness_loss
def losses(self, labels, reg_targets, box_cls, box_regression, centerness): N, num_classes = box_cls[0].shape[:2] box_cls_flatten = [] box_regression_flatten = [] centerness_flatten = [] labels_flatten = [] reg_targets_flatten = [] for l in range(len(labels)): box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape( -1, num_classes)) box_regression_flatten.append(box_regression[l].permute( 0, 2, 3, 1).reshape(-1, 4)) labels_flatten.append(labels[l].reshape(-1)) reg_targets_flatten.append(reg_targets[l].reshape(-1, 4)) centerness_flatten.append(centerness[l].reshape(-1)) box_cls_flatten = torch.cat(box_cls_flatten, dim=0) box_regression_flatten = torch.cat(box_regression_flatten, dim=0) centerness_flatten = torch.cat(centerness_flatten, dim=0) labels_flatten = torch.cat(labels_flatten, dim=0) reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) pos_inds = torch.nonzero((labels_flatten >= 0) & ( labels_flatten != self.num_classes)).squeeze(1) box_regression_flatten = box_regression_flatten[pos_inds] reg_targets_flatten = reg_targets_flatten[pos_inds] centerness_flatten = centerness_flatten[pos_inds] num_gpus = get_num_gpus() # sync num_pos from all gpus total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel() ])).item() num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0) gt_classes_target = torch.zeros_like(box_cls_flatten) foreground_idxs = (labels_flatten >= 0) & (labels_flatten != self.num_classes) gt_classes_target[foreground_idxs, labels_flatten[foreground_idxs]] = 1 cls_loss = sigmoid_focal_loss_jit( box_cls_flatten, gt_classes_target, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_pos_avg_per_gpu if pos_inds.numel() > 0: centerness_targets = compute_centerness_targets( reg_targets_flatten) # average sum_centerness_targets from all gpus, # which is used to normalize centerness-weighed reg loss sum_centerness_targets_avg_per_gpu = \ reduce_sum(centerness_targets.sum()).item() / float(num_gpus) reg_loss = iou_loss(box_regression_flatten, reg_targets_flatten, centerness_targets, loss_type=self.iou_loss_type ) / sum_centerness_targets_avg_per_gpu centerness_loss = F.binary_cross_entropy_with_logits( centerness_flatten, centerness_targets, reduction='sum') / num_pos_avg_per_gpu else: reg_loss = box_regression_flatten.sum() reduce_sum(centerness_flatten.new_tensor([0.0])) centerness_loss = centerness_flatten.sum() return dict(cls_loss=cls_loss, reg_loss=reg_loss, centerness_loss=centerness_loss)
def run_focal_loss_jit() -> None: fl = sigmoid_focal_loss_jit( inputs, targets, gamma=0, alpha=alpha, reduction="mean" ) fl.backward() torch.cuda.synchronize()