예제 #1
0
def test_focal_class_is_scribtable():
    """check that script gives the same results"""
    input_bad = torch.Tensor([-1, 2, 0]).float()
    target = torch.Tensor([1, 0, 1])
    loss = losses.FocalLoss()(input_bad, target)
    jit_class = torch.jit.script(losses.FocalLoss())
    loss_jit = jit_class(input_bad, target)
    assert torch.allclose(loss, loss_jit)
예제 #2
0
def test_binary_focal_3d():
    """This test validates that it works for 3d by comparing results to 1d reshaped data"""
    INP_3d = torch.rand(BS, 1, IM_SIZE, IM_SIZE, IM_SIZE)
    TARGET_3d = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE, IM_SIZE)).float()
    my_focal_3d = losses.FocalLoss(mode="binary")(INP_3d, TARGET_3d)
    my_focal_1d = losses.FocalLoss(mode="binary")(INP_3d.view(BS, 1, -1), TARGET_3d.view(BS, 1, -1))

    assert torch.allclose(my_focal_3d, my_focal_1d)
예제 #3
0
def test_focal_loss_modes():
    # check that multilabel == one-hot-encoded multiclass
    fl_multiclass = losses.FocalLoss(mode="multiclass",
                                     reduction="sum")(INP_IMG,
                                                      TARGET_IMG_MULTICLASS)
    fl_multilabel = losses.FocalLoss(mode="multilabel",
                                     reduction="sum")(INP_IMG,
                                                      TARGET_IMG_MULTILABEL)
    assert fl_multiclass == fl_multilabel

    # check that ignore index works for multiclass
    fl = losses.FocalLoss(mode="multiclass",
                          reduction="none")(INP_IMG, TARGET_IMG_MULTICLASS)
    loss_diff = fl[:, :, :2, :2].sum()
    y_true = TARGET_IMG_MULTICLASS.clone()
    y_true[:, :2, :2] = -100
    fl_i = losses.FocalLoss(mode="multiclass",
                            reduction="sum",
                            ignore_label=-100)(INP_IMG, y_true)
    assert torch.allclose(fl.sum() - loss_diff, fl_i)

    # check that ignore index works for binary
    fl = losses.FocalLoss(mode="binary", reduction="none")(INP_IMG_BINARY,
                                                           TARGET_IMG_BINARY)
    loss_diff = fl[:, :, :2, :2].sum()
    y_true = TARGET_IMG_BINARY.clone()
    y_true[:, :, :2, :2] = -100
    fl_i = losses.FocalLoss(mode="binary", reduction="sum",
                            ignore_label=-100)(INP_IMG_BINARY, y_true)
    assert torch.allclose(fl.sum() - loss_diff, fl_i)
예제 #4
0
def test_focal_incorrect_reduction():
    with pytest.raises(ValueError):
        losses.FocalLoss(reduction="some_reduction")
예제 #5
0
def test_focal_incorrect_mode():
    with pytest.raises(ValueError):
        losses.FocalLoss(mode="some_mode")