Example #1
0
    def test_kernel_5x5_batch(self, device, dtype):
        batch_size = 3
        inp = torch.tensor(
            [
                [
                    [
                        [1.0, 1.0, 1.0, 1.0, 1.0],
                        [1.0, 1.0, 1.0, 1.0, 1.0],
                        [1.0, 1.0, 1.0, 1.0, 1.0],
                        [2.0, 2.0, 2.0, 2.0, 2.0],
                        [2.0, 2.0, 2.0, 2.0, 2.0],
                    ]
                ]
            ],
            device=device,
            dtype=dtype,
        ).repeat(batch_size, 1, 1, 1)

        kernel_size = (5, 5)
        expected = inp.sum((1, 2, 3)) / torch.mul(*kernel_size)

        actual = kornia.filters.box_blur(inp, kernel_size)

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1, 1e-4)
        assert_allclose(actual[:, 0, 2, 2], expected, rtol=tol_val, atol=tol_val)
Example #2
0
    def test_normalized_mean_filter(self, device, dtype):
        kernel = torch.ones(1, 3, 3).to(device)
        input = torch.tensor(
            [[[
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 5.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
            ]]],
            device=device,
            dtype=dtype,
        ).expand(2, 2, -1, -1)

        nv: float = 5.0 / 9  # normalization value
        expected = torch.tensor(
            [[[
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, nv, nv, nv, 0.0],
                [0.0, nv, nv, nv, 0.0],
                [0.0, nv, nv, nv, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
            ]]],
            device=device,
            dtype=dtype,
        ).expand(2, 2, -1, -1)

        actual = kornia.filter2d(input, kernel, normalized=True)

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1,
                                                      1e-4)
        assert_allclose(actual, expected, rtol=tol_val, atol=tol_val)
Example #3
0
    def test_forth_and_back(self, device, dtype):
        data = torch.rand(3, 4, 5, device=device, dtype=dtype)
        lab = kornia.color.rgb_to_lab
        rgb = kornia.color.lab_to_rgb

        unclipped_data_out = rgb(lab(data), clip=False)
        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1,
                                                      1e-4)
        assert_allclose(unclipped_data_out, data, rtol=tol_val, atol=tol_val)
Example #4
0
    def test_ssim(self, device, dtype, batch_shape, window_size, reduction_type):
        # input data
        img = torch.rand(batch_shape, device=device, dtype=dtype)

        loss = kornia.losses.ssim_loss(
            img, img, window_size, reduction=reduction_type)

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1, 1e-4)
        assert_allclose(loss.item(), 0.0, rtol=tol_val, atol=tol_val)
Example #5
0
    def test_unit(self, device, dtype):
        data = torch.tensor(
            [[[0.13088040, 0.54399723, 0.69396782, 0.63581685, 0.09902618],
              [0.59459005, 0.74215373, 0.89662376, 0.25920381, 0.89937686],
              [0.29857584, 0.28139791, 0.16441015, 0.55507519, 0.06124221],
              [0.40908658, 0.10261389, 0.01691456, 0.76006799, 0.32971736]],
             [[0.60354551, 0.76201361, 0.79009938, 0.91742945, 0.60044175],
              [0.42812678, 0.18552390, 0.04186043, 0.38030245, 0.15420346],
              [0.13552373, 0.53955473, 0.79102736, 0.49050815, 0.75271446],
              [0.39861023, 0.80680277, 0.82823833, 0.54438462, 0.22063386]],
             [[0.63231256, 0.18316011, 0.84317145, 0.59529881, 0.15297393],
              [0.59235313, 0.36617295, 0.34600773, 0.40304737, 0.61720451],
              [0.46040250, 0.42006640, 0.54765106, 0.48982632, 0.13914755],
              [0.58402964, 0.89597990, 0.98276161, 0.25019163, 0.69285921]]],
            device=device,
            dtype=dtype)

        # Reference output generated using skimage: rgb2lab(data)
        expected = torch.tensor([
            [[58.02612517, 72.48876064, 79.75208576, 86.38913217, 55.25164186],
             [51.66668553, 43.81214392, 48.93865503, 39.03804484, 52.55152607],
             [23.7114063, 52.38661792, 72.54607218, 53.89587489, 67.94892652],
             [45.02897165, 75.98315061, 78.257619, 61.85069778, 33.77972627]],
            [[
                -28.63524124, -39.21783796, -5.40909568, -37.74958445,
                -55.02172792
            ],
             [24.16049084, 58.53088654, 75.33566652, -9.65827726, 76.94753157],
             [
                 36.53113547, -28.57665427, -54.16269089, 6.2586262,
                 -67.69290198
             ],
             [
                 12.32708756, -33.04781428, -29.29282657, 13.46090338,
                 42.98737069
             ]],
            [[-12.99941502, 63.48788307, -9.49591204, 32.9931831, 47.80929165],
             [
                 -16.11189945, 7.72083678, 19.17820444, -6.90801653,
                 -17.46468994
             ],
             [-39.99097133, 9.92432127, 19.90687976, 2.40429413, 61.24066709],
             [
                 -25.45166461, -22.94347485, -31.32259433, 47.2621717,
                 -60.05694598
             ]]
        ],
                                device=device,
                                dtype=dtype)

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1,
                                                      1e-4)
        assert_allclose(kornia.color.rgb_to_lab(data),
                        expected,
                        rtol=tol_val,
                        atol=tol_val)
