def __call__(self, img: np.ndarray) -> np.ndarray: """ Args: img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...]. Returns: np.ndarray containing envelope of data in img along the specified axis. """ # add one to transform axis because a batch axis will be added at dimension 0 hilbert_transform = HilbertTransform(self.axis + 1, self.n) # convert to Tensor and add Batch axis expected by HilbertTransform input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0) return np.abs(hilbert_transform(input_data).squeeze(0).numpy())
def test_no_fft_module_error(self): self.assertRaises(OptionalImportError, HilbertTransform(), torch.randn(1, 1, 10))
def test_invalid_pytorch_error(self): with self.assertRaisesRegex(InvalidPyTorchVersionError, "version"): HilbertTransform()
def test_value(self, arguments, image, expected_data, atol): result = HilbertTransform(**arguments)(image) result = result.squeeze(0).squeeze(0).cpu().numpy() np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol)