def test_loss_addition(): d_l = losses.DiceLoss("binary") bf_l = losses.CrossEntropyLoss(mode="binary") l = losses.DiceLoss("binary") * 0.5 + losses.CrossEntropyLoss(mode="binary") * 5 d_res = d_l(INP_IMG_BINARY, TARGET_IMG_BINARY) bf_res = bf_l(INP_IMG_BINARY, TARGET_IMG_BINARY) res = l(INP_IMG_BINARY, TARGET_IMG_BINARY) assert res.shape == d_res.shape assert torch.allclose(res, d_res * 0.5 + bf_res * 5)
def test_loss_addition(): inp = torch.ones(2, 1, 8, 8) label = torch.zeros(2, 1, 8, 8) d_l = losses.DiceLoss('binary') bf_l = losses.BinaryFocalLoss() l = losses.DiceLoss('binary') * 0.5 + losses.BinaryFocalLoss() * 5 d_res = d_l(inp, label) bf_res = bf_l(inp, label) res = l(inp, label) assert res.shape == d_res.shape assert res == d_res * 0.5 + bf_res * 5
def test_dice_loss_binary(): criterion = losses.DiceLoss(mode="binary", from_logits=False, eps=1e-4) # Ideal case y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, 1, -1) y_true = torch.tensor(([1, 1, 1])).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=EPS) y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, 1, -1) y_true = torch.tensor(([1, 0, 1])).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=EPS) y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, 1, -1) y_true = torch.tensor(([0, 0, 0])).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=EPS) # Worst case y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, -1) y_true = torch.tensor([0, 0, 0]).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) # It zeros loss if there is no y_true assert float(loss) == pytest.approx(0.0, abs=EPS) y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, -1) y_true = torch.tensor([0, 1, 0]).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(1.0, abs=EPS) y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, -1) y_true = torch.tensor([1, 1, 1]).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(1.0, abs=EPS)
def test_dice_loss_binary(): criterion = losses.DiceLoss(mode="binary", from_logits=False) # Ideal case y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, 1, -1) y_true = torch.tensor(([1, 1, 1])).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=EPS) y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, 1, -1) y_true = torch.tensor(([1, 0, 1])).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=EPS) y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, 1, -1) y_true = torch.tensor(([0, 0, 0])).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=EPS) # Worst case y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, -1) y_true = torch.tensor([0, 0, 0]).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) # It returns 1. due to internal smoothing assert float(loss) == pytest.approx(1.0, abs=EPS) y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, -1) y_true = torch.tensor([0, 1, 0]).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(1.0, abs=EPS) y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, -1) y_true = torch.tensor([1, 1, 1]).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(1.0, abs=EPS)