def test_ifft2(shape, named): shape = shape + [2] data = create_input(shape, named=named) if named: dim = ("height", "width") else: dim = (-2, -1) out_torch = transforms.ifft2(data, dim=dim).numpy() out_torch = out_torch[..., 0] + 1j * out_torch[..., 1] input_numpy = tensor_to_complex_numpy(data) input_numpy = np.fft.ifftshift(input_numpy, (-2, -1)) out_numpy = np.fft.ifft2(input_numpy, norm="ortho") out_numpy = np.fft.fftshift(out_numpy, (-2, -1)) assert np.allclose(out_torch, out_numpy)
mul_names = mul.names mr_forward = torch.where( sampling_mask.rename(None) == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), transforms.fft2(mul).rename(None), ) error = mr_forward - torch.where( sampling_mask.rename(None) == 0, torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device), masked_kspace.rename(None), ) error = error.refine_names(*mul_names) mr_backward = transforms.ifft2(error) out = transforms.complex_multiplication(transforms.conjugate(sensitivity_map), mr_backward).sum("coil") # numpy # mul_numpy = sensitivity_map_numpy * input_image_numpy # mr_forward_numpy = sampling_mask_numpy * numpy_fft(mul_numpy) # error_numpy = mr_forward_numpy - sampling_mask_numpy * masked_kspace_numpy # mr_backward_numpy = numpy_ifft(error_numpy) # out_numpy = (sensitivity_map_numpy.conjugate() * mr_backward_numpy).sum(1) # np.allclose(tensor_to_complex_numpy(out), out_numpy) # numpy 2 mr_backward_numpy = numpy_ifft(