Esempio n. 1
0
 def test_focal_loss_star_equals_ce_loss_multi_class(self) -> None:
     """
     Focal loss with predictions for multiple classes matches ce loss.
     """
     inputs = logit(
         torch.tensor(
             [[
                 [0.95, 0.55, 0.12, 0.05],
                 [0.09, 0.95, 0.36, 0.11],
                 [0.06, 0.12, 0.56, 0.07],
                 [0.09, 0.15, 0.25, 0.45],
             ]],
             dtype=torch.float32,
         ))
     targets = torch.tensor(
         [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]],
         dtype=torch.float32,
     )
     focal_loss_star = sigmoid_focal_loss_star(inputs,
                                               targets,
                                               gamma=1,
                                               alpha=-1,
                                               reduction="mean")
     ce_loss = F.binary_cross_entropy_with_logits(inputs,
                                                  targets,
                                                  reduction="mean")
     self.assertEqual(ce_loss, focal_loss_star)
Esempio n. 2
0
    def test_focal_loss_star_equals_ce_loss(self) -> None:
        """
        No weighting of easy/hard (gamma = 1) or positive/negative (alpha = -1).
        """
        inputs = logit(
            torch.tensor([[[0.95], [0.90], [0.98], [0.99]]],
                         dtype=torch.float32))
        targets = torch.tensor([[[1], [1], [1], [1]]], dtype=torch.float32)
        inputs_fl = inputs.clone().requires_grad_()
        targets_fl = targets.clone()
        inputs_ce = inputs.clone().requires_grad_()
        targets_ce = targets.clone()

        focal_loss_star = sigmoid_focal_loss_star(inputs_fl,
                                                  targets_fl,
                                                  gamma=1,
                                                  alpha=-1,
                                                  reduction="mean")
        ce_loss = F.binary_cross_entropy_with_logits(inputs_ce,
                                                     targets_ce,
                                                     reduction="mean")

        self.assertEqual(ce_loss, focal_loss_star.data)
        focal_loss_star.backward()
        ce_loss.backward()
        self.assertTrue(
            # pyre-ignore
            torch.allclose(inputs_fl.grad.data, inputs_ce.grad.data))
Esempio n. 3
0
 def test_positives_ignored_focal_loss_star(self) -> None:
     """
     With alpha = 0 postive examples have focal loss of 0.
     """
     inputs = logit(
         torch.tensor(
             [[[0.05], [0.12], [0.89], [0.79]]], dtype=torch.float32
         )
     )
     targets = torch.tensor([[[1], [1], [0], [0]]], dtype=torch.float32)
     focal_loss_star = (
         sigmoid_focal_loss_star(inputs, targets, gamma=3, alpha=0)
         .squeeze()
         .numpy()
     )
     ce_loss = (
         F.binary_cross_entropy_with_logits(
             inputs, targets, reduction="none"
         )
         .squeeze()
         .numpy()
     )
     targets = targets.squeeze().numpy()
     self.assertTrue(np.all(ce_loss[targets == 1] > 0))
     self.assertTrue(np.all(focal_loss_star[targets == 1] == 0))
Esempio n. 4
0
 def test_focal_loss_star_positive_weights(self) -> None:
     """
     With alpha = 0.5 loss of positive examples is downweighted.
     """
     N = 5
     inputs = logit(torch.rand(N))
     targets = torch.ones((N,)).float()
     focal_loss_star = sigmoid_focal_loss_star(
         inputs, targets, gamma=2, alpha=-1
     )
     focal_loss_half = sigmoid_focal_loss_star(
         inputs, targets, gamma=2, alpha=0.5
     )
     loss_ratio = (focal_loss_star / focal_loss_half).squeeze()
     correct_ratio = torch.zeros((N,)).float() + 2.0
     self.assertTrue(np.allclose(loss_ratio, correct_ratio))
Esempio n. 5
0
 def run_focal_loss_star() -> None:
     fl = sigmoid_focal_loss_star(inputs,
                                  targets,
                                  gamma=1,
                                  alpha=alpha,
                                  reduction="mean")
     fl.backward()
     torch.cuda.synchronize()
Esempio n. 6
0
 def test_easy_ex_focal_loss_star_less_than_ce_loss(self) -> None:
     """
     With gamma = 3 loss of easy examples is downweighted.
     """
     inputs = logit(
         torch.tensor([0.75, 0.8, 0.12, 0.05], dtype=torch.float32))
     targets = torch.tensor([1, 1, 0, 0], dtype=torch.float32)
     focal_loss_star = sigmoid_focal_loss_star(inputs,
                                               targets,
                                               gamma=3,
                                               alpha=-1)
     ce_loss = F.binary_cross_entropy_with_logits(inputs,
                                                  targets,
                                                  reduction="none")
     loss_ratio = (ce_loss / focal_loss_star).squeeze()
     self.assertTrue(torch.all(loss_ratio > 10.0))
Esempio n. 7
0
 def test_sum_focal_loss_star_equals_ce_loss(self) -> None:
     """
     Sum of focal loss across all examples matches ce loss.
     """
     inputs = logit(
         torch.tensor([[[0.05], [0.12], [0.89], [0.79]]],
                      dtype=torch.float32))
     targets = torch.tensor([[[1], [1], [0], [0]]], dtype=torch.float32)
     focal_loss_star = sigmoid_focal_loss_star(inputs,
                                               targets,
                                               gamma=1,
                                               alpha=-1,
                                               reduction="sum")
     ce_loss = F.binary_cross_entropy_with_logits(inputs,
                                                  targets,
                                                  reduction="sum")
     self.assertTrue(torch.allclose(ce_loss, focal_loss_star))
Esempio n. 8
0
 def test_hard_ex_focal_loss_star_similar_to_ce_loss(self) -> None:
     """
     With gamma = 2 loss of hard examples is roughly unchanged.
     """
     inputs = logit(
         torch.tensor([0.05, 0.12, 0.91, 0.85], dtype=torch.float64))
     targets = torch.tensor([1, 1, 0, 0], dtype=torch.float64)
     focal_loss_star = sigmoid_focal_loss_star(inputs,
                                               targets,
                                               gamma=2,
                                               alpha=-1)
     ce_loss = F.binary_cross_entropy_with_logits(inputs,
                                                  targets,
                                                  reduction="none")
     loss_ratio = (ce_loss / focal_loss_star).squeeze()
     rough_ratio = torch.tensor([1.0, 1.0, 1.0, 1.0], dtype=torch.float64)
     self.assertTrue(torch.allclose(loss_ratio, rough_ratio, atol=0.1))
Esempio n. 9
0
 def test_mean_focal_loss_star_equals_ce_loss(self) -> None:
     """
     Mean value of focal loss across all examples matches ce loss.
     """
     inputs = logit(
         torch.tensor(
             [[0.05, 0.9], [0.52, 0.45], [0.89, 0.8], [0.39, 0.5]],
             dtype=torch.float32,
         ))
     targets = torch.tensor([[1, 0], [1, 0], [1, 1], [0, 1]],
                            dtype=torch.float32)
     focal_loss_star = sigmoid_focal_loss_star(inputs,
                                               targets,
                                               gamma=1,
                                               alpha=-1,
                                               reduction="mean")
     ce_loss = F.binary_cross_entropy_with_logits(inputs,
                                                  targets,
                                                  reduction="mean")
     self.assertTrue(torch.allclose(ce_loss, focal_loss_star))