def test_focal_loss_fn(): # classification test torch_ce = F.binary_cross_entropy_with_logits(INP_BINARY, TARGET_BINARY.float()) my_ce = pt_F.focal_loss_with_logits(INP_BINARY, TARGET_BINARY, alpha=-1, gamma=0) assert torch.allclose(torch_ce, my_ce) # check that smooth combination works my_ce_not_reduced = pt_F.focal_loss_with_logits(INP_BINARY, TARGET_BINARY, combine_thr=0) my_ce_reduced = pt_F.focal_loss_with_logits(INP_BINARY, TARGET_BINARY, combine_thr=0.2) my_ce_reduced2 = pt_F.focal_loss_with_logits(INP_BINARY, TARGET_BINARY, combine_thr=0.8) assert my_ce_not_reduced < my_ce_reduced assert my_ce_reduced < my_ce_reduced2 # images test torch_ce = F.binary_cross_entropy_with_logits(INP_IMG_BINARY, TARGET_IMG_BINARY) my_ce = pt_F.focal_loss_with_logits(INP_IMG_BINARY, TARGET_IMG_BINARY, alpha=-1, gamma=0) assert torch.allclose(torch_ce, my_ce)
def test_focal_loss_fn_normalize(): # simulate very accurate predictions inp = TARGET_BINARY * 5 - (1 - TARGET_BINARY) * 5 my_ce = pt_F.focal_loss_with_logits(inp, TARGET_BINARY, normalized=False) my_ce_normalized = pt_F.focal_loss_with_logits(inp, TARGET_BINARY, normalized=True) assert my_ce_normalized > my_ce
def test_focal_loss_fn_basic(): input_good = torch.Tensor([10, -10, 10]).float() input_bad = torch.Tensor([-1, 2, 0]).float() target = torch.Tensor([1, 0, 1]) loss_good = pt_F.focal_loss_with_logits(input_good, target) loss_bad = pt_F.focal_loss_with_logits(input_bad, target) assert loss_good < loss_bad
def test_focal_loss_fn_basic(): """explicit tests for values in two corner cases""" input_good = torch.Tensor([10, -10, 10]).float() input_bad = torch.Tensor([-1, 2, 0]).float() target = torch.Tensor([1, 0, 1]) loss_good = pt_F.focal_loss_with_logits(input_good, target) loss_bad = pt_F.focal_loss_with_logits(input_bad, target) assert torch.allclose(loss_good, torch.tensor(0.0)) assert torch.allclose(loss_bad, torch.tensor(0.4854), atol=1e-4)
def test_focal_fn_is_scribtable(): """check that script gives the same results""" input_bad = torch.Tensor([-1, 2, 0]).float() target = torch.Tensor([1, 0, 1]) loss = pt_F.focal_loss_with_logits(input_bad, target) jit_func = torch.jit.script(pt_F.focal_loss_with_logits) loss_jit = jit_func(input_bad, target) assert torch.allclose(loss, loss_jit)
def test_focal_loss_fn_reduction(reduction): torch_ce = F.binary_cross_entropy_with_logits(INP_BINARY, TARGET_BINARY.float(), reduction=reduction) my_ce = pt_F.focal_loss_with_logits(INP_BINARY, TARGET_BINARY, alpha=0.5, gamma=0, reduction=reduction) assert torch.allclose(torch_ce, my_ce * 2)