Esempio n. 1
0
def get_test_devices() -> Dict[str, torch.device]:
    """Creates a dictionary with the devices to test the source code.
    CUDA devices will be test only in case the current hardware supports it.

    Return:
        dict(str, torch.device): list with devices names.
    """
    devices: Dict[str, torch.device] = {}
    devices["cpu"] = torch.device("cpu")
    if torch.cuda.is_available():
        devices["cuda"] = torch.device("cuda:0")
    if kornia.xla_is_available():
        import torch_xla.core.xla_model as xm
        devices["tpu"] = xm.xla_device()
    return devices
Esempio n. 2
0
class TestPosterize(BaseTester):

    f = kornia.enhance.posterize

    def test_smoke(self, device, dtype):
        B, C, H, W = 2, 3, 4, 5
        img = torch.rand(B, C, H, W, device=device, dtype=dtype)
        assert isinstance(TestPosterize.f(img, 8), torch.Tensor)

    @pytest.mark.parametrize(
        "batch_size, height, width, bits",
        [
            (1, 4, 5, 8),
            (2, 4, 5, 1),
            (2, 4, 5, 0),
            (1, 4, 5, torch.tensor(8)),
            (2, 4, 5, torch.tensor(8)),
            (2, 4, 5, torch.tensor([0, 8])),
            (3, 4, 5, torch.tensor([0, 1, 8])),
        ],
    )
    @pytest.mark.parametrize("channels", [1, 3, 5])
    def test_cardinality(self, batch_size, channels, height, width, bits,
                         device, dtype):
        inputs = torch.ones(batch_size,
                            channels,
                            height,
                            width,
                            device=device,
                            dtype=dtype)
        assert TestPosterize.f(inputs, bits).shape == torch.Size(
            [batch_size, channels, height, width])

    # TODO(jian): add better assertions
    def test_exception(self, device, dtype):
        img = torch.ones(2, 3, 4, 5, device=device, dtype=dtype)

        with pytest.raises(TypeError):
            assert TestPosterize.f([1.0], 0.0)

        with pytest.raises(TypeError):
            assert TestPosterize.f(img, 1.0)

    # TODO(jian): add better cases
    @pytest.mark.skipif(kornia.xla_is_available(),
                        reason="issues with xla device")
    def test_value(self, device, dtype):
        torch.manual_seed(0)

        inputs = torch.rand(1, 1, 3, 3).to(device=device, dtype=dtype)

        # Output generated is similar (1e-2 due to the uint8 conversions) to the below output:
        # img = PIL.Image.fromarray((255*inputs[0,0]).byte().numpy())
        # en = ImageOps.posterize(img, 1)
        # np.array(en) / 255.
        expected = torch.tensor(
            [[[[0.0, 0.50196078, 0.0], [0.0, 0.0, 0.50196078],
               [0.0, 0.50196078, 0.0]]]],
            device=device,
            dtype=dtype)

        assert_allclose(TestPosterize.f(inputs, 1), expected)
        assert_allclose(TestPosterize.f(inputs, 0), torch.zeros_like(inputs))
        assert_allclose(TestPosterize.f(inputs, 8), inputs)

    @pytest.mark.skip(reason="IndexError: tuple index out of range")
    @pytest.mark.grad
    def test_gradcheck(self, device, dtype):
        bs, channels, height, width = 2, 3, 4, 5
        inputs = torch.rand(bs,
                            channels,
                            height,
                            width,
                            device=device,
                            dtype=dtype)
        inputs = tensor_to_gradcheck_var(inputs)
        assert gradcheck(TestPosterize.f, (inputs, 0), raise_exception=True)

    # TODO: implement me
    @pytest.mark.skip(reason="union type input")
    @pytest.mark.jit
    def test_jit(self, device, dtype):
        op = torch.jit.script(kornia.enhance.adjust.posterize)
        inputs = torch.rand(2, 1, 3, 3).to(device=device, dtype=dtype)
        expected = op(input, 8)
        actual = op_script(input, 8)
        assert_allclose(actual, expected)

    # TODO: implement me
    @pytest.mark.skip(reason="Not having it yet.")
    @pytest.mark.nn
    def test_module(self, device, dtype):
        img = torch.ones(2, 3, 4, 4, device=device, dtype=dtype)
Esempio n. 3
0
    def test_jit(self, device, dtype):
        op = torch.jit.script(kornia.enhance.adjust.sharpness)
        inputs = torch.rand(2, 1, 3, 3).to(device=device, dtype=dtype)
        expected = op(input, 0.8)
        actual = op_script(input, 0.8)
        assert_allclose(actual, expected)

    @pytest.mark.skip(reason="Not having it yet.")
    @pytest.mark.nn
    def test_module(self, device, dtype):
        img = torch.ones(2, 3, 4, 4, device=device, dtype=dtype)
        # gray_ops = kornia.enhance.sharpness().to(device, dtype)
        # assert_allclose(gray_ops(img), f(img))


@pytest.mark.skipif(kornia.xla_is_available(), reason="issues with xla device")
class TestSolarize(BaseTester):

    f = kornia.enhance.solarize

    def test_smoke(self, device, dtype):
        B, C, H, W = 2, 3, 4, 5
        img = torch.rand(B, C, H, W, device=device, dtype=dtype)
        assert isinstance(TestSolarize.f(img, 0.8), torch.Tensor)

    @pytest.mark.parametrize(
        "batch_size, height, width, thresholds, additions",
        [
            (1, 4, 5, 0.8, None),
            (1, 4, 5, 0.8, 0.4),
            (2, 4, 5, 0.8, None),