Ejemplo n.º 1
0
 def run_focal_loss_star_jit() -> None:
     fl = sigmoid_focal_loss_star_jit(inputs,
                                      targets,
                                      gamma=1,
                                      alpha=alpha,
                                      reduction="mean")
     fl.backward()
     torch.cuda.synchronize()
Ejemplo n.º 2
0
 def test_focal_loss_star_equals_ce_loss_jit(self) -> None:
     """
     No weighting of easy/hard (gamma = 1) or positive/negative (alpha = 0).
     """
     device = torch.device("cuda:0")
     N = 5
     inputs = logit(torch.rand(N)).to(device)
     targets = torch.randint(0, 2, (N, )).float().to(device)
     focal_loss_star = sigmoid_focal_loss_star_jit(inputs, targets, gamma=1)
     ce_loss = F.binary_cross_entropy_with_logits(inputs.cpu(),
                                                  targets.cpu(),
                                                  reduction="none")
     self.assertTrue(np.allclose(ce_loss, focal_loss_star.cpu()))
Ejemplo n.º 3
0
    def losses(
        self,
        gt_class_info,
        gt_delta_info,
        gt_mask_info,
        num_fg,
        pred_logits,
        pred_deltas,
        pred_masks,
    ):
        """
        Args:
            For `gt_class_info`, `gt_delta_info`, `gt_mask_info` and `num_fg` parameters, see
                :meth:`TensorMask.get_ground_truth`.
            For `pred_logits`, `pred_deltas` and `pred_masks`, see
                :meth:`TensorMaskHead.forward`.

        Returns:
            losses (dict[str: Tensor]): mapping from a named loss to a scalar tensor
                storing the loss. Used during training only. The potential dict keys are:
                "loss_cls", "loss_box_reg" and "loss_mask".
        """
        gt_classes_target, gt_valid_inds = gt_class_info
        gt_deltas, gt_fg_inds = gt_delta_info
        gt_masks, gt_mask_inds = gt_mask_info
        loss_normalizer = torch.tensor(max(1, num_fg), dtype=torch.float32, device=self.device)

        # classification and regression
        pred_logits, pred_deltas = permute_all_cls_and_box_to_N_HWA_K_and_concat(
            pred_logits, pred_deltas, self.num_classes
        )
        loss_cls = (
            sigmoid_focal_loss_star_jit(
                pred_logits[gt_valid_inds],
                gt_classes_target[gt_valid_inds],
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            )
            / loss_normalizer
        )

        if num_fg == 0:
            loss_box_reg = pred_deltas.sum() * 0
        else:
            loss_box_reg = (
                smooth_l1_loss(pred_deltas[gt_fg_inds], gt_deltas, beta=0.0, reduction="sum")
                / loss_normalizer
            )
        losses = {"loss_cls": loss_cls, "loss_box_reg": loss_box_reg}

        # mask prediction
        if self.mask_on:
            loss_mask = 0
            for lvl in range(self.num_levels):
                cur_level_factor = 2 ** lvl if self.bipyramid_on else 1
                for anc in range(self.num_anchors):
                    cur_gt_mask_inds = gt_mask_inds[lvl][anc]
                    if cur_gt_mask_inds is None:
                        loss_mask += pred_masks[lvl][anc][0, 0, 0, 0] * 0
                    else:
                        cur_mask_size = self.mask_sizes[anc] * cur_level_factor
                        # TODO maybe there are numerical issues when mask sizes are large
                        cur_size_divider = torch.tensor(
                            self.mask_loss_weight / (cur_mask_size ** 2),
                            dtype=torch.float32,
                            device=self.device,
                        )

                        cur_pred_masks = pred_masks[lvl][anc][
                            cur_gt_mask_inds[:, 0],  # N
                            :,  # V x U
                            cur_gt_mask_inds[:, 1],  # H
                            cur_gt_mask_inds[:, 2],  # W
                        ]

                        loss_mask += F.binary_cross_entropy_with_logits(
                            cur_pred_masks.view(-1, cur_mask_size, cur_mask_size),  # V, U
                            gt_masks[lvl][anc].to(dtype=torch.float32),
                            reduction="sum",
                            weight=cur_size_divider,
                            pos_weight=self.mask_pos_weight,
                        )
            losses["loss_mask"] = loss_mask / loss_normalizer
        return losses