def test_multiclass_lovasz(): inp = torch.randn(BS, N_CLASSES, IM_SIZE, IM_SIZE).float() target = torch.randint(0, N_CLASSES, (BS, 1, IM_SIZE, IM_SIZE)).float() criterion = losses.LovaszLoss(mode="multiclass") loss = criterion(inp, target) # try other target shape target = target.view(BS, IM_SIZE, IM_SIZE) loss2 = criterion(inp, target) assert torch.allclose(loss, loss2)
def test_binary_lovasz(): inp = torch.randn(BS, 1, IM_SIZE, IM_SIZE).float() target = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE)).float() criterion = losses.LovaszLoss(mode="binary") loss = criterion(inp, target) # try other target shape target = target.view(BS, IM_SIZE, IM_SIZE) loss2 = criterion(inp, target) assert torch.allclose(loss, loss2)
def test_multiclass_multilabel_lovasz(): loss = losses.LovaszLoss(mode="multiclass")(INP_IMG, TARGET_IMG_MULTICLASS) loss2 = losses.LovaszLoss(mode="multilabel")(INP_IMG, TARGET_IMG_MULTILABEL) assert torch.allclose(loss, loss2)
def test_binary_lovasz(): loss = losses.LovaszLoss(mode="binary")(INP_IMG_BINARY, TARGET_IMG_BINARY) # try other target shape target = TARGET_IMG_BINARY.view(BS, IM_SIZE, IM_SIZE) loss2 = losses.LovaszLoss(mode="binary")(INP_IMG_BINARY, target) assert torch.allclose(loss, loss2)
def test_multilabel_lovasz(): inp = torch.randn(BS, N_CLASSES, IM_SIZE, IM_SIZE).float() target = torch.randint(0, 2, (BS, N_CLASSES, IM_SIZE, IM_SIZE)).float() criterion = losses.LovaszLoss(mode="multilabel") loss = criterion(inp, target) assert loss