def get_test_devices() -> Dict[str, torch.device]: """Creates a dictionary with the devices to test the source code. CUDA devices will be test only in case the current hardware supports it. Return: dict(str, torch.device): list with devices names. """ devices: Dict[str, torch.device] = {} devices["cpu"] = torch.device("cpu") if torch.cuda.is_available(): devices["cuda"] = torch.device("cuda:0") if kornia.xla_is_available(): import torch_xla.core.xla_model as xm devices["tpu"] = xm.xla_device() return devices
class TestPosterize(BaseTester): f = kornia.enhance.posterize def test_smoke(self, device, dtype): B, C, H, W = 2, 3, 4, 5 img = torch.rand(B, C, H, W, device=device, dtype=dtype) assert isinstance(TestPosterize.f(img, 8), torch.Tensor) @pytest.mark.parametrize( "batch_size, height, width, bits", [ (1, 4, 5, 8), (2, 4, 5, 1), (2, 4, 5, 0), (1, 4, 5, torch.tensor(8)), (2, 4, 5, torch.tensor(8)), (2, 4, 5, torch.tensor([0, 8])), (3, 4, 5, torch.tensor([0, 1, 8])), ], ) @pytest.mark.parametrize("channels", [1, 3, 5]) def test_cardinality(self, batch_size, channels, height, width, bits, device, dtype): inputs = torch.ones(batch_size, channels, height, width, device=device, dtype=dtype) assert TestPosterize.f(inputs, bits).shape == torch.Size( [batch_size, channels, height, width]) # TODO(jian): add better assertions def test_exception(self, device, dtype): img = torch.ones(2, 3, 4, 5, device=device, dtype=dtype) with pytest.raises(TypeError): assert TestPosterize.f([1.0], 0.0) with pytest.raises(TypeError): assert TestPosterize.f(img, 1.0) # TODO(jian): add better cases @pytest.mark.skipif(kornia.xla_is_available(), reason="issues with xla device") def test_value(self, device, dtype): torch.manual_seed(0) inputs = torch.rand(1, 1, 3, 3).to(device=device, dtype=dtype) # Output generated is similar (1e-2 due to the uint8 conversions) to the below output: # img = PIL.Image.fromarray((255*inputs[0,0]).byte().numpy()) # en = ImageOps.posterize(img, 1) # np.array(en) / 255. expected = torch.tensor( [[[[0.0, 0.50196078, 0.0], [0.0, 0.0, 0.50196078], [0.0, 0.50196078, 0.0]]]], device=device, dtype=dtype) assert_allclose(TestPosterize.f(inputs, 1), expected) assert_allclose(TestPosterize.f(inputs, 0), torch.zeros_like(inputs)) assert_allclose(TestPosterize.f(inputs, 8), inputs) @pytest.mark.skip(reason="IndexError: tuple index out of range") @pytest.mark.grad def test_gradcheck(self, device, dtype): bs, channels, height, width = 2, 3, 4, 5 inputs = torch.rand(bs, channels, height, width, device=device, dtype=dtype) inputs = tensor_to_gradcheck_var(inputs) assert gradcheck(TestPosterize.f, (inputs, 0), raise_exception=True) # TODO: implement me @pytest.mark.skip(reason="union type input") @pytest.mark.jit def test_jit(self, device, dtype): op = torch.jit.script(kornia.enhance.adjust.posterize) inputs = torch.rand(2, 1, 3, 3).to(device=device, dtype=dtype) expected = op(input, 8) actual = op_script(input, 8) assert_allclose(actual, expected) # TODO: implement me @pytest.mark.skip(reason="Not having it yet.") @pytest.mark.nn def test_module(self, device, dtype): img = torch.ones(2, 3, 4, 4, device=device, dtype=dtype)
def test_jit(self, device, dtype): op = torch.jit.script(kornia.enhance.adjust.sharpness) inputs = torch.rand(2, 1, 3, 3).to(device=device, dtype=dtype) expected = op(input, 0.8) actual = op_script(input, 0.8) assert_allclose(actual, expected) @pytest.mark.skip(reason="Not having it yet.") @pytest.mark.nn def test_module(self, device, dtype): img = torch.ones(2, 3, 4, 4, device=device, dtype=dtype) # gray_ops = kornia.enhance.sharpness().to(device, dtype) # assert_allclose(gray_ops(img), f(img)) @pytest.mark.skipif(kornia.xla_is_available(), reason="issues with xla device") class TestSolarize(BaseTester): f = kornia.enhance.solarize def test_smoke(self, device, dtype): B, C, H, W = 2, 3, 4, 5 img = torch.rand(B, C, H, W, device=device, dtype=dtype) assert isinstance(TestSolarize.f(img, 0.8), torch.Tensor) @pytest.mark.parametrize( "batch_size, height, width, thresholds, additions", [ (1, 4, 5, 0.8, None), (1, 4, 5, 0.8, 0.4), (2, 4, 5, 0.8, None),