def test_cross_entropy_from_logits():
    """Check that passing from_logits True and taking softmax manually gives the same result"""
    loss_1 = losses.CrossEntropyLoss()
    res_1 = loss_1(INP, TARGET)
    loss_2 = losses.CrossEntropyLoss(from_logits=False)
    res_2 = loss_2(INP.softmax(1), TARGET)
    assert torch.allclose(res_1, res_2)
def test_binary_cross_entropy_from_logits():
    """Check that passing from_logits True and taking sigmoid manually gives the same result"""
    loss_1 = losses.CrossEntropyLoss(mode="binary")
    res_1 = loss_1(INP_BINARY, TARGET_BINARY)
    loss_2 = losses.CrossEntropyLoss(mode="binary", from_logits=False)
    res_2 = loss_2(INP_BINARY.sigmoid(), TARGET_BINARY)
    assert torch.allclose(res_1, res_2)
Exemple #3
0
def test_loss_addition():
    d_l = losses.DiceLoss("binary")
    bf_l = losses.CrossEntropyLoss(mode="binary")
    l = losses.DiceLoss("binary") * 0.5 + losses.CrossEntropyLoss(mode="binary") * 5
    d_res = d_l(INP_IMG_BINARY, TARGET_IMG_BINARY)
    bf_res = bf_l(INP_IMG_BINARY, TARGET_IMG_BINARY)
    res = l(INP_IMG_BINARY, TARGET_IMG_BINARY)
    assert res.shape == d_res.shape
    assert torch.allclose(res, d_res * 0.5 + bf_res * 5)
def test_cross_entropy_weight():
    weight_1 = torch.randint(1, 100, (N_CLASSES, )).float()
    weight_2 = weight_1.numpy().astype(int)
    weight_3 = list(weight_2)

    torch_ce_w = torch.nn.CrossEntropyLoss(weight=weight_1)(INP, TARGET)
    my_ce_w = losses.CrossEntropyLoss(weight=weight_1)(INP, TARGET)
    assert torch.allclose(torch_ce_w, my_ce_w)

    my_ce_w = losses.CrossEntropyLoss(weight=weight_2)(INP, TARGET)
    assert torch.allclose(torch_ce_w, my_ce_w)

    my_ce_w = losses.CrossEntropyLoss(weight=weight_3)(INP, TARGET)
    assert torch.allclose(torch_ce_w, my_ce_w)
def test_binary_cross_entropy(reduction):
    # classification test
    torch_ce = F.binary_cross_entropy_with_logits(INP_BINARY,
                                                  TARGET_BINARY,
                                                  reduction=reduction)
    my_ce_loss = losses.CrossEntropyLoss(mode="binary", reduction=reduction)
    my_ce = my_ce_loss(INP_BINARY, TARGET_BINARY)
    assert torch.allclose(torch_ce, my_ce)

    # test for images
    torch_ce = F.binary_cross_entropy_with_logits(INP_IMG_BINARY,
                                                  TARGET_IMG_BINARY,
                                                  reduction=reduction)
    my_ce = my_ce_loss(INP_IMG_BINARY, TARGET_IMG_BINARY)
    assert torch.allclose(torch_ce, my_ce)

    my_ce = my_ce_loss(INP_IMG_BINARY.squeeze(), TARGET_IMG_BINARY.squeeze())
    assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze())

    # test for images with different y_true shape
    my_ce = my_ce_loss(INP_IMG_BINARY.squeeze(), TARGET_IMG_BINARY)
    assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze())

    my_ce = my_ce_loss(INP_IMG_BINARY, TARGET_IMG_BINARY.squeeze())
    assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze())
Exemple #6
0
def test_binary_cross_entropy(reduction):
    # classification test
    torch_ce = F.binary_cross_entropy_with_logits(INP_BINARY, TARGET_BINARY, reduction=reduction)
    my_ce_loss = losses.CrossEntropyLoss(mode="binary", reduction=reduction)
    my_ce = my_ce_loss(INP_BINARY, TARGET_BINARY)
    assert torch.allclose(torch_ce, my_ce)

    # test than long targets would also work
    my_ce = my_ce_loss(INP_BINARY, TARGET_BINARY.long())
    assert torch.allclose(torch_ce, my_ce)

    # test for images
    torch_ce = F.binary_cross_entropy_with_logits(INP_IMG_BINARY, TARGET_IMG_BINARY, reduction=reduction)
    my_ce = my_ce_loss(INP_IMG_BINARY, TARGET_IMG_BINARY)
    assert torch.allclose(torch_ce, my_ce)

    my_ce = my_ce_loss(INP_IMG_BINARY.squeeze(), TARGET_IMG_BINARY.squeeze())
    assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze())

    # test for images with different y_true shape
    my_ce = my_ce_loss(INP_IMG_BINARY.squeeze(), TARGET_IMG_BINARY)
    assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze())

    my_ce = my_ce_loss(INP_IMG_BINARY, TARGET_IMG_BINARY.squeeze())
    assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze())

    # test for 3d volumes
    INP_3d = torch.rand(BS, 1, IM_SIZE, IM_SIZE, IM_SIZE)
    TARGET_3d = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE, IM_SIZE)).float()
    torch_ce = F.binary_cross_entropy_with_logits(INP_3d, TARGET_3d, reduction=reduction)
    my_ce = my_ce_loss(INP_3d, TARGET_3d)
    assert torch.allclose(torch_ce, my_ce)