Example #6
0
    def test_ssim(self, device, dtype, batch_shape, window_size, reduction_type):
        if device.type == 'xla':
            pytest.skip("test highly unstable with tpu")

        # input data
        img = torch.rand(batch_shape, device=device, dtype=dtype)

        loss = kornia.losses.ssim_loss(img, img, window_size, reduction=reduction_type)

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1, 1e-4)
        assert_close(loss.item(), 0.0, rtol=tol_val, atol=tol_val)
Example #7
0
    def test_ssim_equal_none(self, device, dtype):
        # input data
        img1 = torch.rand(1, 1, 10, 16, device=device, dtype=dtype)
        img2 = torch.rand(1, 1, 10, 16, device=device, dtype=dtype)

        ssim1 = kornia.losses.ssim_loss(img1, img1, window_size=5, reduction="none")
        ssim2 = kornia.losses.ssim_loss(img2, img2, window_size=5, reduction="none")

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1, 1e-4)
        assert_close(ssim1, torch.zeros_like(img1), rtol=tol_val, atol=tol_val)
        assert_close(ssim2, torch.zeros_like(img2), rtol=tol_val, atol=tol_val)
Example #8
0
    def test_identity(self, input_shape, eps, device, dtype):
        """Assert that data can be recovered by the inverse transform."""

        data = torch.randn(*input_shape, device=device, dtype=dtype)

        zca = kornia.enhance.ZCAWhitening(compute_inv=True, eps=eps).fit(data)

        data_w = zca(data)

        data_hat = zca.inverse_transform(data_w)

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1,
                                                      1e-4)
        assert_close(data, data_hat, rtol=tol_val, atol=tol_val)
Example #9
0
    def test_kernel_3x3_nonormalize(self, device, dtype):
        inp = torch.tensor([[[[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.],
                              [1., 1., 1., 1., 1.], [2., 2., 2., 2., 2.],
                              [2., 2., 2., 2., 2.]]]],
                           device=device,
                           dtype=dtype)

        kernel_size = (3, 3)
        actual = kornia.filters.box_blur(inp, kernel_size, normalized=False)

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1,
                                                      1e-4)
        assert_allclose(actual.sum(),
                        torch.tensor(35.).to(actual),
                        rtol=tol_val,
                        atol=tol_val)
Example #10
0
    def test_normalized_mean_filter(self, padding, device, dtype):
        kernel = torch.ones(1, 3, 3).to(device)
        input = torch.tensor(
            [[[
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 5.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
            ]]],
            device=device,
            dtype=dtype,
        ).expand(2, 2, -1, -1)

        nv: float = 5.0 / 9  # normalization value
        expected_same = torch.tensor(
            [[[
                [0.0, 0.0, 0.0, 0.0, 0.0],
                [0.0, nv, nv, nv, 0.0],
                [0.0, nv, nv, nv, 0.0],
                [0.0, nv, nv, nv, 0.0],
                [0.0, 0.0, 0.0, 0.0, 0.0],
            ]]],
            device=device,
            dtype=dtype,
        ).expand(2, 2, -1, -1)

        expected_valid = torch.tensor(
            [[[[nv, nv, nv], [nv, nv, nv], [nv, nv, nv]]]],
            device=device,
            dtype=dtype).expand(2, 2, -1, -1)

        actual = kornia.filters.filter2d(input,
                                         kernel,
                                         normalized=True,
                                         padding=padding)

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1,
                                                      1e-4)
        if padding == 'same':
            assert_close(actual, expected_same, rtol=tol_val, atol=tol_val)
        else:
            assert_close(actual, expected_valid, rtol=tol_val, atol=tol_val)
Example #11
0
    def test_kernel_3x3(self, device, dtype):
        inp = torch.tensor(
            [[[
                [1.0, 1.0, 1.0, 1.0, 1.0],
                [1.0, 1.0, 1.0, 1.0, 1.0],
                [1.0, 1.0, 1.0, 1.0, 1.0],
                [2.0, 2.0, 2.0, 2.0, 2.0],
                [2.0, 2.0, 2.0, 2.0, 2.0],
            ]]],
            device=device,
            dtype=dtype,
        )

        kernel_size = (3, 3)
        actual = kornia.filters.box_blur(inp, kernel_size)

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1,
                                                      1e-4)
        assert_close(actual.sum(),
                     torch.tensor(35.0).to(actual),
                     rtol=tol_val,
                     atol=tol_val)
