Ejemplo n.º 1
0
def test_multiclass_lovasz():
    inp = torch.randn(BS, N_CLASSES, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, N_CLASSES, (BS, 1, IM_SIZE, IM_SIZE)).float()
    criterion = losses.LovaszLoss(mode="multiclass")
    loss = criterion(inp, target)
    # try other target shape
    target = target.view(BS, IM_SIZE, IM_SIZE)
    loss2 = criterion(inp, target)
    assert torch.allclose(loss, loss2)
Ejemplo n.º 2
0
def test_binary_lovasz():
    inp = torch.randn(BS, 1, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE)).float()
    criterion = losses.LovaszLoss(mode="binary")
    loss = criterion(inp, target)
    # try other target shape
    target = target.view(BS, IM_SIZE, IM_SIZE)
    loss2 = criterion(inp, target)
    assert torch.allclose(loss, loss2)
Ejemplo n.º 3
0
def test_multiclass_multilabel_lovasz():
    loss = losses.LovaszLoss(mode="multiclass")(INP_IMG, TARGET_IMG_MULTICLASS)
    loss2 = losses.LovaszLoss(mode="multilabel")(INP_IMG,
                                                 TARGET_IMG_MULTILABEL)
    assert torch.allclose(loss, loss2)
Ejemplo n.º 4
0
def test_binary_lovasz():
    loss = losses.LovaszLoss(mode="binary")(INP_IMG_BINARY, TARGET_IMG_BINARY)
    # try other target shape
    target = TARGET_IMG_BINARY.view(BS, IM_SIZE, IM_SIZE)
    loss2 = losses.LovaszLoss(mode="binary")(INP_IMG_BINARY, target)
    assert torch.allclose(loss, loss2)
Ejemplo n.º 5
0
def test_multilabel_lovasz():
    inp = torch.randn(BS, N_CLASSES, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, 2, (BS, N_CLASSES, IM_SIZE, IM_SIZE)).float()
    criterion = losses.LovaszLoss(mode="multilabel")
    loss = criterion(inp, target)
    assert loss