Exemple #7
0
def test_cross_entropy_weight():
    inp = torch.randn(BS, N_CLASSES)
    target = torch.randint(0, N_CLASSES, (BS,)).long()
    weight_1 = torch.randint(1, 100, (N_CLASSES,)).float()
    weight_2 = weight_1.numpy().astype(int)
    weight_3 = list(weight_2)

    torch_ce_w = torch.nn.CrossEntropyLoss(weight=weight_1)(inp, target)
    my_ce_w = losses.CrossEntropyLoss(weight=weight_1)(inp, target)
    assert torch.allclose(torch_ce_w, my_ce_w)

    my_ce_w = losses.CrossEntropyLoss(weight=weight_2)(inp, target)
    assert torch.allclose(torch_ce_w, my_ce_w)

    my_ce_w = losses.CrossEntropyLoss(weight=weight_3)(inp, target)
    assert torch.allclose(torch_ce_w, my_ce_w)
Exemple #8
0
def test_cross_entropy():
    c = np.random.beta(0.4, 0.4)
    perm = torch.randperm(BS)
    tar_one_hot_2 = TARGET_MULTILABEL * c + (1 - c) * TARGET_MULTILABEL[perm, :]
    my_ce_loss = losses.CrossEntropyLoss()
    torch_ce = torch.nn.CrossEntropyLoss()(INP, TARGET)
    my_ce = my_ce_loss(INP, TARGET)
    assert torch.allclose(torch_ce, my_ce)

    my_ce_oh = my_ce_loss(INP, TARGET_MULTILABEL)
    assert torch.allclose(torch_ce, my_ce_oh)

    my_ce_oh_2 = my_ce_loss(INP, tar_one_hot_2)
    assert not torch.allclose(torch_ce, my_ce_oh_2)

    my_ce_sm = losses.CrossEntropyLoss(smoothing=0.1)(INP, TARGET)
    assert not torch.allclose(my_ce_sm, my_ce)
Exemple #9
0
def test_binary_cross_entropy():
    # classification test
    IM_SIZE = 10
    inp = torch.randn(16).float()
    target = torch.randint(0, 2, (BS, )).float()

    torch_ce = torch.nn.functional.binary_cross_entropy_with_logits(
        inp, target)
    my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target)
    assert torch.allclose(torch_ce, my_ce)

    # test for images
    inp = torch.randn(BS, 1, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE)).float()

    torch_ce = torch.nn.functional.binary_cross_entropy_with_logits(
        inp, target)
    my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target)
    assert torch.allclose(torch_ce, my_ce)

    inp = torch.randn(BS, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, 2, (BS, IM_SIZE, IM_SIZE)).float()

    torch_ce = torch.nn.functional.binary_cross_entropy_with_logits(
        inp, target)
    my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target)
    assert torch.allclose(torch_ce, my_ce)

    # test for images with different y_true shape
    inp = torch.randn(BS, 1, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, 2, (BS, IM_SIZE, IM_SIZE)).float()
    torch_ce = torch.nn.functional.binary_cross_entropy_with_logits(
        inp.squeeze(), target)
    my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target)
    assert torch.allclose(torch_ce, my_ce)

    inp = torch.randn(BS, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE)).float()
    torch_ce = torch.nn.functional.binary_cross_entropy_with_logits(
        inp, target.squeeze())
    my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target)
    assert torch.allclose(torch_ce, my_ce)
Exemple #10
0
def test_cross_entropy():
    inp = torch.randn(BS, N_CLASSES)
    target = torch.randint(0, N_CLASSES, (BS, )).long()
    tar_one_hot = torch.zeros(target.size(0), N_CLASSES, dtype=torch.float)
    tar_one_hot.scatter_(1, target.unsqueeze(1), 1.0)
    c = np.random.beta(0.4, 0.4)
    perm = torch.randperm(BS)
    tar_one_hot_2 = tar_one_hot * c + (1 - c) * tar_one_hot[perm, :]

    torch_ce = torch.nn.CrossEntropyLoss()(inp, target)
    my_ce = losses.CrossEntropyLoss()(inp, target)
    assert torch.allclose(torch_ce, my_ce)

    my_ce_oh = losses.CrossEntropyLoss()(inp, tar_one_hot)
    assert torch.allclose(torch_ce, my_ce_oh)

    my_ce_oh_2 = losses.CrossEntropyLoss()(inp, tar_one_hot_2)
    assert not torch.allclose(torch_ce, my_ce_oh_2)

    my_ce_sm = losses.CrossEntropyLoss(smoothing=0.1)(inp, target)
    assert not torch.allclose(my_ce_sm, my_ce)
Exemple #11
0
def test_cross_entropy_reduction(reduction):
    torch_ce = F.cross_entropy(INP_IMG, TARGET_IMG_MULTICLASS, reduction=reduction)
    my_ce_loss = losses.CrossEntropyLoss(mode="multiclass", reduction=reduction)
    # using multilabel in our loss because in current implementation it wouldn't perform OHE
    my_ce = my_ce_loss(INP_IMG, TARGET_IMG_MULTILABEL)
    assert torch.allclose(torch_ce, my_ce)