コード例 #1
0
def test_ssim_raises_if_kernel_size_greater_than_image() -> None:
    right_kernel_sizes = list(range(1, 52, 2))
    for kernel_size in right_kernel_sizes:
        wrong_size_prediction = torch.rand(3, 3, kernel_size - 1, kernel_size - 1)
        wrong_size_target = torch.rand(3, 3, kernel_size - 1, kernel_size - 1)
        with pytest.raises(ValueError):
            ssim(wrong_size_prediction, wrong_size_target, kernel_size=kernel_size)
コード例 #2
0
def test_ssim_raises_if_tensors_have_different_shapes(target: torch.Tensor) -> None:
    dims = [[3], [2, 3], [255, 256], [255, 256]]
    for b, c, h, w in list(itertools.product(*dims)):
        wrong_shape_prediction = torch.rand(b, c, h, w)
        if wrong_shape_prediction.size() == target.size():
            try:
                ssim(wrong_shape_prediction, target)
            except Exception as e:
                pytest.fail(f"Unexpected error occurred: {e}")
        else:
            with pytest.raises(AssertionError):
                ssim(wrong_shape_prediction, target)
コード例 #3
0
def test_ssim_raise_if_wrong_value_is_estimated(prediction: torch.Tensor, target: torch.Tensor) -> None:
    photosynthesis_ssim = ssim(prediction, target, kernel_size=11, kernel_sigma=1.5, data_range=1., size_average=False)
    tf_prediction = tf.convert_to_tensor(prediction.permute(0, 2, 3, 1).numpy())
    tf_target = tf.convert_to_tensor(target.permute(0, 2, 3, 1).numpy())
    tf_ssim = torch.tensor(tf.image.ssim(tf_prediction, tf_target, max_val=1.).numpy())
    assert torch.isclose(photosynthesis_ssim, tf_ssim, atol=1e-6).all(), \
        f'The estimated value must be equal to tensorflow provided one' \
        f'(considering floating point operation error up to 1 * 10^-6), ' \
        f'got difference {photosynthesis_ssim - tf_ssim}'
コード例 #4
0
def test_ssim_raises_if_wrong_kernel_size_is_passed(prediction: torch.Tensor, target: torch.Tensor) -> None:
    wrong_kernel_sizes = list(range(0, 50, 2))
    for kernel_size in wrong_kernel_sizes:
        with pytest.raises(AssertionError):
            ssim(prediction, target, kernel_size=kernel_size)
コード例 #5
0
def test_ssim_raises_if_tensors_have_different_types(target: torch.Tensor) -> None:
    wrong_type_prediction = list(range(10))
    with pytest.raises(AssertionError):
        ssim(wrong_type_prediction, target)
コード例 #6
0
def test_ssim_measure_is_less_or_equal_to_one_cuda() -> None:
    ones = torch.ones((3, 3, 256, 256)).cuda()
    zeros = torch.zeros((3, 3, 256, 256)).cuda()
    measure = ssim(ones, zeros, data_range=1.)
    assert measure <= 1, f'SSIM must be <= 1, got {measure}'
コード例 #7
0
def test_ssim_measure_is_less_or_equal_to_one() -> None:
    # Create two maximally different tensors.
    ones = torch.ones((3, 3, 256, 256))
    zeros = torch.zeros((3, 3, 256, 256))
    measure = ssim(ones, zeros, data_range=1.)
    assert measure <= 1, f'SSIM must be <= 1, got {measure}'
コード例 #8
0
def test_ssim_measure_is_zero_for_equal_tensors(target: torch.Tensor) -> None:
    prediction = target.clone()
    measure = ssim(prediction, target, data_range=1.)
    measure -= 1.
    assert measure.sum() <= 1e-6, f'If equal tensors are passed SSIM must be equal to 0 ' \
                                  f'(considering floating point operation error up to 1 * 10^-6), got {measure}'
コード例 #9
0
def test_ssim_symmetry(prediction: torch.Tensor, target: torch.Tensor) -> None:
    measure = ssim(prediction, target, data_range=1.)
    reverse_measure = ssim(target, prediction, data_range=1.)
    assert measure == reverse_measure, f'Expect: SSIM(a, b) == SSIM(b, a), got {measure} != {reverse_measure}'