def test_dists_simmilar_to_official_implementation() -> None:
    # Baseline scores from: https://github.com/dingkeyan93/DISTS
    loss = DISTS()

    # Greyscale images
    goldhill = torch.tensor(imread('tests/assets/goldhill.gif'))[None, None,
                                                                 ...] / 255.0
    goldhill_jpeg = torch.tensor(
        imread('tests/assets/goldhill_jpeg.gif'))[None, None, ...] / 255.0

    loss_value = loss(goldhill_jpeg, goldhill)
    baseline_value = torch.tensor(0.19509)
    assert torch.isclose(loss_value, baseline_value, atol=1e-3), \
        f'Expected PIQ loss to be equal to original. Got {loss_value} and {baseline_value}'

    # RGB images
    I01 = torch.tensor(imread('tests/assets/I01.BMP')).permute(
        2, 0, 1)[None, ...] / 255.0
    i1_01_5 = torch.tensor(imread('tests/assets/i01_01_5.bmp')).permute(
        2, 0, 1)[None, ...] / 255.0

    loss_value = loss(i1_01_5, I01)
    baseline_value = torch.tensor(0.17321)

    assert torch.isclose(loss_value, baseline_value, atol=1e-3), \
        f'Expected PIQ loss to be equal to original. Got {loss_value} and {baseline_value}'
def test_dists_loss_forward_for_special_cases(x, y, expectation: Any,
                                              value: float) -> None:
    loss = DISTS()
    with expectation:
        loss_value = loss(x, y)
        assert torch.isclose(loss_value, torch.tensor(value), atol=1e-6), \
            f'Expected loss value to be equal to target value. Got {loss_value} and {value}'
def test_dists_computes_grad(x, y, device: str) -> None:
    x.requires_grad_()
    loss_value = DISTS()(x.to(device), y.to(device))
    loss_value.backward()
    assert x.grad is not None, NONE_GRAD_ERR_MSG
def test_dists_loss_forward(x, y, device: str) -> None:
    loss = DISTS()
    loss(x.to(device), y.to(device))
Example #5
0
def test_dists_computes_grad(prediction: torch.Tensor, target: torch.Tensor,
                             device: str) -> None:
    prediction.requires_grad_()
    loss_value = DISTS()(prediction.to(device), target.to(device))
    loss_value.backward()
    assert prediction.grad is not None, NONE_GRAD_ERR_MSG
Example #6
0
def test_dists_loss_forward(prediction: torch.Tensor, target: torch.Tensor,
                            device: str) -> None:
    loss = DISTS()
    loss(prediction.to(device), target.to(device))