コード例 #1
0
ファイル: test_losses.py プロジェクト: kuan-li/pytorch-tools
def test_multiclass_lovasz():
    inp = torch.randn(BS, N_CLASSES, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, N_CLASSES, (BS, 1, IM_SIZE, IM_SIZE)).float()
    criterion = losses.LovaszLoss(mode="multiclass")
    loss = criterion(inp, target)
    # try other target shape
    target = target.view(BS, IM_SIZE, IM_SIZE)
    loss2 = criterion(inp, target)
    assert torch.allclose(loss, loss2)
コード例 #2
0
ファイル: test_losses.py プロジェクト: kuan-li/pytorch-tools
def test_binary_lovasz():
    inp = torch.randn(BS, 1, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE)).float()
    criterion = losses.LovaszLoss(mode="binary")
    loss = criterion(inp, target)
    # try other target shape
    target = target.view(BS, IM_SIZE, IM_SIZE)
    loss2 = criterion(inp, target)
    assert torch.allclose(loss, loss2)
コード例 #3
0
def test_multiclass_multilabel_lovasz():
    loss = losses.LovaszLoss(mode="multiclass")(INP_IMG, TARGET_IMG_MULTICLASS)
    loss2 = losses.LovaszLoss(mode="multilabel")(INP_IMG,
                                                 TARGET_IMG_MULTILABEL)
    assert torch.allclose(loss, loss2)
コード例 #4
0
def test_binary_lovasz():
    loss = losses.LovaszLoss(mode="binary")(INP_IMG_BINARY, TARGET_IMG_BINARY)
    # try other target shape
    target = TARGET_IMG_BINARY.view(BS, IM_SIZE, IM_SIZE)
    loss2 = losses.LovaszLoss(mode="binary")(INP_IMG_BINARY, target)
    assert torch.allclose(loss, loss2)
コード例 #5
0
ファイル: test_losses.py プロジェクト: kuan-li/pytorch-tools
def test_multilabel_lovasz():
    inp = torch.randn(BS, N_CLASSES, IM_SIZE, IM_SIZE).float()
    target = torch.randint(0, 2, (BS, N_CLASSES, IM_SIZE, IM_SIZE)).float()
    criterion = losses.LovaszLoss(mode="multilabel")
    loss = criterion(inp, target)
    assert loss