Example #12
0
    def test_unit(self, device, dtype):
        data = torch.tensor(
            [[
                [
                    [
                        50.21928787, 23.29810143, 14.98279190, 62.50927353,
                        72.78904724
                    ],
                    [
                        70.86846924, 68.75330353, 52.81696701, 76.17090607,
                        88.63134003
                    ],
                    [
                        46.87160873, 72.38699341, 37.71450806, 82.57386780,
                        74.79967499
                    ],
                    [
                        77.33016968, 47.39180374, 61.76217651, 90.83254242,
                        86.96239471
                    ],
                ],
                [
                    [
                        65.81327057, -3.69859719, 0.16971001, 14.86583614,
                        -65.54960632
                    ],
                    [
                        -41.03258133, -19.52661896, 64.16155243, -58.53935242,
                        -71.78411102
                    ],
                    [
                        112.05227661, -60.13330460, 43.07910538, -51.01456833,
                        -58.25787354
                    ],
                    [
                        -62.37575531, 50.88882065, -39.27450943, 17.00958824,
                        -24.93779755
                    ],
                ],
                [
                    [
                        -69.53346252, -73.34986877, -11.47461891, 66.73863220,
                        70.43983459
                    ],
                    [
                        51.92737579, 58.77009583, 45.97863388, 24.44452858,
                        98.81991577
                    ],
                    [
                        -7.60597992, 78.97976685, -69.31867218, 67.33953857,
                        14.28889370
                    ],
                    [
                        92.31149292, -85.91405487, -32.83668518, -23.45091820,
                        69.99038696
                    ],
                ],
            ]],
            device=device,
            dtype=dtype,
        )

        # Reference output generated using skimage: lab2rgb(data)
        expected = torch.tensor(
            [[
                [
                    [0.63513142, 0.0, 0.10660624, 0.79048697, 0.26823414],
                    [
                        0.48903025, 0.64529494, 0.91140099, 0.15877841,
                        0.45987959
                    ],
                    [1.0, 0.36069696, 0.29236125, 0.55744393, 0.0],
                    [0.41710863, 0.3198324, 0.0, 0.94256868, 0.82748892],
                ],
                [
                    [
                        0.28210726, 0.26080003, 0.15027717, 0.54540429,
                        0.80323837
                    ],
                    [0.748392, 0.68774842, 0.24204415, 0.83695682, 0.9902132],
                    [0.0, 0.79101603, 0.26633725, 0.89223337, 0.82301254],
                    [
                        0.84857086, 0.34455393, 0.66555314, 0.86168397,
                        0.8948667
                    ],
                ],
                [
                    [
                        0.94172458, 0.66390044, 0.21043296, 0.02453515,
                        0.04169043
                    ],
                    [0.28233233, 0.20235374, 0.19803933, 0.55069441, 0.0],
                    [0.50205101, 0.0, 0.79745394, 0.25376936, 0.6114783],
                    [0.0, 1.0, 0.80867314, 1.0, 0.28778443],
                ],
            ]],
            device=device,
            dtype=dtype,
        )

        expected_unclipped = torch.tensor(
            [[
                [
                    [
                        0.63513142, -1.78708635, 0.10660624, 0.79048697,
                        0.26823414
                    ],
                    [
                        0.48903025, 0.64529494, 0.91140099, 0.15877841,
                        0.45987959
                    ],
                    [
                        1.01488435, 0.36069696, 0.29236125, 0.55744393,
                        -0.28090181
                    ],
                    [
                        0.41710863, 0.3198324, -1.81087917, 0.94256868,
                        0.82748892
                    ],
                ],
                [
                    [
                        0.28210726, 0.26080003, 0.15027717, 0.54540429,
                        0.80323837
                    ],
                    [0.748392, 0.68774842, 0.24204415, 0.83695682, 0.9902132],
                    [
                        -1.37862046, 0.79101603, 0.26633725, 0.89223337,
                        0.82301254
                    ],
                    [
                        0.84857086, 0.34455393, 0.66555314, 0.86168397,
                        0.8948667
                    ],
                ],
                [
                    [
                        0.94172458, 0.66390044, 0.21043296, 0.02453515,
                        0.04169043
                    ],
                    [
                        0.28233233, 0.20235374, 0.19803933, 0.55069441,
                        -0.62707704
                    ],
                    [
                        0.50205101, -0.25005965, 0.79745394, 0.25376936,
                        0.6114783
                    ],
                    [
                        -0.55802926, 1.0223477, 0.80867314, 1.07334156,
                        0.28778443
                    ],
                ],
            ]],
            device=device,
            dtype=dtype,
        )

        tol_val: float = utils._get_precision_by_name(device, 'xla', 1e-1,
                                                      1e-4)
        assert_allclose(kornia.color.lab_to_rgb(data),
                        expected,
                        rtol=tol_val,
                        atol=tol_val)
        assert_allclose(kornia.color.lab_to_rgb(data, clip=False),
                        expected_unclipped,
                        rtol=tol_val,
                        atol=tol_val)