def test_iw_ssim_reduction(x_rand: torch.Tensor, y_rand: torch.Tensor, device: str) -> None:
    for mode in ['mean', 'sum', 'none']:
        information_weighted_ssim(x_rand.to(device), y_rand.to(device), reduction=mode)

    for mode in [None, 'n', 2]:
        with pytest.raises(ValueError):
            information_weighted_ssim(x_rand.to(device), y_rand.to(device), reduction=mode)
def test_iw_ssim_raises_if_kernel_size_greater_than_image(x_rand: torch.Tensor, y_rand: torch.Tensor,
                                                          device: str) -> None:
    kernel_size = 11
    levels = 5
    min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1
    wrong_size_x = x_rand[:, :, :min_size - 1, :min_size - 1]
    wrong_size_y = y_rand[:, :, :min_size - 1, :min_size - 1]
    with pytest.raises(ValueError):
        information_weighted_ssim(wrong_size_x.to(device), wrong_size_y.to(device), kernel_size=kernel_size)
def test_iw_ssim_supports_different_data_ranges(x_rand: torch.Tensor, y_rand: torch.Tensor, data_range: int,
                                                device: str) -> None:

    x_scaled = (x_rand * data_range).type(torch.uint8)
    y_scaled = (y_rand * data_range).type(torch.uint8)

    measure_scaled = information_weighted_ssim(x_scaled.to(device), y_scaled.to(device), data_range=data_range)
    measure = information_weighted_ssim(
        x_scaled.to(device) / float(data_range),
        y_scaled.to(device) / float(data_range),
        data_range=1.0
    )
    diff = torch.abs(measure_scaled - measure)
    assert (diff <= 1e-6).all(), f'Result for same tensor with different data_range should be the same, got {diff}'
def test_iw_ssim_corresponds_to_matlab(test_images: List, device: str):
    x_gray, y_gray = test_images[0]
    x_rgb, y_rgb = test_images[1]
    matlab_gray = torch.tensor(0.886297251092821, device=device)
    matlab_rgb = torch.tensor(0.946804801436296, device=device)

    score_gray = information_weighted_ssim(x_gray.to(device), y_gray.to(device), data_range=255)

    assert torch.isclose(score_gray, matlab_gray, atol=1e-5),\
        f'Expected {matlab_gray:.4f}, got {score_gray:.4f} for gray scale case.'

    score_rgb = information_weighted_ssim(x_rgb.to(device), y_rgb.to(device), data_range=255)

    assert torch.isclose(score_rgb, matlab_rgb, atol=1e-5),\
        f'Expected {matlab_rgb:.8f}, got {score_rgb:.8f} for rgb case.'
def test_iw_ssim_measure_is_one_for_equal_tensors(x_rand: torch.Tensor, device: str) -> None:
    x_rand = x_rand.to(device)
    y_rand = x_rand.clone()
    measure = information_weighted_ssim(x_rand, y_rand, data_range=1.)
    assert torch.allclose(measure, torch.ones_like(measure)), \
        f'If equal tensors are passed IW-SSIM must be equal to 1 ' \
        f'(considering floating point operation error up to 1 * 10^-6), got {measure + 1}'
def test_iw_ssim_raises_if_tensors_have_different_shapes(x_rand: torch.Tensor, y_rand: torch.Tensor,
                                                         scale_weights: torch.Tensor, device: str) -> None:

    dims = [[3], [2, 3], [160, 161], [160, 161]]

    for size in list(itertools.product(*dims)):
        wrong_shape_x = torch.rand(size).to(x_rand)
        print(wrong_shape_x.size())
        if wrong_shape_x.size() == x_rand.size():
            information_weighted_ssim(wrong_shape_x.to(device), x_rand.to(device))
        else:
            with pytest.raises(AssertionError):
                information_weighted_ssim(wrong_shape_x.to(device), x_rand.to(device))

    information_weighted_ssim(x_rand.to(device), y_rand.to(device), scale_weights=scale_weights.to(device))

    wrong_scale_weights = torch.rand(2, 2)
    with pytest.raises(ValueError):
        information_weighted_ssim(x_rand.to(device), y_rand.to(device), scale_weights=wrong_scale_weights.to(device))
def test_iw_ssim_raises_if_tensors_have_different_types(x_rand: torch.Tensor, device: str) -> None:
    wrong_type_x = list(range(10))
    with pytest.raises(AssertionError):
        information_weighted_ssim(wrong_type_x, x_rand.to(device))
def test_iw_ssim_preserves_dtype(x_rand: torch.Tensor, y_rand: torch.Tensor, dtype: torch.dtype, device: str) -> None:
    output = information_weighted_ssim(x_rand.to(device=device, dtype=dtype), y_rand.to(device=device, dtype=dtype))
    assert output.dtype == dtype
def test_iw_ssim_fails_for_incorrect_data_range(x_rand: torch.Tensor, y_rand: torch.Tensor, device: str) -> None:
    # Scale to [0, 255]
    x_scaled = (x_rand * 255).type(torch.uint8)
    y_scaled = (y_rand * 255).type(torch.uint8)
    with pytest.raises(AssertionError):
        information_weighted_ssim(x_scaled.to(device), y_scaled.to(device), data_range=1.0)