def smooth_l1_loss(self): """ Compute the smooth L1 loss for box regression. Returns: scalar Tensor """ if self._no_instances: return 0.0 * self.pred_proposal_deltas.sum() gt_proposal_deltas = self.box2box_transform.get_deltas( self.proposals.tensor, self.gt_boxes.tensor) box_dim = gt_proposal_deltas.size(1) # 4 or 5 cls_agnostic_bbox_reg = self.pred_proposal_deltas.size(1) == box_dim device = self.pred_proposal_deltas.device bg_class_ind = self.pred_class_logits.shape[1] - 1 # Box delta loss is only computed between the prediction for the gt class k # (if 0 <= k < bg_class_ind) and the target; there is no loss defined on predictions # for non-gt classes and background. # Empty fg_inds produces a valid loss of zero as long as the size_average # arg to smooth_l1_loss is False (otherwise it uses torch.mean internally # and would produce a nan loss). fg_inds = torch.nonzero( (self.gt_classes >= 0) & (self.gt_classes < bg_class_ind), as_tuple=False).squeeze(1) if cls_agnostic_bbox_reg: # pred_proposal_deltas only corresponds to foreground class for agnostic gt_class_cols = torch.arange(box_dim, device=device) else: fg_gt_classes = self.gt_classes[fg_inds] # pred_proposal_deltas for class k are located in columns [b * k : b * k + b], # where b is the dimension of box representation (4 or 5) # Note that compared to Detectron1, # we do not perform bounding box regression for background classes. gt_class_cols = box_dim * fg_gt_classes[:, None] + torch.arange( box_dim, device=device) loss_box_reg = smooth_l1_loss( self.pred_proposal_deltas[fg_inds[:, None], gt_class_cols], gt_proposal_deltas[fg_inds], self.smooth_l1_beta, reduction="sum", ) # The loss is normalized using the total number of regions (R), not the number # of foreground regions even though the box regression loss is only defined on # foreground regions. Why? Because doing so gives equal training influence to # each foreground example. To see how, consider two different minibatches: # (1) Contains a single foreground region # (2) Contains 100 foreground regions # If we normalize by the number of foreground regions, the single example in # minibatch (1) will be given 100 times as much influence as each foreground # example in minibatch (2). Normalizing by the total number of regions, R, # means that the single example in minibatch (1) and each of the 100 examples # in minibatch (2) are given equal influence. loss_box_reg = loss_box_reg / self.gt_classes.numel() return loss_box_reg
def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits, pred_anchor_deltas): """ Args: For `gt_classes` and `gt_anchors_deltas` parameters, see :meth:`RetinaNet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x A) For `pred_class_logits` and `pred_anchor_deltas`, see :meth:`RetinaNetHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_anchor_deltas, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_anchors_deltas = gt_anchors_deltas.view(-1, 4) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 self.loss_normalizer = self.loss_normalizer_momentum * self.loss_normalizer + ( 1 - self.loss_normalizer_momentum) * max(num_foreground.item(), 1) # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / self.loss_normalizer # regression loss loss_box_reg = smooth_l1_loss( pred_anchor_deltas[foreground_idxs], gt_anchors_deltas[foreground_idxs], beta=self.smooth_l1_loss_beta, reduction="sum", ) / self.loss_normalizer return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits, pred_anchor_deltas): """ Args: For `gt_classes` and `gt_anchors_deltas` parameters, see :meth:`EfficientDet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x A) For `pred_class_logits` and `pred_anchor_deltas`, see :meth:`EfficientDetHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_anchor_deltas, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_anchors_deltas = gt_anchors_deltas.view(-1, 4) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 # Classification loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, num_foreground) # Regression loss, refer to the official released code. # See: https://github.com/google/automl/blob/master/efficientdet/det_model_fn.py loss_box_reg = self.box_loss_weight * self.smooth_l1_loss_beta * smooth_l1_loss( pred_anchor_deltas[foreground_idxs], gt_anchors_deltas[foreground_idxs], beta=self.smooth_l1_loss_beta, reduction="sum", ) / max(1, num_foreground * self.regress_norm) return {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}
def losses(self, gt_classes, gt_anchors_deltas, pred_class_logits, pred_anchor_deltas): """ Args: For `gt_classes` and `gt_anchors_deltas` parameters, see :meth:`RetinaNet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of anchors across levels, i.e. sum(Hi x Wi x A) For `pred_class_logits` and `pred_anchor_deltas`, see :meth:`RetinaNetHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ pred_class_logits, pred_anchor_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_anchor_deltas, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. gt_classes = gt_classes.flatten() gt_anchors_deltas = gt_anchors_deltas.view(-1, 4) pos_inds = torch.nonzero((gt_classes >= 0) & (gt_classes != self.num_classes)).squeeze(1) retinanet_regression_loss = smooth_l1_loss( pred_anchor_deltas[pos_inds], gt_anchors_deltas[pos_inds], beta=self.smooth_l1_loss_beta, # size_average=False, reduction="sum", ) / max(1, pos_inds.numel() * self.regress_norm) labels = torch.ones_like(gt_classes) # convert labels from 0~79 to 1~80 labels[pos_inds] += gt_classes[pos_inds] labels[gt_classes == -1] = gt_classes[gt_classes == -1] labels[gt_classes == self.num_classes] = 0 labels = labels.int() retinanet_cls_loss = self.box_cls_loss_func(pred_class_logits, labels) return { "loss_cls": retinanet_cls_loss, "loss_box_reg": retinanet_regression_loss }
def rpn_losses( gt_objectness_logits, gt_anchor_deltas, pred_objectness_logits, pred_anchor_deltas, smooth_l1_beta, ): """ Args: gt_objectness_logits (Tensor): shape (N,), each element in {-1, 0, 1} representing ground-truth objectness labels with: -1 = ignore; 0 = not object; 1 = object. gt_anchor_deltas (Tensor): shape (N, box_dim), row i represents ground-truth box2box transform targets (dx, dy, dw, dh) or (dx, dy, dw, dh, da) that map anchor i to its matched ground-truth box. pred_objectness_logits (Tensor): shape (N,), each element is a predicted objectness logit. pred_anchor_deltas (Tensor): shape (N, box_dim), each row is a predicted box2box transform (dx, dy, dw, dh) or (dx, dy, dw, dh, da) smooth_l1_beta (float): The transition point between L1 and L2 loss in the smooth L1 loss function. When set to 0, the loss becomes L1. When set to +inf, the loss becomes constant 0. Returns: objectness_loss, localization_loss, both unnormalized (summed over samples). """ pos_masks = gt_objectness_logits == 1 localization_loss = smooth_l1_loss(pred_anchor_deltas[pos_masks], gt_anchor_deltas[pos_masks], smooth_l1_beta, reduction="sum") valid_masks = gt_objectness_logits >= 0 objectness_loss = F.binary_cross_entropy_with_logits( pred_objectness_logits[valid_masks], gt_objectness_logits[valid_masks].to(torch.float32), reduction="sum", ) return objectness_loss, localization_loss
def losses( self, gt_classes, gt_shifts_deltas, gt_centerness, gt_classes_border, gt_deltas_border, pred_class_logits, pred_shift_deltas, pred_centerness, border_box_cls, border_bbox_reg, ): """ Args: For `gt_classes`, `gt_shifts_deltas` and `gt_centerness` parameters, see :meth:`BorderDet.get_ground_truth`. Their shapes are (N, R) and (N, R, 4), respectively, where R is the total number of shifts across levels, i.e. sum(Hi x Wi) For `pred_class_logits`, `pred_shift_deltas` and `pred_centerness`, see :meth:`BorderHead.forward`. Returns: dict[str: Tensor]: mapping from a named loss to a scalar tensor storing the loss. Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" """ ( pred_class_logits, pred_shift_deltas, pred_centerness, border_class_logits, border_shift_deltas, ) = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_class_logits, pred_shift_deltas, pred_centerness, border_box_cls, border_bbox_reg, self.num_classes ) # Shapes: (N x R, K) and (N x R, 4), respectively. # fcos gt_classes = gt_classes.flatten() gt_shifts_deltas = gt_shifts_deltas.view(-1, 4) gt_centerness = gt_centerness.view(-1, 1) valid_idxs = gt_classes >= 0 foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes) num_foreground = foreground_idxs.sum() acc_centerness_num = gt_centerness[foreground_idxs].sum() gt_classes_target = torch.zeros_like(pred_class_logits) gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1 dist.all_reduce(num_foreground) num_foreground /= dist.get_world_size() dist.all_reduce(acc_centerness_num) acc_centerness_num /= dist.get_world_size() # logits loss loss_cls = sigmoid_focal_loss_jit( pred_class_logits[valid_idxs], gt_classes_target[valid_idxs], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / max(1, num_foreground) # regression loss loss_box_reg = iou_loss( pred_shift_deltas[foreground_idxs], gt_shifts_deltas[foreground_idxs], gt_centerness[foreground_idxs], box_mode="ltrb", loss_type=self.iou_loss_type, reduction="sum", ) / max(1, acc_centerness_num) # centerness loss loss_centerness = F.binary_cross_entropy_with_logits( pred_centerness[foreground_idxs], gt_centerness[foreground_idxs], reduction="sum", ) / max(1, num_foreground) # borderdet gt_classes_border = gt_classes_border.flatten() gt_deltas_border = gt_deltas_border.view(-1, 4) valid_idxs_border = gt_classes_border >= 0 foreground_idxs_border = (gt_classes_border >= 0) & (gt_classes_border != self.num_classes) num_foreground_border = foreground_idxs_border.sum() gt_classes_border_target = torch.zeros_like(border_class_logits) gt_classes_border_target[foreground_idxs_border, gt_classes_border[foreground_idxs_border]] = 1 dist.all_reduce(num_foreground_border) num_foreground_border /= dist.get_world_size() num_foreground_border = max(num_foreground_border, 1.0) loss_border_cls = sigmoid_focal_loss_jit( border_class_logits[valid_idxs_border], gt_classes_border_target[valid_idxs_border], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / num_foreground_border if foreground_idxs_border.numel() > 0: loss_border_reg = ( smooth_l1_loss(border_shift_deltas[foreground_idxs_border], gt_deltas_border[foreground_idxs_border], beta=0, reduction="sum") / num_foreground_border) else: loss_border_reg = border_shift_deltas.sum() return { "loss_cls": loss_cls, "loss_box_reg": loss_box_reg, "loss_centerness": loss_centerness, "loss_border_cls": loss_border_cls, "loss_border_reg": loss_border_reg, }
def losses( self, gt_class_info, gt_delta_info, gt_mask_info, num_fg, pred_logits, pred_deltas, pred_masks, ): """ Args: For `gt_class_info`, `gt_delta_info`, `gt_mask_info` and `num_fg` parameters, see :meth:`TensorMask.get_ground_truth`. For `pred_logits`, `pred_deltas` and `pred_masks`, see :meth:`TensorMaskHead.forward`. Returns: losses (dict[str: Tensor]): mapping from a named loss to a scalar tensor storing the loss. Used during training only. The potential dict keys are: "loss_cls", "loss_box_reg" and "loss_mask". """ gt_classes_target, gt_valid_inds = gt_class_info gt_deltas, gt_fg_inds = gt_delta_info gt_masks, gt_mask_inds = gt_mask_info loss_normalizer = torch.tensor(max(1, num_fg), dtype=torch.float32, device=self.device) # classification and regression pred_logits, pred_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_logits, pred_deltas, self.num_classes) loss_cls = (sigmoid_focal_loss_star_jit( pred_logits[gt_valid_inds], gt_classes_target[gt_valid_inds], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / loss_normalizer) if num_fg == 0: loss_box_reg = pred_deltas.sum() * 0 else: loss_box_reg = (smooth_l1_loss( pred_deltas[gt_fg_inds], gt_deltas, beta=0.0, reduction="sum") / loss_normalizer) losses = {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg} # mask prediction if self.mask_on: loss_mask = 0 for lvl in range(self.num_levels): cur_level_factor = 2**lvl if self.bipyramid_on else 1 for anc in range(self.num_anchors): cur_gt_mask_inds = gt_mask_inds[lvl][anc] if cur_gt_mask_inds is None: loss_mask += pred_masks[lvl][anc][0, 0, 0, 0] * 0 else: cur_mask_size = self.mask_sizes[anc] * cur_level_factor # TODO maybe there are numerical issues when mask sizes are large cur_size_divider = torch.tensor( self.mask_loss_weight / (cur_mask_size**2), dtype=torch.float32, device=self.device, ) cur_pred_masks = pred_masks[lvl][anc][ cur_gt_mask_inds[:, 0], # N :, # V x U cur_gt_mask_inds[:, 1], # H cur_gt_mask_inds[:, 2], # W ] loss_mask += F.binary_cross_entropy_with_logits( # V, U cur_pred_masks.view(-1, cur_mask_size, cur_mask_size), gt_masks[lvl][anc].to(dtype=torch.float32), reduction="sum", weight=cur_size_divider, pos_weight=self.mask_pos_weight, ) losses["loss_mask"] = loss_mask / loss_normalizer return losses
def losses(self, center_pts, cls_outs, pts_outs_init, pts_outs_refine, targets): """ Args: center_pts: (list[list[Tensor]]): a list of N=#image elements. Each is a list of #feature level tensors. The tensors contains shifts of this image on the specific feature level. cls_outs: List[Tensor], each item in list with shape:[N, num_classes, H, W] pts_outs_init: List[Tensor], each item in list with shape:[N, num_points*2, H, W] pts_outs_refine: List[Tensor], each item in list with shape:[N, num_points*2, H, W] targets: (list[Instances]): a list of N `Instances`s. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Specify `targets` during training only. Returns: dict[str:Tensor]: mapping from a named loss to scalar tensor """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_outs] assert len(featmap_sizes) == len(center_pts[0]) pts_dim = 2 * self.num_points cls_outs = [ cls_out.permute(0, 2, 3, 1).reshape(cls_out.size(0), -1, self.num_classes) for cls_out in cls_outs ] pts_outs_init = [ pts_out_init.permute(0, 2, 3, 1).reshape(pts_out_init.size(0), -1, pts_dim) for pts_out_init in pts_outs_init ] pts_outs_refine = [ pts_out_refine.permute(0, 2, 3, 1).reshape(pts_out_refine.size(0), -1, pts_dim) for pts_out_refine in pts_outs_refine ] cls_outs = torch.cat(cls_outs, dim=1) pts_outs_init = torch.cat(pts_outs_init, dim=1) pts_outs_refine = torch.cat(pts_outs_refine, dim=1) pts_strides = [] for i, s in enumerate(center_pts[0]): pts_strides.append( cls_outs.new_full((s.size(0), ), self.fpn_strides[i])) pts_strides = torch.cat(pts_strides, dim=0) center_pts = [ torch.cat(c_pts, dim=0).to(self.device) for c_pts in center_pts ] pred_cls = [] pred_init = [] pred_refine = [] target_cls = [] target_init = [] target_refine = [] num_pos_init = 0 num_pos_refine = 0 for img, (per_center_pts, cls_prob, pts_init, pts_refine, per_targets) in enumerate( zip(center_pts, cls_outs, pts_outs_init, pts_outs_refine, targets)): assert per_center_pts.shape[:-1] == cls_prob.shape[:-1] gt_bboxes = per_targets.gt_boxes.to(cls_prob.device) gt_labels = per_targets.gt_classes.to(cls_prob.device) pts_init_bbox_targets, pts_init_labels_targets = \ self.point_targets(per_center_pts, pts_strides, gt_bboxes.tensor, gt_labels) # per_center_pts, shape:[N, 18] per_center_pts_repeat = per_center_pts.repeat(1, self.num_points) normalize_term = self.point_base_scale * pts_strides normalize_term = normalize_term.reshape(-1, 1) # bbox_center = torch.cat([per_center_pts, per_center_pts], dim=1) per_pts_strides = pts_strides.reshape(-1, 1) pts_init_coordinate = pts_init * per_pts_strides + \ per_center_pts_repeat init_bbox_pred = self.pts_to_bbox(pts_init_coordinate) foreground_idxs = (pts_init_labels_targets >= 0) & \ (pts_init_labels_targets != self.num_classes) pred_init.append(init_bbox_pred[foreground_idxs] / normalize_term[foreground_idxs]) target_init.append(pts_init_bbox_targets[foreground_idxs] / normalize_term[foreground_idxs]) num_pos_init += foreground_idxs.sum() # A another way to convert predicted offset to bbox # bbox_pred_init = self.pts_to_bbox(pts_init.detach()) * \ # per_pts_strides # init_bbox_pred = bbox_center + bbox_pred_init pts_refine_bbox_targets, pts_refine_labels_targets = \ self.bbox_targets(init_bbox_pred, gt_bboxes, gt_labels) pts_refine_coordinate = pts_refine * per_pts_strides + \ per_center_pts_repeat refine_bbox_pred = self.pts_to_bbox(pts_refine_coordinate) # bbox_pred_refine = self.pts_to_bbox(pts_refine) * per_pts_strides # refine_bbox_pred = bbox_center + bbox_pred_refine foreground_idxs = (pts_refine_labels_targets >= 0) & \ (pts_refine_labels_targets != self.num_classes) pred_refine.append(refine_bbox_pred[foreground_idxs] / normalize_term[foreground_idxs]) target_refine.append(pts_refine_bbox_targets[foreground_idxs] / normalize_term[foreground_idxs]) num_pos_refine += foreground_idxs.sum() gt_classes_target = torch.zeros_like(cls_prob) gt_classes_target[foreground_idxs, pts_refine_labels_targets[foreground_idxs]] = 1 pred_cls.append(cls_prob) target_cls.append(gt_classes_target) pred_cls = torch.cat(pred_cls, dim=0) pred_init = torch.cat(pred_init, dim=0) pred_refine = torch.cat(pred_refine, dim=0) target_cls = torch.cat(target_cls, dim=0) target_init = torch.cat(target_init, dim=0) target_refine = torch.cat(target_refine, dim=0) loss_cls = sigmoid_focal_loss_jit( pred_cls, target_cls, alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum") / max( 1, num_pos_refine.item()) * self.loss_cls_weight loss_pts_init = smooth_l1_loss( pred_init, target_init, beta=0.11, reduction='sum') / max( 1, num_pos_init.item()) * self.loss_bbox_init_weight loss_pts_refine = smooth_l1_loss( pred_refine, target_refine, beta=0.11, reduction='sum') / max( 1, num_pos_refine.item()) * self.loss_bbox_refine_weight return { "loss_cls": loss_cls, "loss_pts_init": loss_pts_init, "loss_pts_refine": loss_pts_refine }
def losses(self, anchors, gt_instances, box_cls, box_delta): anchors = [Boxes.cat(anchors_i) for anchors_i in anchors] box_cls_flattened = [ permute_to_N_HWA_K(x, self.num_classes) for x in box_cls ] box_delta_flattened = [permute_to_N_HWA_K(x, 4) for x in box_delta] pred_class_logits = cat(box_cls_flattened, dim=1) pred_anchor_deltas = cat(box_delta_flattened, dim=1) pred_class_probs = pred_class_logits.sigmoid() pred_box_probs = [] num_foreground = 0 positive_losses = [] for anchors_per_image, \ gt_instances_per_image, \ pred_class_probs_per_image, \ pred_anchor_deltas_per_image in zip( anchors, gt_instances, pred_class_probs, pred_anchor_deltas): gt_classes_per_image = gt_instances_per_image.gt_classes with torch.no_grad(): # predicted_boxes_per_image: a_{j}^{loc}, shape: [j, 4] predicted_boxes_per_image = self.box2box_transform.apply_deltas( pred_anchor_deltas_per_image, anchors_per_image.tensor) # gt_pred_iou: IoU_{ij}^{loc}, shape: [i, j] gt_pred_iou = pairwise_iou(gt_instances_per_image.gt_boxes, Boxes(predicted_boxes_per_image)) t1 = self.bbox_threshold t2 = gt_pred_iou.max(dim=1, keepdim=True).values.clamp_( min=t1 + torch.finfo(torch.float32).eps) # gt_pred_prob: P{a_{j} -> b_{i}}, shape: [i, j] gt_pred_prob = ((gt_pred_iou - t1) / (t2 - t1)).clamp_(min=0, max=1) # pred_box_prob_per_image: P{a_{j} \in A_{+}}, shape: [j, c] nonzero_idxs = torch.nonzero(gt_pred_prob, as_tuple=True) pred_box_prob_per_image = torch.zeros_like( pred_class_probs_per_image) pred_box_prob_per_image[nonzero_idxs[1], gt_classes_per_image[nonzero_idxs[0]]] \ = gt_pred_prob[nonzero_idxs] pred_box_probs.append(pred_box_prob_per_image) # construct bags for objects match_quality_matrix = pairwise_iou( gt_instances_per_image.gt_boxes, anchors_per_image) _, foreground_idxs = torch.topk(match_quality_matrix, self.pos_anchor_topk, dim=1, sorted=False) # matched_pred_class_probs_per_image: P_{ij}^{cls} matched_pred_class_probs_per_image = torch.gather( pred_class_probs_per_image[foreground_idxs], 2, gt_classes_per_image.view(-1, 1, 1).repeat(1, self.pos_anchor_topk, 1)).squeeze(2) # matched_gt_anchor_deltas_per_image: P_{ij}^{loc} matched_gt_anchor_deltas_per_image = self.box2box_transform.get_deltas( anchors_per_image.tensor[foreground_idxs], gt_instances_per_image.gt_boxes.tensor.unsqueeze(1)) loss_box_reg = smooth_l1_loss( pred_anchor_deltas_per_image[foreground_idxs], matched_gt_anchor_deltas_per_image, beta=self.smooth_l1_loss_beta, reduction="none").sum(dim=-1) * self.reg_weight matched_pred_reg_probs_per_image = (-loss_box_reg).exp() # positive_losses: { -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } num_foreground += len(gt_instances_per_image) positive_losses.append( positive_bag_loss(matched_pred_class_probs_per_image * matched_pred_reg_probs_per_image, dim=1)) # positive_loss: \sum_{i}{ -log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) ) } / ||B|| positive_loss = torch.cat(positive_losses).sum() / max( 1, num_foreground) # pred_box_probs: P{a_{j} \in A_{+}} pred_box_probs = torch.stack(pred_box_probs, dim=0) # negative_loss: \sum_{j}{ FL( (1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}) ) } / n||B|| negative_loss = negative_bag_loss( pred_class_probs * (1 - pred_box_probs), self.focal_loss_gamma).sum() / max( 1, num_foreground * self.pos_anchor_topk) loss_pos = positive_loss * self.focal_loss_alpha loss_neg = negative_loss * (1 - self.focal_loss_alpha) return {"loss_pos": loss_pos, "loss_neg": loss_neg}