def test_focal_loss_fn():
    # classification test
    torch_ce = F.binary_cross_entropy_with_logits(INP_BINARY,
                                                  TARGET_BINARY.float())
    my_ce = pt_F.focal_loss_with_logits(INP_BINARY,
                                        TARGET_BINARY,
                                        alpha=-1,
                                        gamma=0)
    assert torch.allclose(torch_ce, my_ce)

    # check that smooth combination works
    my_ce_not_reduced = pt_F.focal_loss_with_logits(INP_BINARY,
                                                    TARGET_BINARY,
                                                    combine_thr=0)
    my_ce_reduced = pt_F.focal_loss_with_logits(INP_BINARY,
                                                TARGET_BINARY,
                                                combine_thr=0.2)
    my_ce_reduced2 = pt_F.focal_loss_with_logits(INP_BINARY,
                                                 TARGET_BINARY,
                                                 combine_thr=0.8)
    assert my_ce_not_reduced < my_ce_reduced
    assert my_ce_reduced < my_ce_reduced2

    # images test
    torch_ce = F.binary_cross_entropy_with_logits(INP_IMG_BINARY,
                                                  TARGET_IMG_BINARY)
    my_ce = pt_F.focal_loss_with_logits(INP_IMG_BINARY,
                                        TARGET_IMG_BINARY,
                                        alpha=-1,
                                        gamma=0)
    assert torch.allclose(torch_ce, my_ce)
def test_focal_loss_fn_normalize():
    # simulate very accurate predictions
    inp = TARGET_BINARY * 5 - (1 - TARGET_BINARY) * 5
    my_ce = pt_F.focal_loss_with_logits(inp, TARGET_BINARY, normalized=False)
    my_ce_normalized = pt_F.focal_loss_with_logits(inp,
                                                   TARGET_BINARY,
                                                   normalized=True)
    assert my_ce_normalized > my_ce
Exemple #3
0
def test_focal_loss_fn_basic():
    input_good = torch.Tensor([10, -10, 10]).float()
    input_bad = torch.Tensor([-1, 2, 0]).float()
    target = torch.Tensor([1, 0, 1])

    loss_good = pt_F.focal_loss_with_logits(input_good, target)
    loss_bad = pt_F.focal_loss_with_logits(input_bad, target)
    assert loss_good < loss_bad
def test_focal_loss_fn_basic():
    """explicit tests for values in two corner cases"""
    input_good = torch.Tensor([10, -10, 10]).float()
    input_bad = torch.Tensor([-1, 2, 0]).float()
    target = torch.Tensor([1, 0, 1])

    loss_good = pt_F.focal_loss_with_logits(input_good, target)
    loss_bad = pt_F.focal_loss_with_logits(input_bad, target)
    assert torch.allclose(loss_good, torch.tensor(0.0))
    assert torch.allclose(loss_bad, torch.tensor(0.4854), atol=1e-4)
def test_focal_fn_is_scribtable():
    """check that script gives the same results"""
    input_bad = torch.Tensor([-1, 2, 0]).float()
    target = torch.Tensor([1, 0, 1])
    loss = pt_F.focal_loss_with_logits(input_bad, target)
    jit_func = torch.jit.script(pt_F.focal_loss_with_logits)
    loss_jit = jit_func(input_bad, target)
    assert torch.allclose(loss, loss_jit)
def test_focal_loss_fn_reduction(reduction):
    torch_ce = F.binary_cross_entropy_with_logits(INP_BINARY,
                                                  TARGET_BINARY.float(),
                                                  reduction=reduction)
    my_ce = pt_F.focal_loss_with_logits(INP_BINARY,
                                        TARGET_BINARY,
                                        alpha=0.5,
                                        gamma=0,
                                        reduction=reduction)
    assert torch.allclose(torch_ce, my_ce * 2)