def run_focal_loss_star_jit() -> None: fl = sigmoid_focal_loss_star_jit(inputs, targets, gamma=1, alpha=alpha, reduction="mean") fl.backward() torch.cuda.synchronize()
def test_focal_loss_star_equals_ce_loss_jit(self) -> None: """ No weighting of easy/hard (gamma = 1) or positive/negative (alpha = 0). """ device = torch.device("cuda:0") N = 5 inputs = logit(torch.rand(N)).to(device) targets = torch.randint(0, 2, (N, )).float().to(device) focal_loss_star = sigmoid_focal_loss_star_jit(inputs, targets, gamma=1) ce_loss = F.binary_cross_entropy_with_logits(inputs.cpu(), targets.cpu(), reduction="none") self.assertTrue(np.allclose(ce_loss, focal_loss_star.cpu()))
def losses( self, gt_class_info, gt_delta_info, gt_mask_info, num_fg, pred_logits, pred_deltas, pred_masks, ): """ Args: For `gt_class_info`, `gt_delta_info`, `gt_mask_info` and `num_fg` parameters, see :meth:`TensorMask.get_ground_truth`. For `pred_logits`, `pred_deltas` and `pred_masks`, see :meth:`TensorMaskHead.forward`. Returns: losses (dict[str: Tensor]): mapping from a named loss to a scalar tensor storing the loss. Used during training only. The potential dict keys are: "loss_cls", "loss_box_reg" and "loss_mask". """ gt_classes_target, gt_valid_inds = gt_class_info gt_deltas, gt_fg_inds = gt_delta_info gt_masks, gt_mask_inds = gt_mask_info loss_normalizer = torch.tensor(max(1, num_fg), dtype=torch.float32, device=self.device) # classification and regression pred_logits, pred_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat( pred_logits, pred_deltas, self.num_classes ) loss_cls = ( sigmoid_focal_loss_star_jit( pred_logits[gt_valid_inds], gt_classes_target[gt_valid_inds], alpha=self.focal_loss_alpha, gamma=self.focal_loss_gamma, reduction="sum", ) / loss_normalizer ) if num_fg == 0: loss_box_reg = pred_deltas.sum() * 0 else: loss_box_reg = ( smooth_l1_loss(pred_deltas[gt_fg_inds], gt_deltas, beta=0.0, reduction="sum") / loss_normalizer ) losses = {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg} # mask prediction if self.mask_on: loss_mask = 0 for lvl in range(self.num_levels): cur_level_factor = 2 ** lvl if self.bipyramid_on else 1 for anc in range(self.num_anchors): cur_gt_mask_inds = gt_mask_inds[lvl][anc] if cur_gt_mask_inds is None: loss_mask += pred_masks[lvl][anc][0, 0, 0, 0] * 0 else: cur_mask_size = self.mask_sizes[anc] * cur_level_factor # TODO maybe there are numerical issues when mask sizes are large cur_size_divider = torch.tensor( self.mask_loss_weight / (cur_mask_size ** 2), dtype=torch.float32, device=self.device, ) cur_pred_masks = pred_masks[lvl][anc][ cur_gt_mask_inds[:, 0], # N :, # V x U cur_gt_mask_inds[:, 1], # H cur_gt_mask_inds[:, 2], # W ] loss_mask += F.binary_cross_entropy_with_logits( cur_pred_masks.view(-1, cur_mask_size, cur_mask_size), # V, U gt_masks[lvl][anc].to(dtype=torch.float32), reduction="sum", weight=cur_size_divider, pos_weight=self.mask_pos_weight, ) losses["loss_mask"] = loss_mask / loss_normalizer return losses