def test_focal_loss_star_equals_ce_loss_multi_class(self) -> None: """ Focal loss with predictions for multiple classes matches ce loss. """ inputs = logit( torch.tensor( [[ [0.95, 0.55, 0.12, 0.05], [0.09, 0.95, 0.36, 0.11], [0.06, 0.12, 0.56, 0.07], [0.09, 0.15, 0.25, 0.45], ]], dtype=torch.float32, )) targets = torch.tensor( [[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]], dtype=torch.float32, ) focal_loss_star = sigmoid_focal_loss_star(inputs, targets, gamma=1, alpha=-1, reduction="mean") ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="mean") self.assertEqual(ce_loss, focal_loss_star)
def test_focal_loss_star_equals_ce_loss(self) -> None: """ No weighting of easy/hard (gamma = 1) or positive/negative (alpha = -1). """ inputs = logit( torch.tensor([[[0.95], [0.90], [0.98], [0.99]]], dtype=torch.float32)) targets = torch.tensor([[[1], [1], [1], [1]]], dtype=torch.float32) inputs_fl = inputs.clone().requires_grad_() targets_fl = targets.clone() inputs_ce = inputs.clone().requires_grad_() targets_ce = targets.clone() focal_loss_star = sigmoid_focal_loss_star(inputs_fl, targets_fl, gamma=1, alpha=-1, reduction="mean") ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction="mean") self.assertEqual(ce_loss, focal_loss_star.data) focal_loss_star.backward() ce_loss.backward() self.assertTrue( # pyre-ignore torch.allclose(inputs_fl.grad.data, inputs_ce.grad.data))
def test_positives_ignored_focal_loss_star(self) -> None: """ With alpha = 0 postive examples have focal loss of 0. """ inputs = logit( torch.tensor( [[[0.05], [0.12], [0.89], [0.79]]], dtype=torch.float32 ) ) targets = torch.tensor([[[1], [1], [0], [0]]], dtype=torch.float32) focal_loss_star = ( sigmoid_focal_loss_star(inputs, targets, gamma=3, alpha=0) .squeeze() .numpy() ) ce_loss = ( F.binary_cross_entropy_with_logits( inputs, targets, reduction="none" ) .squeeze() .numpy() ) targets = targets.squeeze().numpy() self.assertTrue(np.all(ce_loss[targets == 1] > 0)) self.assertTrue(np.all(focal_loss_star[targets == 1] == 0))
def test_focal_loss_star_positive_weights(self) -> None: """ With alpha = 0.5 loss of positive examples is downweighted. """ N = 5 inputs = logit(torch.rand(N)) targets = torch.ones((N,)).float() focal_loss_star = sigmoid_focal_loss_star( inputs, targets, gamma=2, alpha=-1 ) focal_loss_half = sigmoid_focal_loss_star( inputs, targets, gamma=2, alpha=0.5 ) loss_ratio = (focal_loss_star / focal_loss_half).squeeze() correct_ratio = torch.zeros((N,)).float() + 2.0 self.assertTrue(np.allclose(loss_ratio, correct_ratio))
def run_focal_loss_star() -> None: fl = sigmoid_focal_loss_star(inputs, targets, gamma=1, alpha=alpha, reduction="mean") fl.backward() torch.cuda.synchronize()
def test_easy_ex_focal_loss_star_less_than_ce_loss(self) -> None: """ With gamma = 3 loss of easy examples is downweighted. """ inputs = logit( torch.tensor([0.75, 0.8, 0.12, 0.05], dtype=torch.float32)) targets = torch.tensor([1, 1, 0, 0], dtype=torch.float32) focal_loss_star = sigmoid_focal_loss_star(inputs, targets, gamma=3, alpha=-1) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") loss_ratio = (ce_loss / focal_loss_star).squeeze() self.assertTrue(torch.all(loss_ratio > 10.0))
def test_sum_focal_loss_star_equals_ce_loss(self) -> None: """ Sum of focal loss across all examples matches ce loss. """ inputs = logit( torch.tensor([[[0.05], [0.12], [0.89], [0.79]]], dtype=torch.float32)) targets = torch.tensor([[[1], [1], [0], [0]]], dtype=torch.float32) focal_loss_star = sigmoid_focal_loss_star(inputs, targets, gamma=1, alpha=-1, reduction="sum") ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="sum") self.assertTrue(torch.allclose(ce_loss, focal_loss_star))
def test_hard_ex_focal_loss_star_similar_to_ce_loss(self) -> None: """ With gamma = 2 loss of hard examples is roughly unchanged. """ inputs = logit( torch.tensor([0.05, 0.12, 0.91, 0.85], dtype=torch.float64)) targets = torch.tensor([1, 1, 0, 0], dtype=torch.float64) focal_loss_star = sigmoid_focal_loss_star(inputs, targets, gamma=2, alpha=-1) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") loss_ratio = (ce_loss / focal_loss_star).squeeze() rough_ratio = torch.tensor([1.0, 1.0, 1.0, 1.0], dtype=torch.float64) self.assertTrue(torch.allclose(loss_ratio, rough_ratio, atol=0.1))
def test_mean_focal_loss_star_equals_ce_loss(self) -> None: """ Mean value of focal loss across all examples matches ce loss. """ inputs = logit( torch.tensor( [[0.05, 0.9], [0.52, 0.45], [0.89, 0.8], [0.39, 0.5]], dtype=torch.float32, )) targets = torch.tensor([[1, 0], [1, 0], [1, 1], [0, 1]], dtype=torch.float32) focal_loss_star = sigmoid_focal_loss_star(inputs, targets, gamma=1, alpha=-1, reduction="mean") ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="mean") self.assertTrue(torch.allclose(ce_loss, focal_loss_star))