def test_iw_ssim_loss_reduction(x_rand: torch.Tensor, y_rand: torch.Tensor, device: str) -> None: for mode in ['mean', 'sum', 'none']: loss = InformationWeightedSSIMLoss(reduction=mode) loss(x_rand.to(device), y_rand.to(device)) for mode in [None, 'n', 2]: with pytest.raises(ValueError): loss = InformationWeightedSSIMLoss(reduction=mode) loss(x_rand.to(device), y_rand.to(device))
def test_iw_ssim_loss_fails_for_incorrect_data_range(x_rand: torch.Tensor, y_rand: torch.Tensor, device: str) -> None: # Scale to [0, 255] x_scaled = (x_rand * 255).type(torch.uint8) y_scaled = (y_rand * 255).type(torch.uint8) loss = InformationWeightedSSIMLoss(data_range=1.) with pytest.raises(AssertionError): loss(x_scaled.to(device), y_scaled.to(device))
def test_iw_ssim_loss_supports_different_data_ranges(x_rand: torch.Tensor, y_rand: torch.Tensor, data_range: int, device: str) -> None: x_scaled = (x_rand * data_range).type(torch.uint8) y_scaled = (y_rand * data_range).type(torch.uint8) loss = InformationWeightedSSIMLoss(data_range=1.) loss_scaled = InformationWeightedSSIMLoss(data_range=data_range) measure_scaled = loss_scaled(x_scaled.to(device), y_scaled.to(device)) measure = loss( x_scaled.to(device) / float(data_range), y_scaled.to(device) / float(data_range) ) diff = torch.abs(measure_scaled - measure) assert (diff <= 1e-6).all(), f'Result for same tensor with different data_range should be the same, got {diff}'
def test_iw_ssim_loss_is_one_for_equal_tensors(x_rand: torch.Tensor, device: str) -> None: x_rand = x_rand.to(device) y_rand = x_rand.clone() loss = InformationWeightedSSIMLoss(data_range=1.) measure = loss(x_rand, y_rand) assert torch.allclose(measure, torch.zeros_like(measure), atol=1e-5), \ f'If equal tensors are passed IW-SSIM must be equal to 0 ' \ f'(considering floating point operation error up to 1 * 10^-5), got {measure}'
def test_iw_ssim_loss_raises_if_kernel_size_greater_than_image(x_rand: torch.Tensor, y_rand: torch.Tensor, device: str) -> None: kernel_size = 11 levels = 5 min_size = (kernel_size - 1) * 2 ** (levels - 1) + 1 wrong_size_x = x_rand[:, :, :min_size - 1, :min_size - 1] wrong_size_y = y_rand[:, :, :min_size - 1, :min_size - 1] loss = InformationWeightedSSIMLoss(data_range=1., kernel_size=kernel_size) with pytest.raises(ValueError): loss(wrong_size_x.to(device), wrong_size_y.to(device))
def test_iw_ssim_loss_raises_if_tensors_have_different_shapes(x_rand: torch.Tensor, y_rand: torch.Tensor, scale_weights: torch.Tensor, device: str) -> None: dims = [[3], [2, 3], [160, 161], [160, 161]] loss = InformationWeightedSSIMLoss(data_range=1.) for size in list(itertools.product(*dims)): wrong_shape_x = torch.rand(size).to(x_rand) print(wrong_shape_x.size()) if wrong_shape_x.size() == x_rand.size(): loss(wrong_shape_x.to(device), x_rand.to(device)) else: with pytest.raises(AssertionError): loss(wrong_shape_x.to(device), x_rand.to(device)) loss = InformationWeightedSSIMLoss(data_range=1., scale_weights=scale_weights) loss(x_rand.to(device), y_rand.to(device)) wrong_scale_weights = torch.rand(2, 2) loss = InformationWeightedSSIMLoss(data_range=1., scale_weights=wrong_scale_weights) with pytest.raises(ValueError): loss(x_rand.to(device), y_rand.to(device))
def test_iw_ssim_loss_corresponds_to_matlab(test_images: List, device: str): x_gray, y_gray = test_images[0] x_rgb, y_rgb = test_images[1] matlab_gray = 1 - torch.tensor(0.886297251092821, device=device) matlab_rgb = 1 - torch.tensor(0.946804801436296, device=device) loss = InformationWeightedSSIMLoss(data_range=255) score_gray = loss(x_gray.to(device), y_gray.to(device)) assert torch.isclose(score_gray, matlab_gray, atol=1e-5),\ f'Expected {matlab_gray:.8f}, got {score_gray:.8f} for gray scale case.' score_rgb = loss(x_rgb.to(device), y_rgb.to(device)) assert torch.isclose(score_rgb, matlab_rgb, atol=1e-5),\ f'Expected {matlab_rgb:.8f}, got {score_rgb:.8f} for rgb case.'
def test_iw_ssim_loss_backprop(x_rand: torch.Tensor, y_rand: torch.Tensor, device: str): x_rand.requires_grad_(True) loss = InformationWeightedSSIMLoss(data_range=1.) score_gray = loss(x_rand.to(device), y_rand.to(device)) score_gray.backward() assert torch.isfinite(x_rand.grad).all(), f'Expected finite gradient values, got {x_rand.grad}.'
def test_iw_ssim_loss_preserves_dtype(x_rand: torch.Tensor, y_rand: torch.Tensor, dtype: torch.dtype, device: str) -> None: loss = InformationWeightedSSIMLoss(data_range=1.) output = loss(x_rand.to(device=device, dtype=dtype), y_rand.to(device=device, dtype=dtype)) assert output.dtype == dtype
def test_iw_ssim_loss_raises_if_tensors_have_different_types(x_rand: torch.Tensor, device: str) -> None: wrong_type_x = list(range(10)) loss = InformationWeightedSSIMLoss(data_range=1.) with pytest.raises(AssertionError): loss(wrong_type_x, x_rand.to(device))