def losses(self, gt_classes, gt_shifts_deltas, pred_class_logits, pred_shift_deltas, pred_filtering): """ Args: For `gt_classes` and `gt_shifts_deltas` parameters, see :meth:`FCOS.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of shifts across levels, i.e. sum(Hi x Wi) For `pred_class_logits`, `pred_shift_deltas` and `pred_fitering`, see :meth:`FCOSHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_shift_deltas, pred_filtering = \ permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_filtering, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.view(-1, 4) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 num_foreground = comm.all_reduce(num_foreground) / float(comm.get_world_size()) pred_class_logits = pred_class_logits.sigmoid() * pred_filtering.sigmoid() # logits loss loss_cls = focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1.0, num_foreground) # regression loss loss_box_reg = iou_loss( pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", ) / max(1.0, num_foreground) * self.reg_weight return { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, }
def losses( self, gt_classes, gt_shifts_deltas, gt_centerness, gt_classes_border, gt_deltas_border, pred_class_logits, pred_shift_deltas, pred_centerness, border_box_cls, border_bbox_reg, ): """ Args: For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see :meth:`BorderDet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of shifts across levels, i.e. sum(Hi x Wi) For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see :meth:`BorderHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ ( pred_class_logits, pred_shift_deltas, pred_centerness, border_class_logits, border_shift_deltas, ) = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_centerness, border_box_cls, border_bbox_reg, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. # fcos gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.view(-1, 4) gt_centerness = gt_centerness.view(-1, 1) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() acc_centerness_num = gt_centerness[foreground_idxs].sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 dist.all_reduce(num_foreground) num_foreground /= dist.get_world_size() dist.all_reduce(acc_centerness_num) acc_centerness_num /= dist.get_world_size() # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, num_foreground) # regression loss loss_box_reg = iou_loss( pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], gt_centerness[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", ) / max(1, acc_centerness_num) # centerness loss loss_centerness = F.binary_cross_entropy_with_logits( pred_centerness[foreground_idxs], gt_centerness[foreground_idxs], reduction="sum", ) / max(1, num_foreground) # borderdet gt_classes_border = gt_classes_border.flatten() gt_deltas_border = gt_deltas_border.view(-1, 4) valid_idxs_border = gt_classes_border >= 0 foreground_idxs_border = (gt_classes_border >= 0) & (gt_classes_border != self.num_classes) num_foreground_border = foreground_idxs_border.sum() gt_classes_border_target = torch.zeros_like(border_class_logits) gt_classes_border_target[foreground_idxs_border, gt_classes_border[foreground_idxs_border]] = 1 dist.all_reduce(num_foreground_border) num_foreground_border /= dist.get_world_size() num_foreground_border = max(num_foreground_border, 1.0) loss_border_cls = sigmoid_focal_loss_jit( border_class_logits[valid_idxs_border], gt_classes_border_target[valid_idxs_border], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_foreground_border if foreground_idxs_border.numel() > 0: loss_border_reg = ( smooth_l1_loss(border_shift_deltas[foreground_idxs_border], gt_deltas_border[foreground_idxs_border], beta=0, reduction="sum") / num_foreground_border) else: loss_border_reg = border_shift_deltas.sum() return { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, "loss_centerness": loss_centerness, "loss_border_cls": loss_border_cls, "loss_border_reg": loss_border_reg, }
def losses(self, gt_classes, gt_shifts_deltas, gt_centerness, pred_class_logits, pred_shift_deltas, pred_centerness): """ Args: For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see :meth:`FCOS.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of shifts across levels, i.e. sum(Hi x Wi) For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see :meth:`FCOSHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_shift_deltas, pred_centerness = \ permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_centerness, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.view(-1, 4) gt_centerness = gt_centerness.view(-1, 1) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 num_foreground = comm.all_reduce(num_foreground) / float(comm.get_world_size()) num_foreground_centerness = gt_centerness[foreground_idxs].sum() num_targets = comm.all_reduce(num_foreground_centerness) / float(comm.get_world_size()) # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1.0, num_foreground) # regression loss loss_box_reg = iou_loss( pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], gt_centerness[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", ) / max(1.0, num_targets) # centerness loss loss_centerness = F.binary_cross_entropy_with_logits( pred_centerness[foreground_idxs], gt_centerness[foreground_idxs], reduction="sum", ) / max(1, num_foreground) loss = { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, "loss_centerness": loss_centerness, } # budget loss if self.is_dynamic_head and self.budget_loss_lambda != 0: soft_cost, used_cost, full_cost = get_module_running_cost(self) loss_budget = (soft_cost / full_cost).mean() * self.budget_loss_lambda storage = get_event_storage() storage.put_scalar("complxity_ratio", (used_cost / full_cost).mean()) loss.update({"loss_budget": loss_budget}) return loss
def proposals_losses(self, gt_classes, gt_shifts_deltas, gt_centerness, gt_inds, im_inds, pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params, fpn_levels, shifts): pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params = \ permute_all_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_centerness, pred_inst_params, self.num_gen_params, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.reshape(-1, 4) gt_centerness = gt_centerness.reshape(-1, 1) fpn_levels = fpn_levels.flatten() im_inds = im_inds.flatten() gt_inds = gt_inds.flatten() valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 num_foreground = comm.all_reduce(num_foreground) / float( comm.get_world_size()) num_foreground_centerness = gt_centerness[foreground_idxs].sum() num_targets = comm.all_reduce(num_foreground_centerness) / float( comm.get_world_size()) # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1.0, num_foreground) # regression loss loss_box_reg = iou_loss(pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], gt_centerness[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", smooth=self.iou_smooth) / max( 1e-6, num_targets) # centerness loss loss_centerness = F.binary_cross_entropy_with_logits( pred_centerness[foreground_idxs], gt_centerness[foreground_idxs], reduction="sum", ) / max(1, num_foreground) proposals_losses = { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, "loss_centerness": loss_centerness } all_shifts = torch.cat([torch.cat(shift) for shift in shifts]) proposals = Instances((0, 0)) proposals.inst_parmas = pred_inst_params[foreground_idxs] proposals.fpn_levels = fpn_levels[foreground_idxs] proposals.shifts = all_shifts[foreground_idxs] proposals.gt_inds = gt_inds[foreground_idxs] proposals.im_inds = im_inds[foreground_idxs] # select_instances for saving memory if len(proposals): if self.topk_proposals_per_im != -1: proposals.gt_cls = gt_classes[foreground_idxs] proposals.pred_logits = pred_class_logits[foreground_idxs] proposals.pred_centerness = pred_centerness[foreground_idxs] proposals = self.select_instances(proposals) return proposals_losses, proposals
def get_lla_assignments_and_losses(self, shifts, targets, box_cls, box_delta, box_iou): gt_classes = [] box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls] box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta] box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou] box_cls = torch.cat(box_cls, dim=1) box_delta = torch.cat(box_delta, dim=1) box_iou = torch.cat(box_iou, dim=1) losses_cls = [] losses_box_reg = [] losses_iou = [] num_fg = 0 for shifts_per_image, targets_per_image, box_cls_per_image, \ box_delta_per_image, box_iou_per_image in zip( shifts, targets, box_cls, box_delta, box_iou): shifts_over_all = torch.cat(shifts_per_image, dim=0) gt_boxes = targets_per_image.gt_boxes gt_classes = targets_per_image.gt_classes deltas = self.shift2box_transform.get_deltas( shifts_over_all, gt_boxes.tensor.unsqueeze(1)) is_in_boxes = deltas.min(dim=-1).values > 0.01 shape = (len(targets_per_image), len(shifts_over_all), -1) box_cls_per_image_unexpanded = box_cls_per_image box_delta_per_image_unexpanded = box_delta_per_image box_cls_per_image = box_cls_per_image.unsqueeze(0).expand(shape) gt_cls_per_image = F.one_hot( torch.max(gt_classes, torch.zeros_like(gt_classes)), self.num_classes).float().unsqueeze(1).expand(shape) with torch.no_grad(): loss_cls = sigmoid_focal_loss_jit( box_cls_per_image, gt_cls_per_image, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma).sum(dim=-1) loss_cls_bg = sigmoid_focal_loss_jit( box_cls_per_image_unexpanded, torch.zeros_like(box_cls_per_image_unexpanded), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma).sum(dim=-1) box_delta_per_image = box_delta_per_image.unsqueeze(0).expand( shape) gt_delta_per_image = self.shift2box_transform.get_deltas( shifts_over_all, gt_boxes.tensor.unsqueeze(1)) loss_delta = iou_loss(box_delta_per_image, gt_delta_per_image, box_mode="ltrb", loss_type='iou') ious = get_ious(box_delta_per_image, gt_delta_per_image, box_mode="ltrb", loss_type='iou') loss = loss_cls + self.reg_cost * loss_delta + 1e3 * ( 1 - is_in_boxes.float()) loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0) num_gt = loss.shape[0] - 1 num_anchor = loss.shape[1] # Topk matching_matrix = torch.zeros_like(loss) _, topk_idx = torch.topk(loss[:-1], k=self.topk, dim=1, largest=False) matching_matrix[torch.arange(num_gt).unsqueeze(1). repeat(1, self.topk).view(-1), topk_idx.view(-1)] = 1. # make sure one anchor with one gt anchor_matched_gt = matching_matrix.sum(0) if (anchor_matched_gt > 1).sum() > 0: loss_min, loss_argmin = torch.min( loss[:-1, anchor_matched_gt > 1], dim=0) matching_matrix[:, anchor_matched_gt > 1] *= 0. matching_matrix[loss_argmin, anchor_matched_gt > 1] = 1. anchor_matched_gt = matching_matrix.sum(0) num_fg += matching_matrix.sum() matching_matrix[ -1] = 1. - anchor_matched_gt # assignment for Background assigned_gt_inds = torch.argmax(matching_matrix, dim=0) gt_cls_per_image_bg = gt_cls_per_image.new_zeros( (gt_cls_per_image.size(1), gt_cls_per_image.size(2))).unsqueeze(0) gt_cls_per_image_with_bg = torch.cat( [gt_cls_per_image, gt_cls_per_image_bg], dim=0) cls_target_per_image = gt_cls_per_image_with_bg[ assigned_gt_inds, torch.arange(num_anchor)] # Dealing with Crowdhuman ignore label gt_classes_ = torch.cat([gt_classes, gt_classes.new_zeros(1)]) anchor_cls_labels = gt_classes_[assigned_gt_inds] valid_flag = anchor_cls_labels >= 0 pos_mask = assigned_gt_inds != len( targets_per_image) # get foreground mask valid_fg = pos_mask & valid_flag assigned_fg_inds = assigned_gt_inds[valid_fg] range_fg = torch.arange(num_anchor)[valid_fg] ious_fg = ious[assigned_fg_inds, range_fg] anchor_loss_cls = sigmoid_focal_loss_jit( box_cls_per_image_unexpanded[valid_flag], cls_target_per_image[valid_flag], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma).sum(dim=-1) delta_target = gt_delta_per_image[assigned_fg_inds, range_fg] anchor_loss_delta = 2. * iou_loss( box_delta_per_image_unexpanded[valid_fg], delta_target, box_mode="ltrb", loss_type=self.iou_loss_type) anchor_loss_iou = 0.5 * F.binary_cross_entropy_with_logits( box_iou_per_image.squeeze(1)[valid_fg], ious_fg, reduction='none') losses_cls.append(anchor_loss_cls.sum()) losses_box_reg.append(anchor_loss_delta.sum()) losses_iou.append(anchor_loss_iou.sum()) if self.norm_sync: dist.all_reduce(num_fg) num_fg = num_fg.float() / dist.get_world_size() return { 'loss_cls': torch.stack(losses_cls).sum() / num_fg, 'loss_box_reg': torch.stack(losses_box_reg).sum() / num_fg, 'loss_iou': torch.stack(losses_iou).sum() / num_fg }
def losses(self, shifts, gt_instances, box_cls, box_delta, box_center): box_cls_flattened = [ permute_to_N_HWA_K(x, self.num_classes) for x in box_cls ] box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta] box_center_flattened = [permute_to_N_HWA_K(x, 1) for x in box_center] pred_class_logits = cat(box_cls_flattened, dim=1) pred_shift_deltas = cat(box_delta_flattened, dim=1) pred_obj_logits = cat(box_center_flattened, dim=1) pred_class_probs = pred_class_logits.sigmoid() pred_obj_probs = pred_obj_logits.sigmoid() pred_box_probs = [] num_foreground = pred_class_logits.new_zeros(1) num_background = pred_class_logits.new_zeros(1) positive_losses = [] gaussian_norm_losses = [] for shifts_per_image, gt_instances_per_image, \ pred_class_probs_per_image, pred_shift_deltas_per_image, \ pred_obj_probs_per_image in zip( shifts, gt_instances, pred_class_probs, pred_shift_deltas, pred_obj_probs): locations = torch.cat(shifts_per_image, dim=0) labels = gt_instances_per_image.gt_classes gt_boxes = gt_instances_per_image.gt_boxes target_shift_deltas = self.shift2box_transform.get_deltas( locations, gt_boxes.tensor.unsqueeze(1)) is_in_boxes = target_shift_deltas.min(dim=-1).values > 0 foreground_idxs = torch.nonzero(is_in_boxes, as_tuple=True) with torch.no_grad(): # predicted_boxes_per_image: a_{j}^{loc}, shape: [j, 4] predicted_boxes_per_image = self.shift2box_transform.apply_deltas( pred_shift_deltas_per_image, locations) # gt_pred_iou: IoU_{ij}^{loc}, shape: [i, j] gt_pred_iou = pairwise_iou( gt_boxes, Boxes(predicted_boxes_per_image)).max( dim=0, keepdim=True).values.repeat( len(gt_instances_per_image), 1) # pred_box_prob_per_image: P{a_{j} \in A_{+}}, shape: [j, c] pred_box_prob_per_image = torch.zeros_like( pred_class_probs_per_image) box_prob = 1 / (1 - gt_pred_iou[foreground_idxs]).clamp_(1e-12) for i in range(len(gt_instances_per_image)): idxs = foreground_idxs[0] == i if idxs.sum() > 0: box_prob[idxs] = normalize(box_prob[idxs]) pred_box_prob_per_image[foreground_idxs[1], labels[foreground_idxs[0]]] = box_prob pred_box_probs.append(pred_box_prob_per_image) normal_probs = [] for stride, shifts_i in zip(self.fpn_strides, shifts_per_image): gt_shift_deltas = self.shift2box_transform.get_deltas( shifts_i, gt_boxes.tensor.unsqueeze(1)) distances = (gt_shift_deltas[..., :2] - gt_shift_deltas[..., 2:]) / 2 normal_probs.append( normal_distribution(distances / stride, self.mu[labels].unsqueeze(1), self.sigma[labels].unsqueeze(1))) normal_probs = torch.cat(normal_probs, dim=1).prod(dim=-1) composed_cls_prob = pred_class_probs_per_image[:, labels] * pred_obj_probs_per_image # matched_gt_shift_deltas: P_{ij}^{loc} loss_box_reg = iou_loss(pred_shift_deltas_per_image.unsqueeze(0), target_shift_deltas, box_mode="ltrb", loss_type=self.iou_loss_type, reduction="none") * self.reg_weight pred_reg_probs = (-loss_box_reg).exp() # positive_losses: { -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } positive_losses.append( positive_bag_loss( composed_cls_prob.permute(1, 0) * pred_reg_probs, is_in_boxes.float(), normal_probs)) num_foreground += len(gt_instances_per_image) num_background += normal_probs[foreground_idxs].sum().item() gaussian_norm_losses.append( len(gt_instances_per_image) / normal_probs[foreground_idxs].sum().clamp_(1e-12)) if dist.is_initialized(): dist.all_reduce(num_foreground) num_foreground /= dist.get_world_size() dist.all_reduce(num_background) num_background /= dist.get_world_size() # positive_loss: \sum_{i}{ -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } / ||B|| positive_loss = torch.cat(positive_losses).sum() / max( 1, num_foreground) # pred_box_probs: P{a_{j} \in A_{+}} pred_box_probs = torch.stack(pred_box_probs, dim=0) # negative_loss: \sum_{j}{ FL( (1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}) ) } / n||B|| negative_loss = negative_bag_loss( pred_class_probs * pred_obj_probs * (1 - pred_box_probs), self.focal_loss_gamma).sum() / max(1, num_background) loss_pos = positive_loss * self.focal_loss_alpha loss_neg = negative_loss * (1 - self.focal_loss_alpha) loss_norm = torch.stack(gaussian_norm_losses).mean() * ( 1 - self.focal_loss_alpha) return { "loss_pos": loss_pos, "loss_neg": loss_neg, "loss_norm": loss_norm, }
def get_ground_truth(self, shifts, targets, box_cls, box_delta): """ Args: shifts (list[list[Tensor]]): a list of N=#image elements. Each is a list of #feature level tensors. The tensors contains shifts of this image on the specific feature level. targets (list[Instances]): a list of N `Instances`s. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Specify `targets` during training only. Returns: gt_classes (Tensor): An integer tensor of shape (N, R) storing ground-truth labels for each shift. R is the total number of shifts, i.e. the sum of Hi x Wi for all levels. Shifts in the valid boxes are assigned their corresponding label in the [0, K-1] range. Shifts in the background are assigned the label "K". Shifts in the ignore areas are assigned a label "-1", i.e. ignore. gt_shifts_deltas (Tensor): Shape (N, R, 4). The last dimension represents ground-truth shift2box transform targets (dl, dt, dr, db) that map each shift to its matched ground-truth box. The values in the tensor are meaningful only when the corresponding shift is labeled as foreground. """ gt_classes = [] gt_shifts_deltas = [] box_cls = torch.cat( [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls], dim=1) box_delta = torch.cat([permute_to_N_HWA_K(x, 4) for x in box_delta], dim=1) num_fg = 0 num_gt = 0 for shifts_per_image, targets_per_image, box_cls_per_image, box_delta_per_image in zip( shifts, targets, box_cls, box_delta): shifts_over_all_feature_maps = torch.cat(shifts_per_image, dim=0) gt_boxes = targets_per_image.gt_boxes shape = (len(targets_per_image), len(shifts_over_all_feature_maps), -1) gt_cls_per_image = F.one_hot(targets_per_image.gt_classes, self.num_classes).float() loss_cls = sigmoid_focal_loss_jit( box_cls_per_image.unsqueeze(0).expand(shape), gt_cls_per_image.unsqueeze(1).expand(shape), alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, ).sum(dim=2) gt_delta_per_image = self.shift2box_transform.get_deltas( shifts_over_all_feature_maps, gt_boxes.tensor.unsqueeze(1)) loss_delta = iou_loss( box_delta_per_image.unsqueeze(0).expand(shape), gt_delta_per_image, box_mode="ltrb", loss_type=self.iou_loss_type, ) * self.reg_weight loss = loss_cls + loss_delta INF = 1e8 deltas = self.shift2box_transform.get_deltas( shifts_over_all_feature_maps, gt_boxes.tensor.unsqueeze(1)) if self.center_sampling_radius > 0: centers = gt_boxes.get_centers() is_in_boxes = [] for stride, shifts_i in zip(self.fpn_strides, shifts_per_image): radius = stride * self.center_sampling_radius center_boxes = torch.cat(( torch.max(centers - radius, gt_boxes.tensor[:, :2]), torch.min(centers + radius, gt_boxes.tensor[:, 2:]), ), dim=-1) center_deltas = self.shift2box_transform.get_deltas( shifts_i, center_boxes.unsqueeze(1)) is_in_boxes.append(center_deltas.min(dim=-1).values > 0) is_in_boxes = torch.cat(is_in_boxes, dim=1) else: # no center sampling, it will use all the locations within a ground-truth box is_in_boxes = deltas.min(dim=-1).values > 0 loss[~is_in_boxes] = INF gt_idxs, shift_idxs = linear_sum_assignment(loss.cpu().numpy()) num_fg += len(shift_idxs) num_gt += len(targets_per_image) gt_classes_i = shifts_over_all_feature_maps.new_full( (len(shifts_over_all_feature_maps), ), self.num_classes, dtype=torch.long) gt_shifts_reg_deltas_i = shifts_over_all_feature_maps.new_zeros( len(shifts_over_all_feature_maps), 4) if len(targets_per_image) > 0: # ground truth classes gt_classes_i[shift_idxs] = targets_per_image.gt_classes[ gt_idxs] # ground truth box regression gt_shifts_reg_deltas_i[ shift_idxs] = self.shift2box_transform.get_deltas( shifts_over_all_feature_maps[shift_idxs], gt_boxes[gt_idxs].tensor) gt_classes.append(gt_classes_i) gt_shifts_deltas.append(gt_shifts_reg_deltas_i) get_event_storage().put_scalar("num_fg_per_gt", num_fg / num_gt) return torch.stack(gt_classes), torch.stack(gt_shifts_deltas)
def losses(self, gt_classes, gt_shifts_deltas, gt_ious, pred_class_logits, pred_shift_deltas, pred_ious): """ Args: For `gt_classes`, `gt_shifts_deltas` and `gt_ious` parameters, see :meth:`FCOS.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of shifts across levels, i.e. sum(Hi x Wi) For `pred_class_logits`, `pred_shift_deltas` and `pred_ious`, see :meth:`FCOSHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_shift_deltas, pred_ious = \ permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_ious, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.view(-1, 4) gt_ious = gt_ious.view(-1, 1) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() num_target = gt_ious[foreground_idxs].sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 if self.norm_sync: dist.all_reduce(num_foreground) num_foreground /= dist.get_world_size() # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, num_foreground) # regression loss loss_box_reg = 2. * iou_loss( pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", ) / max(1, num_foreground) # iou branch loss loss_iou = 0.5 * F.binary_cross_entropy_with_logits( pred_ious[foreground_idxs], gt_ious[foreground_idxs], reduction="sum", ) / max(1, num_foreground) return { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, "loss_iou": loss_iou }