コード例 #1
0
    def __call__(self, img: np.ndarray) -> np.ndarray:
        """

        Args:
            img: numpy.ndarray containing input data. Must be real and in shape [channels, spatial1, spatial2, ...].

        Returns:
            np.ndarray containing envelope of data in img along the specified axis.

        """
        # add one to transform axis because a batch axis will be added at dimension 0
        hilbert_transform = HilbertTransform(self.axis + 1, self.n)
        # convert to Tensor and add Batch axis expected by HilbertTransform
        input_data = torch.as_tensor(np.ascontiguousarray(img)).unsqueeze(0)
        return np.abs(hilbert_transform(input_data).squeeze(0).numpy())
コード例 #2
0
 def test_no_fft_module_error(self):
     self.assertRaises(OptionalImportError, HilbertTransform(),
                       torch.randn(1, 1, 10))
コード例 #3
0
 def test_invalid_pytorch_error(self):
     with self.assertRaisesRegex(InvalidPyTorchVersionError, "version"):
         HilbertTransform()
コード例 #4
0
 def test_value(self, arguments, image, expected_data, atol):
     result = HilbertTransform(**arguments)(image)
     result = result.squeeze(0).squeeze(0).cpu().numpy()
     np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol)