def test_optional_params(self, clip, grid, device, dtype): C, H, W = 1, 10, 20 img = torch.rand(C, H, W, device=device, dtype=dtype) if clip is None: res = enhance.equalize_clahe(img, grid_size=grid) elif grid is None: res = enhance.equalize_clahe(img, clip_limit=clip) else: res = enhance.equalize_clahe(img, clip, grid) assert isinstance(res, torch.Tensor) assert res.shape == img.shape
def test_cardinality(self, B, C, device, dtype): H, W = 10, 20 if B is None: img = torch.rand(C, H, W, device=device, dtype=dtype) else: img = torch.rand(B, C, H, W, device=device, dtype=dtype) res = enhance.equalize_clahe(img) assert res.shape == img.shape
def test_smoke(self, device, dtype): C, H, W = 1, 10, 20 img = torch.rand(C, H, W, device=device, dtype=dtype) res = enhance.equalize_clahe(img) assert isinstance(res, torch.Tensor) assert res.shape == img.shape assert res.device == img.device assert res.dtype == img.dtype
def test_clahe(self, img): clip_limit: float = 2.0 grid_size: Tuple = (8, 8) res = enhance.equalize_clahe(img, clip_limit=clip_limit, grid_size=grid_size) res_diff = enhance.equalize_clahe(img, clip_limit=clip_limit, grid_size=grid_size, slow_and_differentiable=True) # NOTE: for next versions we need to improve the computation of the LUT # and test with a better image expected = torch.tensor([[[ 0.1216, 0.8745, 0.9373, 0.9163, 0.8745, 0.8745, 0.9373, 0.8745, 0.8745, 0.8118, 0.9373, 0.8745, 0.8745, 0.8118, 0.8745, 0.8745, 0.8327, 0.8118, 0.8745, 1.0000 ]]], dtype=res.dtype, device=res.device) exp_diff = torch.tensor([[[ 0.1250, 0.8752, 0.9042, 0.9167, 0.8401, 0.8852, 0.9302, 0.9120, 0.8750, 0.8370, 0.9620, 0.9077, 0.8750, 0.8754, 0.9204, 0.9167, 0.8370, 0.8806, 0.9096, 1.0000 ]]], dtype=res.dtype, device=res.device) assert torch.allclose( res[..., 0, :], expected, atol=1e-04, rtol=1e-04, ) assert torch.allclose( res_diff[..., 0, :], exp_diff, atol=1e-04, rtol=1e-04, )
def test_clahe(self, img): clip_limit: float = 2. grid_size: Tuple = (8, 8) res = enhance.equalize_clahe(img, clip_limit=clip_limit, grid_size=grid_size) # NOTE: for next versions we need to improve the computation of the LUT # and test with a better image assert torch.allclose( res[..., 0, :], torch.tensor([[[ 0.1216, 0.8745, 0.9373, 0.9137, 0.8745, 0.8745, 0.9373, 0.8745, 0.8745, 0.8118, 0.9373, 0.8745, 0.8745, 0.8118, 0.8745, 0.8745, 0.8314, 0.8118, 0.8745, 1.0000 ]]], dtype=res.dtype, device=res.device), atol=1e-04, rtol=1e-04)
def test_ahe(self, img): clip_limit: float = 0. grid_size: Tuple = (8, 8) res = enhance.equalize_clahe(img, clip_limit=clip_limit, grid_size=grid_size) # NOTE: for next versions we need to improve the computation of the LUT # and test with a better image assert torch.allclose( res[..., 0, :], torch.tensor([[[ 0.2471, 0.4980, 0.7490, 0.6667, 0.4980, 0.4980, 0.7490, 0.4980, 0.4980, 0.2471, 0.7490, 0.4980, 0.4980, 0.2471, 0.4980, 0.4980, 0.3333, 0.2471, 0.4980, 1.0000 ]]], dtype=res.dtype, device=res.device), atol=1e-04, rtol=1e-04)
def test_he(self, img): # should be similar to enhance.equalize but slower. Similar because the lut is computed in a different way. clip_limit: float = 0. grid_size: Tuple = (1, 1) res = enhance.equalize_clahe(img, clip_limit=clip_limit, grid_size=grid_size) # NOTE: for next versions we need to improve the computation of the LUT # and test with a better image assert torch.allclose( res[..., 0, :], torch.tensor([[[ 0.0471, 0.0980, 0.1490, 0.2000, 0.2471, 0.2980, 0.3490, 0.3490, 0.4471, 0.4471, 0.5490, 0.5490, 0.6471, 0.6471, 0.6980, 0.7490, 0.8000, 0.8471, 0.8980, 1.0000 ]]], dtype=res.dtype, device=res.device), atol=1e-04, rtol=1e-04)
def grad_rot(input, a, b, c): rot = rotate(input, torch.tensor(30., dtype=input.dtype, device=device)) return enhance.equalize_clahe(rot, a, b, c)
def test_exception_tensor_type(self): with pytest.raises(TypeError): enhance.equalize_clahe([1, 2, 3])
def test_exception_tensor_dims(self, dims): img = torch.rand(dims) with pytest.raises(ValueError): enhance.equalize_clahe(img)
def test_exception(self, B, clip, grid, exception_type): C, H, W = 1, 10, 20 img = torch.rand(B, C, H, W) with pytest.raises(exception_type): enhance.equalize_clahe(img, clip, grid)