def test_partial_fourier_3d(xp, dtype, pf_fractions): shape = (128, 128, 64) ig = ImageGeometry(shape, distances=(1, ) * len(shape), offsets="dsp") obj = mri_object_3d(ig.fov) coords = ig.fgrid() # fully-sampled k-space kspace_full = xp.asarray(obj.kspace(*coords), dtype=dtype) # partial-Fourier k-space nkeep = [int(f * s) for f, s in zip(pf_fractions, shape)] pf_mask = xp.zeros(shape, dtype=xp.bool) pf_mask[:nkeep[0], :nkeep[1]] = 1 kspace_pf = kspace_full[:nkeep[0], :nkeep[1]] # direct reconstruction using zero-filled k-space direct_recon = ifftnc(kspace_full * pf_mask) # partial Fourier reconstruction pf_recon = mri_partial_fourier_nd(kspace_pf, pf_mask) # dtype is preserved assert pf_recon.dtype == dtype # ground truth image # x_true = xp.asarray(obj.image(*ig.grid())) # recon from fully sampled k-space x_full = xp.asarray(ifftnc(kspace_full)) # Error of partial-Fourier recon should be much less than for zero-filling mse_pf = xp.mean(xp.abs(x_full - pf_recon)**2) mse_direct = xp.mean(xp.abs(x_full - direct_recon)**2) assert mse_pf < 0.25 * mse_direct
def test_mri_object_3d(dtype): obj = mri_object_3d(fov=(120, 120, 40), units="mm", dtype=dtype) ig = ImageGeometry((240, 240, 80), distances=(0.5, 0.5, 0.5)) img = obj.image(*ig.grid()) ksp = obj.kspace(*ig.fgrid()) / np.prod(ig.distances) norm_err = np.linalg.norm(img - ifftnc(ksp)) / np.linalg.norm(img) assert norm_err < 0.2
def test_mri_object_1d(dtype): obj = mri_object_1d(fov=(240, ), units="mm", dtype=dtype) ig = ImageGeometry((240, ), distances=(1.0, )) img = obj.image(*ig.grid()) ksp = obj.kspace(*ig.fgrid()) / ig.distances[0] norm_err = np.linalg.norm(img - ifftnc(ksp)) / np.linalg.norm(img) assert norm_err < 0.1
def test_fftnc_ifftnc(xp, norm, axes, pre_shift, post_shift, dtype): rstate = xp.random.RandomState(1234) # test with at least one odd-sized axis shape = (15, 32, 18) x = rstate.standard_normal(shape).astype(dtype, copy=False) if x.dtype.kind == "c": x += 1j * rstate.standard_normal(shape).astype(dtype, copy=False) if dtype in [np.float32, np.complex64]: rtol = atol = 1e-4 else: rtol = atol = 1e-12 # default shift behavior is to shift all axes that are transformed if pre_shift is None: expected_pre_shift = axes else: expected_pre_shift = pre_shift if post_shift is None: expected_post_shift = axes else: expected_post_shift = post_shift x_cpu = x if xp is np else x.get() # compare to numpy.fft.fftn result on CPU expected_result = fftshift( np.fft.fftn(ifftshift(x_cpu, axes=expected_pre_shift), axes=axes, norm=norm), axes=expected_post_shift, ) # Verify agreement with expected numpy.fft.fftn result res = fftnc( x, axes=axes, norm=norm, pre_shift_axes=pre_shift, post_shift_axes=post_shift, ) xp.testing.assert_allclose(res, expected_result, rtol=rtol, atol=atol) # Verify correct round trip # note: swap pre/post shift vars relative to the forward transform rec = ifftnc( res, axes=axes, norm=norm, pre_shift_axes=post_shift, post_shift_axes=pre_shift, ) xp.testing.assert_allclose(x, rec, rtol=rtol, atol=atol)
def test_partial_FFT_with_im_mask(xp, nd_in, order, shift): """ masked FFT with missing samples and masked image domain """ c = get_data(xp) rstate = xp.random.RandomState(1234) sample_mask = rstate.rand(*(128, 127)) > 0.5 x, y = xp.meshgrid( xp.arange(-c.shape[0] // 2, c.shape[0] // 2), xp.arange(-c.shape[1] // 2, c.shape[1] // 2), indexing="ij", sparse=True, ) # make a circular mask im_mask = xp.sqrt(x * x + y * y) < c.shape[0] // 2 nd_out = False FTop = FFT_Operator( c.shape, order=order, im_mask=im_mask, use_fft_shifts=shift, nd_input=nd_in, nd_output=nd_out, sample_mask=sample_mask, gpu_force_reinit=False, mask_kspace_on_gpu=(not shift), **get_loc(xp), ) # create new linear operator for forward followed by inverse transform FtF = FTop.H * FTop assert isinstance(FtF, LinearOperatorMulti) # test forward only forw = embed(FTop * masker(c, im_mask, order=order), sample_mask, order=order) if shift: expected_forw = sample_mask * fftnc(c * im_mask) else: expected_forw = sample_mask * fftn(c * im_mask) xp.testing.assert_allclose(forw, expected_forw, rtol=1e-7, atol=1e-4) # test roundtrip roundtrip = FTop.H * (FTop * masker(c, im_mask, order=order)) if shift: expected_roundtrip = masker(ifftnc(sample_mask * fftnc(c * im_mask)), im_mask, order=order) else: expected_roundtrip = masker(ifftn(sample_mask * fftn(c * im_mask)), im_mask, order=order) xp.testing.assert_allclose(roundtrip, expected_roundtrip, rtol=1e-7, atol=1e-4) # test roundtrip with 2 reps c2 = xp.stack([c] * 2, axis=-1) roundtrip = FTop.H * (FTop * masker(c2, im_mask, order=order)) if shift: expected_roundtrip = masker( ifftnc( sample_mask[..., xp.newaxis] * fftnc(c2 * im_mask[..., xp.newaxis], axes=(0, 1)), axes=(0, 1), ), im_mask, order=order, ) else: expected_roundtrip = masker( ifftn( sample_mask[..., xp.newaxis] * fftn(c2 * im_mask[..., xp.newaxis], axes=(0, 1)), axes=(0, 1), ), im_mask, order=order, ) xp.testing.assert_allclose(roundtrip, expected_roundtrip, rtol=1e-7, atol=1e-4)
def mri_partial_fourier_nd( partial_kspace, pf_mask, niter=5, tw_inner=8, tw_outer=3, fill_conj=False, init=None, verbose=False, show=False, return_all_estimates=False, xp=None, ): """Partial Fourier reconstruction. Parameters ---------- Returns ------- Notes ----- The implementation is based on a multi-dimensional iterative reconstruction technique as described in [1]_. This is an extension of the 1D iterative methods described in [2]_ and [3]_. The concept of partial Fourier imaging was first proposed in [4]_, [5]_. References ---------- .. [1] Xu, Y. and Haacke, E. M. Partial Fourier imaging in multi-dimensions: A means to save a full factor of two in time. J. Magn. Reson. Imaging, 2001; 14:628–635. doi:10.1002/jmri.1228 .. [2] Haacke, E.; Lindskogj, E. & Lin, W. A fast, iterative, partial-Fourier technique capable of local phase recovery. Journal of Magnetic Resonance, 1991; 92:126-145. .. [3] Liang, Z.-P.; Boada, F.; Constable, R. T.; Haacke, E.; Lauterbur, P. & Smith, M. Constrained reconstruction methods in MR imaging. Rev Magn Reson Med, 1992; 4: 67-185 .. [4] Margosian, P.; Schmitt, F. & Purdy, D. Faster MR imaging: imaging with half the data Health Care Instrum, 1986; 1:195. .. [5] Feinberg, D. A.; Hale, J. D.; Watts, J. C.; Kaufman, L. & Mark, A. Halving MR imaging time by conjugation: demonstration at 3.5 kG. Radiology, 1986, 161, 527-531 """ xp, on_gpu = get_array_module(partial_kspace, xp) partial_kspace = xp.asarray(partial_kspace) # dtype = partial_kspace.dtype pf_mask = xp.asarray(pf_mask) if pf_mask.dtype != xp.bool: pf_mask = pf_mask.astype(xp.bool) im_shape = pf_mask.shape ndim = pf_mask.ndim if not xp.all(xp.asarray(im_shape) % 2 == 0): raise ValueError( "This function assumes all k-space dimensions have even length.") if partial_kspace.size != xp.count_nonzero(pf_mask): raise ValueError( "partial kspace should have total size equal to the number of " "non-zeros in pf_mask") kspace_init = embed(partial_kspace, pf_mask) img_est = ifftnc(kspace_init) lr_kspace = xp.zeros_like(kspace_init) nz = xp.where(pf_mask) lr_slices = [slice(None)] * ndim pf_slices = [slice(None)] * ndim win2_slices = [slice(None)] * ndim lr_shape = [0] * ndim # pf_shape = [0, ]*ndim win2_shape = [0] * ndim for d in range(ndim): nz_min = xp.min(nz[d]) nz_max = xp.max(nz[d]) if hasattr(nz_min, "get"): # 0-dim GPU array to scalar nz_min, nz_max = nz_min.get(), nz_max.get() i_mid = im_shape[d] // 2 if nz_min == 0: i_end = nz_max width = i_end - i_mid else: i_start = nz_min width = i_mid - i_start lr_slices[d] = slice(i_mid - width, i_mid + width + 1) lr_shape[d] = 2 * width + 1 # pf_slices[d] = slice(nz_min, nz_max + 1) pf_shape = nz_max - nz_min + 1 pf_slices[d] = slice(nz_min, nz_max + 1) win2_shape[d] = pf_shape + tw_outer if nz_min == 0: # pf_slices[d] = slice(nz_min, nz_max + 1 + tw_outer) win2_slices[d] = slice(tw_outer, tw_outer + pf_shape) else: # pf_slices[d] = slice(nz_min - tw_outer, nz_max + 1) win2_slices[d] = slice(pf_shape) lr_slices = tuple(lr_slices) win2_slices = tuple(win2_slices) pf_slices = tuple(pf_slices) # lr_mask = xp.zeros(pf_mask, dtype=xp.zeros) lr_win = hanning_apodization_window(lr_shape, tw_inner, xp) lr_kspace[lr_slices] = kspace_init[lr_slices] * lr_win img_lr = ifftnc(lr_kspace) phi = xp.angle(img_lr) pf_win = hanning_apodization_window(win2_shape, tw_outer, xp)[win2_slices] lr_mask = xp.zeros(pf_mask.shape, dtype=xp.float32) lr_mask[lr_slices] = lr_win win2_mask = xp.zeros(pf_mask.shape, dtype=xp.float32) win2_mask[pf_slices] = pf_win if show and ndim == 2: from matplotlib import pyplot as plt fig, axes = plt.subplots(2, 2) axes = axes.ravel() axes[0].imshow(pf_mask) axes[0].set_title("PF mask") axes[1].imshow(lr_mask) axes[1].set_title("LR Filter") axes[2].imshow(win2_mask) axes[2].set_title("Filter") axes[3].imshow(xp.abs(img_est).T) axes[3].set_title("Initial Estimate") for ax in axes: ax.set_xticklabels("") ax.set_yticklabels("") if verbose: norm0 = xp.linalg.norm(img_est) max0 = xp.max(xp.abs(img_est)) if return_all_estimates: all_img_est = [img_est] # POCS iterations for i in range(niter): # step 5 rho1 = xp.abs(img_est) * xp.exp(1j * phi) if verbose: change2 = xp.linalg.norm(rho1 - img_est) / norm0 change1 = xp.max(xp.abs(rho1 - img_est)) / max0 print("change = {}%% {}%%".format(100 * change2, 100 * change1)) # step 6 s1 = fftnc(rho1) # steps 7 & 8 full_kspace = win2_mask * kspace_init + (1 - win2_mask) * s1 # step 9 img_est = ifftnc(full_kspace) if return_all_estimates: all_img_est.append(img_est) if return_all_estimates: return all_img_est return img_est