def test_modulus(shape): shape = shape + [2] data = create_input(shape) out_torch = transforms.modulus(data).numpy() input_numpy = tensor_to_complex_numpy(data) out_numpy = np.abs(input_numpy) assert np.allclose(out_torch, out_numpy)
def test_conjugate(shape): data = np.arange(np.product(shape)).reshape( shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1) torch_tensor = transforms.to_tensor(data) torch_tensor = add_names(torch_tensor, named=True) out_torch = tensor_to_complex_numpy(transforms.conjugate(torch_tensor)) out_numpy = np.conjugate(data) assert np.allclose(out_torch, out_numpy)
def test_root_sum_of_squares_complex(shape, dims): shape = shape + [ 2, ] data = create_input(shape, named=True) # noqa out_torch = transforms.root_sum_of_squares(data, dims).numpy() input_numpy = tensor_to_complex_numpy(data) out_numpy = np.sqrt( np.sum(np.abs(input_numpy)**2, dims if not dims == "coils" else 0)) assert np.allclose(out_torch, out_numpy)
def test_complex_multiplication(shape): data_0 = np.arange(np.product(shape)).reshape( shape) + 1j * (np.arange(np.product(shape)).reshape(shape) + 1) data_1 = data_0 + 0.5 + 1j torch_tensor_0 = transforms.to_tensor(data_0) torch_tensor_1 = transforms.to_tensor(data_1) torch_tensor_0 = add_names(torch_tensor_0, named=True) torch_tensor_1 = add_names(torch_tensor_1, named=True) out_torch = tensor_to_complex_numpy( transforms.complex_multiplication(torch_tensor_0, torch_tensor_1)) out_numpy = data_0 * data_1 assert np.allclose(out_torch, out_numpy)
def test_ifft2(shape, named): shape = shape + [2] data = create_input(shape, named=named) if named: dim = ("height", "width") else: dim = (-2, -1) out_torch = transforms.ifft2(data, dim=dim).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_numpy(data) input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) out_numpy = np.fft.ifft2(input_numpy, norm="ortho") out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) assert np.allclose(out_torch, out_numpy)
def test_fft2(shape, named): shape = shape + [2] data = create_input(shape, named=named) if named: dim = ("height", "width") else: dim = (-3, -2) out_torch = transforms.fft2(data, dim=dim).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] data_numpy = tensor_to_complex_numpy(data) data_numpy = np.fft.ifftshift(data_numpy, (-2, -1)) out_numpy = np.fft.fft2(data_numpy, norm="ortho") out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) z = out_torch - out_numpy print(z.real.max(), z.real.min(), z.imag.max(), z.imag.min()) assert np.allclose(out_torch, out_numpy)
import numpy as np import torch from direct.data.transforms import tensor_to_complex_numpy from direct.nn.rim.mri_models import MRILogLikelihood from direct.data import transforms input_image = create_input([1, 4, 4, 2]).rename("batch", "height", "width", "complex") sensitivity_map = create_input([1, 15, 4, 4, 2]) * 0.1 masked_kspace = create_input([1, 15, 4, 4, 2]) + 0.33 sampling_mask = torch.from_numpy( np.random.binomial(size=(1, 1, 4, 4, 1), n=1, p=0.5)).refine_names(*sensitivity_map.names) input_image_numpy = tensor_to_complex_numpy(input_image) sensitivity_map_numpy = tensor_to_complex_numpy(sensitivity_map) masked_kspace_numpy = tensor_to_complex_numpy(masked_kspace) sampling_mask_numpy = sampling_mask.numpy()[..., 0] # Torch input_image = input_image.align_to("batch", "height", "width", "complex") sensitivity_map = sensitivity_map.align_to("batch", "coil", "height", "width", "complex") masked_kspace = masked_kspace.align_to("batch", "coil", "height", "width", "complex") mul = transforms.complex_multiplication(sensitivity_map, input_image.align_as(sensitivity_map)) mul_names = mul.names