def test_cross_entropy_from_logits(): """Check that passing from_logits True and taking softmax manually gives the same result""" loss_1 = losses.CrossEntropyLoss() res_1 = loss_1(INP, TARGET) loss_2 = losses.CrossEntropyLoss(from_logits=False) res_2 = loss_2(INP.softmax(1), TARGET) assert torch.allclose(res_1, res_2)
def test_binary_cross_entropy_from_logits(): """Check that passing from_logits True and taking sigmoid manually gives the same result""" loss_1 = losses.CrossEntropyLoss(mode="binary") res_1 = loss_1(INP_BINARY, TARGET_BINARY) loss_2 = losses.CrossEntropyLoss(mode="binary", from_logits=False) res_2 = loss_2(INP_BINARY.sigmoid(), TARGET_BINARY) assert torch.allclose(res_1, res_2)
def test_loss_addition(): d_l = losses.DiceLoss("binary") bf_l = losses.CrossEntropyLoss(mode="binary") l = losses.DiceLoss("binary") * 0.5 + losses.CrossEntropyLoss(mode="binary") * 5 d_res = d_l(INP_IMG_BINARY, TARGET_IMG_BINARY) bf_res = bf_l(INP_IMG_BINARY, TARGET_IMG_BINARY) res = l(INP_IMG_BINARY, TARGET_IMG_BINARY) assert res.shape == d_res.shape assert torch.allclose(res, d_res * 0.5 + bf_res * 5)
def test_cross_entropy_weight(): weight_1 = torch.randint(1, 100, (N_CLASSES, )).float() weight_2 = weight_1.numpy().astype(int) weight_3 = list(weight_2) torch_ce_w = torch.nn.CrossEntropyLoss(weight=weight_1)(INP, TARGET) my_ce_w = losses.CrossEntropyLoss(weight=weight_1)(INP, TARGET) assert torch.allclose(torch_ce_w, my_ce_w) my_ce_w = losses.CrossEntropyLoss(weight=weight_2)(INP, TARGET) assert torch.allclose(torch_ce_w, my_ce_w) my_ce_w = losses.CrossEntropyLoss(weight=weight_3)(INP, TARGET) assert torch.allclose(torch_ce_w, my_ce_w)
def test_binary_cross_entropy(reduction): # classification test torch_ce = F.binary_cross_entropy_with_logits(INP_BINARY, TARGET_BINARY, reduction=reduction) my_ce_loss = losses.CrossEntropyLoss(mode="binary", reduction=reduction) my_ce = my_ce_loss(INP_BINARY, TARGET_BINARY) assert torch.allclose(torch_ce, my_ce) # test for images torch_ce = F.binary_cross_entropy_with_logits(INP_IMG_BINARY, TARGET_IMG_BINARY, reduction=reduction) my_ce = my_ce_loss(INP_IMG_BINARY, TARGET_IMG_BINARY) assert torch.allclose(torch_ce, my_ce) my_ce = my_ce_loss(INP_IMG_BINARY.squeeze(), TARGET_IMG_BINARY.squeeze()) assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze()) # test for images with different y_true shape my_ce = my_ce_loss(INP_IMG_BINARY.squeeze(), TARGET_IMG_BINARY) assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze()) my_ce = my_ce_loss(INP_IMG_BINARY, TARGET_IMG_BINARY.squeeze()) assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze())
def test_binary_cross_entropy(reduction): # classification test torch_ce = F.binary_cross_entropy_with_logits(INP_BINARY, TARGET_BINARY, reduction=reduction) my_ce_loss = losses.CrossEntropyLoss(mode="binary", reduction=reduction) my_ce = my_ce_loss(INP_BINARY, TARGET_BINARY) assert torch.allclose(torch_ce, my_ce) # test than long targets would also work my_ce = my_ce_loss(INP_BINARY, TARGET_BINARY.long()) assert torch.allclose(torch_ce, my_ce) # test for images torch_ce = F.binary_cross_entropy_with_logits(INP_IMG_BINARY, TARGET_IMG_BINARY, reduction=reduction) my_ce = my_ce_loss(INP_IMG_BINARY, TARGET_IMG_BINARY) assert torch.allclose(torch_ce, my_ce) my_ce = my_ce_loss(INP_IMG_BINARY.squeeze(), TARGET_IMG_BINARY.squeeze()) assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze()) # test for images with different y_true shape my_ce = my_ce_loss(INP_IMG_BINARY.squeeze(), TARGET_IMG_BINARY) assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze()) my_ce = my_ce_loss(INP_IMG_BINARY, TARGET_IMG_BINARY.squeeze()) assert torch.allclose(torch_ce.squeeze(), my_ce.squeeze()) # test for 3d volumes INP_3d = torch.rand(BS, 1, IM_SIZE, IM_SIZE, IM_SIZE) TARGET_3d = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE, IM_SIZE)).float() torch_ce = F.binary_cross_entropy_with_logits(INP_3d, TARGET_3d, reduction=reduction) my_ce = my_ce_loss(INP_3d, TARGET_3d) assert torch.allclose(torch_ce, my_ce)
def test_cross_entropy_weight(): inp = torch.randn(BS, N_CLASSES) target = torch.randint(0, N_CLASSES, (BS,)).long() weight_1 = torch.randint(1, 100, (N_CLASSES,)).float() weight_2 = weight_1.numpy().astype(int) weight_3 = list(weight_2) torch_ce_w = torch.nn.CrossEntropyLoss(weight=weight_1)(inp, target) my_ce_w = losses.CrossEntropyLoss(weight=weight_1)(inp, target) assert torch.allclose(torch_ce_w, my_ce_w) my_ce_w = losses.CrossEntropyLoss(weight=weight_2)(inp, target) assert torch.allclose(torch_ce_w, my_ce_w) my_ce_w = losses.CrossEntropyLoss(weight=weight_3)(inp, target) assert torch.allclose(torch_ce_w, my_ce_w)
def test_cross_entropy(): c = np.random.beta(0.4, 0.4) perm = torch.randperm(BS) tar_one_hot_2 = TARGET_MULTILABEL * c + (1 - c) * TARGET_MULTILABEL[perm, :] my_ce_loss = losses.CrossEntropyLoss() torch_ce = torch.nn.CrossEntropyLoss()(INP, TARGET) my_ce = my_ce_loss(INP, TARGET) assert torch.allclose(torch_ce, my_ce) my_ce_oh = my_ce_loss(INP, TARGET_MULTILABEL) assert torch.allclose(torch_ce, my_ce_oh) my_ce_oh_2 = my_ce_loss(INP, tar_one_hot_2) assert not torch.allclose(torch_ce, my_ce_oh_2) my_ce_sm = losses.CrossEntropyLoss(smoothing=0.1)(INP, TARGET) assert not torch.allclose(my_ce_sm, my_ce)
def test_binary_cross_entropy(): # classification test IM_SIZE = 10 inp = torch.randn(16).float() target = torch.randint(0, 2, (BS, )).float() torch_ce = torch.nn.functional.binary_cross_entropy_with_logits( inp, target) my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target) assert torch.allclose(torch_ce, my_ce) # test for images inp = torch.randn(BS, 1, IM_SIZE, IM_SIZE).float() target = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE)).float() torch_ce = torch.nn.functional.binary_cross_entropy_with_logits( inp, target) my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target) assert torch.allclose(torch_ce, my_ce) inp = torch.randn(BS, IM_SIZE, IM_SIZE).float() target = torch.randint(0, 2, (BS, IM_SIZE, IM_SIZE)).float() torch_ce = torch.nn.functional.binary_cross_entropy_with_logits( inp, target) my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target) assert torch.allclose(torch_ce, my_ce) # test for images with different y_true shape inp = torch.randn(BS, 1, IM_SIZE, IM_SIZE).float() target = torch.randint(0, 2, (BS, IM_SIZE, IM_SIZE)).float() torch_ce = torch.nn.functional.binary_cross_entropy_with_logits( inp.squeeze(), target) my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target) assert torch.allclose(torch_ce, my_ce) inp = torch.randn(BS, IM_SIZE, IM_SIZE).float() target = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE)).float() torch_ce = torch.nn.functional.binary_cross_entropy_with_logits( inp, target.squeeze()) my_ce = losses.CrossEntropyLoss(mode='binary')(inp, target) assert torch.allclose(torch_ce, my_ce)
def test_cross_entropy(): inp = torch.randn(BS, N_CLASSES) target = torch.randint(0, N_CLASSES, (BS, )).long() tar_one_hot = torch.zeros(target.size(0), N_CLASSES, dtype=torch.float) tar_one_hot.scatter_(1, target.unsqueeze(1), 1.0) c = np.random.beta(0.4, 0.4) perm = torch.randperm(BS) tar_one_hot_2 = tar_one_hot * c + (1 - c) * tar_one_hot[perm, :] torch_ce = torch.nn.CrossEntropyLoss()(inp, target) my_ce = losses.CrossEntropyLoss()(inp, target) assert torch.allclose(torch_ce, my_ce) my_ce_oh = losses.CrossEntropyLoss()(inp, tar_one_hot) assert torch.allclose(torch_ce, my_ce_oh) my_ce_oh_2 = losses.CrossEntropyLoss()(inp, tar_one_hot_2) assert not torch.allclose(torch_ce, my_ce_oh_2) my_ce_sm = losses.CrossEntropyLoss(smoothing=0.1)(inp, target) assert not torch.allclose(my_ce_sm, my_ce)
def test_cross_entropy_reduction(reduction): torch_ce = F.cross_entropy(INP_IMG, TARGET_IMG_MULTICLASS, reduction=reduction) my_ce_loss = losses.CrossEntropyLoss(mode="multiclass", reduction=reduction) # using multilabel in our loss because in current implementation it wouldn't perform OHE my_ce = my_ce_loss(INP_IMG, TARGET_IMG_MULTILABEL) assert torch.allclose(torch_ce, my_ce)