예제 #1
0
def test_iw_ssim_loss_reduction(x_rand: torch.Tensor, y_rand: torch.Tensor, device: str) -> None:
    for mode in ['mean', 'sum', 'none']:
        loss = InformationWeightedSSIMLoss(reduction=mode)
        loss(x_rand.to(device), y_rand.to(device))

    for mode in [None, 'n', 2]:
        with pytest.raises(ValueError):
            loss = InformationWeightedSSIMLoss(reduction=mode)
            loss(x_rand.to(device), y_rand.to(device))
예제 #2
0
def test_iw_ssim_loss_fails_for_incorrect_data_range(x_rand: torch.Tensor, y_rand: torch.Tensor, device: str) -> None:
    # Scale to [0, 255]
    x_scaled = (x_rand * 255).type(torch.uint8)
    y_scaled = (y_rand * 255).type(torch.uint8)
    loss = InformationWeightedSSIMLoss(data_range=1.)
    with pytest.raises(AssertionError):
        loss(x_scaled.to(device), y_scaled.to(device))
예제 #3
0
def test_iw_ssim_loss_supports_different_data_ranges(x_rand: torch.Tensor, y_rand: torch.Tensor, data_range: int,
                                                     device: str) -> None:

    x_scaled = (x_rand * data_range).type(torch.uint8)
    y_scaled = (y_rand * data_range).type(torch.uint8)

    loss = InformationWeightedSSIMLoss(data_range=1.)
    loss_scaled = InformationWeightedSSIMLoss(data_range=data_range)

    measure_scaled = loss_scaled(x_scaled.to(device), y_scaled.to(device))
    measure = loss(
        x_scaled.to(device) / float(data_range),
        y_scaled.to(device) / float(data_range)
    )
    diff = torch.abs(measure_scaled - measure)
    assert (diff <= 1e-6).all(), f'Result for same tensor with different data_range should be the same, got {diff}'
예제 #4
0
def test_iw_ssim_loss_is_one_for_equal_tensors(x_rand: torch.Tensor, device: str) -> None:
    x_rand = x_rand.to(device)
    y_rand = x_rand.clone()
    loss = InformationWeightedSSIMLoss(data_range=1.)
    measure = loss(x_rand, y_rand)
    assert torch.allclose(measure, torch.zeros_like(measure), atol=1e-5), \
        f'If equal tensors are passed IW-SSIM must be equal to 0 ' \
        f'(considering floating point operation error up to 1 * 10^-5), got {measure}'
예제 #5
0
def test_iw_ssim_loss_raises_if_kernel_size_greater_than_image(x_rand: torch.Tensor, y_rand: torch.Tensor,
                                                               device: str) -> None:
    kernel_size = 11
    levels = 5
    min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1
    wrong_size_x = x_rand[:, :, :min_size - 1, :min_size - 1]
    wrong_size_y = y_rand[:, :, :min_size - 1, :min_size - 1]
    loss = InformationWeightedSSIMLoss(data_range=1., kernel_size=kernel_size)
    with pytest.raises(ValueError):
        loss(wrong_size_x.to(device), wrong_size_y.to(device))
예제 #6
0
def test_iw_ssim_loss_raises_if_tensors_have_different_shapes(x_rand: torch.Tensor, y_rand: torch.Tensor,
                                                              scale_weights: torch.Tensor, device: str) -> None:

    dims = [[3], [2, 3], [160, 161], [160, 161]]
    loss = InformationWeightedSSIMLoss(data_range=1.)
    for size in list(itertools.product(*dims)):
        wrong_shape_x = torch.rand(size).to(x_rand)
        print(wrong_shape_x.size())
        if wrong_shape_x.size() == x_rand.size():
            loss(wrong_shape_x.to(device), x_rand.to(device))
        else:
            with pytest.raises(AssertionError):
                loss(wrong_shape_x.to(device), x_rand.to(device))

    loss = InformationWeightedSSIMLoss(data_range=1., scale_weights=scale_weights)
    loss(x_rand.to(device), y_rand.to(device))
    wrong_scale_weights = torch.rand(2, 2)
    loss = InformationWeightedSSIMLoss(data_range=1., scale_weights=wrong_scale_weights)
    with pytest.raises(ValueError):
        loss(x_rand.to(device), y_rand.to(device))
예제 #7
0
def test_iw_ssim_loss_corresponds_to_matlab(test_images: List, device: str):
    x_gray, y_gray = test_images[0]
    x_rgb, y_rgb = test_images[1]
    matlab_gray = 1 - torch.tensor(0.886297251092821, device=device)
    matlab_rgb = 1 - torch.tensor(0.946804801436296, device=device)

    loss = InformationWeightedSSIMLoss(data_range=255)
    score_gray = loss(x_gray.to(device), y_gray.to(device))

    assert torch.isclose(score_gray, matlab_gray, atol=1e-5),\
        f'Expected {matlab_gray:.8f}, got {score_gray:.8f} for gray scale case.'

    score_rgb = loss(x_rgb.to(device), y_rgb.to(device))

    assert torch.isclose(score_rgb, matlab_rgb, atol=1e-5),\
        f'Expected {matlab_rgb:.8f}, got {score_rgb:.8f} for rgb case.'
예제 #8
0
def test_iw_ssim_loss_backprop(x_rand: torch.Tensor, y_rand: torch.Tensor, device: str):
    x_rand.requires_grad_(True)
    loss = InformationWeightedSSIMLoss(data_range=1.)
    score_gray = loss(x_rand.to(device), y_rand.to(device))
    score_gray.backward()
    assert torch.isfinite(x_rand.grad).all(), f'Expected finite gradient values, got {x_rand.grad}.'
예제 #9
0
def test_iw_ssim_loss_preserves_dtype(x_rand: torch.Tensor, y_rand: torch.Tensor, dtype: torch.dtype,
                                      device: str) -> None:
    loss = InformationWeightedSSIMLoss(data_range=1.)
    output = loss(x_rand.to(device=device, dtype=dtype), y_rand.to(device=device, dtype=dtype))
    assert output.dtype == dtype
예제 #10
0
def test_iw_ssim_loss_raises_if_tensors_have_different_types(x_rand: torch.Tensor, device: str) -> None:
    wrong_type_x = list(range(10))
    loss = InformationWeightedSSIMLoss(data_range=1.)
    with pytest.raises(AssertionError):
        loss(wrong_type_x, x_rand.to(device))