コード例 #1
0
def test_partial_fourier_3d(xp, dtype, pf_fractions):
    shape = (128, 128, 64)
    ig = ImageGeometry(shape, distances=(1, ) * len(shape), offsets="dsp")
    obj = mri_object_3d(ig.fov)
    coords = ig.fgrid()

    # fully-sampled k-space
    kspace_full = xp.asarray(obj.kspace(*coords), dtype=dtype)

    # partial-Fourier k-space
    nkeep = [int(f * s) for f, s in zip(pf_fractions, shape)]
    pf_mask = xp.zeros(shape, dtype=xp.bool)
    pf_mask[:nkeep[0], :nkeep[1]] = 1
    kspace_pf = kspace_full[:nkeep[0], :nkeep[1]]

    # direct reconstruction using zero-filled k-space
    direct_recon = ifftnc(kspace_full * pf_mask)

    # partial Fourier reconstruction
    pf_recon = mri_partial_fourier_nd(kspace_pf, pf_mask)
    # dtype is preserved
    assert pf_recon.dtype == dtype

    # ground truth image
    # x_true = xp.asarray(obj.image(*ig.grid()))

    # recon from fully sampled k-space
    x_full = xp.asarray(ifftnc(kspace_full))

    # Error of partial-Fourier recon should be much less than for zero-filling
    mse_pf = xp.mean(xp.abs(x_full - pf_recon)**2)
    mse_direct = xp.mean(xp.abs(x_full - direct_recon)**2)
    assert mse_pf < 0.25 * mse_direct
コード例 #2
0
def test_mri_object_3d(dtype):
    obj = mri_object_3d(fov=(120, 120, 40), units="mm", dtype=dtype)
    ig = ImageGeometry((240, 240, 80), distances=(0.5, 0.5, 0.5))

    img = obj.image(*ig.grid())
    ksp = obj.kspace(*ig.fgrid()) / np.prod(ig.distances)

    norm_err = np.linalg.norm(img - ifftnc(ksp)) / np.linalg.norm(img)
    assert norm_err < 0.2
コード例 #3
0
def test_mri_object_1d(dtype):
    obj = mri_object_1d(fov=(240, ), units="mm", dtype=dtype)
    ig = ImageGeometry((240, ), distances=(1.0, ))

    img = obj.image(*ig.grid())
    ksp = obj.kspace(*ig.fgrid()) / ig.distances[0]

    norm_err = np.linalg.norm(img - ifftnc(ksp)) / np.linalg.norm(img)
    assert norm_err < 0.1
コード例 #4
0
def test_fftnc_ifftnc(xp, norm, axes, pre_shift, post_shift, dtype):
    rstate = xp.random.RandomState(1234)
    # test with at least one odd-sized axis

    shape = (15, 32, 18)
    x = rstate.standard_normal(shape).astype(dtype, copy=False)
    if x.dtype.kind == "c":
        x += 1j * rstate.standard_normal(shape).astype(dtype, copy=False)

    if dtype in [np.float32, np.complex64]:
        rtol = atol = 1e-4
    else:
        rtol = atol = 1e-12

    # default shift behavior is to shift all axes that are transformed
    if pre_shift is None:
        expected_pre_shift = axes
    else:
        expected_pre_shift = pre_shift

    if post_shift is None:
        expected_post_shift = axes
    else:
        expected_post_shift = post_shift

    x_cpu = x if xp is np else x.get()
    # compare to numpy.fft.fftn result on CPU
    expected_result = fftshift(
        np.fft.fftn(ifftshift(x_cpu, axes=expected_pre_shift),
                    axes=axes,
                    norm=norm),
        axes=expected_post_shift,
    )

    # Verify agreement with expected numpy.fft.fftn result
    res = fftnc(
        x,
        axes=axes,
        norm=norm,
        pre_shift_axes=pre_shift,
        post_shift_axes=post_shift,
    )
    xp.testing.assert_allclose(res, expected_result, rtol=rtol, atol=atol)

    # Verify correct round trip
    # note: swap pre/post shift vars relative to the forward transform
    rec = ifftnc(
        res,
        axes=axes,
        norm=norm,
        pre_shift_axes=post_shift,
        post_shift_axes=pre_shift,
    )
    xp.testing.assert_allclose(x, rec, rtol=rtol, atol=atol)
コード例 #5
0
def test_partial_FFT_with_im_mask(xp, nd_in, order, shift):
    """ masked FFT with missing samples and masked image domain """
    c = get_data(xp)
    rstate = xp.random.RandomState(1234)
    sample_mask = rstate.rand(*(128, 127)) > 0.5
    x, y = xp.meshgrid(
        xp.arange(-c.shape[0] // 2, c.shape[0] // 2),
        xp.arange(-c.shape[1] // 2, c.shape[1] // 2),
        indexing="ij",
        sparse=True,
    )

    # make a circular mask
    im_mask = xp.sqrt(x * x + y * y) < c.shape[0] // 2

    nd_out = False
    FTop = FFT_Operator(
        c.shape,
        order=order,
        im_mask=im_mask,
        use_fft_shifts=shift,
        nd_input=nd_in,
        nd_output=nd_out,
        sample_mask=sample_mask,
        gpu_force_reinit=False,
        mask_kspace_on_gpu=(not shift),
        **get_loc(xp),
    )

    # create new linear operator for forward followed by inverse transform
    FtF = FTop.H * FTop
    assert isinstance(FtF, LinearOperatorMulti)

    # test forward only
    forw = embed(FTop * masker(c, im_mask, order=order),
                 sample_mask,
                 order=order)

    if shift:
        expected_forw = sample_mask * fftnc(c * im_mask)
    else:
        expected_forw = sample_mask * fftn(c * im_mask)
    xp.testing.assert_allclose(forw, expected_forw, rtol=1e-7, atol=1e-4)

    # test roundtrip
    roundtrip = FTop.H * (FTop * masker(c, im_mask, order=order))
    if shift:
        expected_roundtrip = masker(ifftnc(sample_mask * fftnc(c * im_mask)),
                                    im_mask,
                                    order=order)
    else:
        expected_roundtrip = masker(ifftn(sample_mask * fftn(c * im_mask)),
                                    im_mask,
                                    order=order)

    xp.testing.assert_allclose(roundtrip,
                               expected_roundtrip,
                               rtol=1e-7,
                               atol=1e-4)

    # test roundtrip with 2 reps
    c2 = xp.stack([c] * 2, axis=-1)
    roundtrip = FTop.H * (FTop * masker(c2, im_mask, order=order))
    if shift:
        expected_roundtrip = masker(
            ifftnc(
                sample_mask[..., xp.newaxis] *
                fftnc(c2 * im_mask[..., xp.newaxis], axes=(0, 1)),
                axes=(0, 1),
            ),
            im_mask,
            order=order,
        )
    else:
        expected_roundtrip = masker(
            ifftn(
                sample_mask[..., xp.newaxis] *
                fftn(c2 * im_mask[..., xp.newaxis], axes=(0, 1)),
                axes=(0, 1),
            ),
            im_mask,
            order=order,
        )
    xp.testing.assert_allclose(roundtrip,
                               expected_roundtrip,
                               rtol=1e-7,
                               atol=1e-4)
コード例 #6
0
def mri_partial_fourier_nd(
    partial_kspace,
    pf_mask,
    niter=5,
    tw_inner=8,
    tw_outer=3,
    fill_conj=False,
    init=None,
    verbose=False,
    show=False,
    return_all_estimates=False,
    xp=None,
):
    """Partial Fourier reconstruction.

    Parameters
    ----------

    Returns
    -------

    Notes
    -----
    The implementation is based on a multi-dimensional iterative reconstruction
    technique as described in [1]_.  This is an extension of the 1D iterative
    methods described in [2]_ and [3]_.  The concept of partial Fourier imaging
    was first proposed in [4]_, [5]_.

    References
    ----------
    .. [1] Xu, Y. and Haacke, E. M.  Partial Fourier imaging in
        multi-dimensions: A means to save a full factor of two in time.
        J. Magn. Reson. Imaging, 2001; 14:628–635.
        doi:10.1002/jmri.1228

    .. [2] Haacke, E.; Lindskogj, E. & Lin, W. A fast, iterative,
        partial-Fourier technique capable of local phase recovery.
        Journal of Magnetic Resonance, 1991; 92:126-145.

    .. [3] Liang, Z.-P.; Boada, F.; Constable, R. T.; Haacke, E.; Lauterbur,
        P. & Smith, M. Constrained reconstruction methods in MR imaging.
        Rev Magn Reson Med, 1992; 4: 67-185

    .. [4] Margosian, P.; Schmitt, F. & Purdy, D. Faster MR imaging: imaging
        with half the data Health Care Instrum, 1986; 1:195.

    .. [5] Feinberg, D. A.; Hale, J. D.; Watts, J. C.; Kaufman, L. & Mark, A.
        Halving MR imaging time by conjugation: demonstration at 3.5 kG.
        Radiology, 1986, 161, 527-531

    """
    xp, on_gpu = get_array_module(partial_kspace, xp)
    partial_kspace = xp.asarray(partial_kspace)
    # dtype = partial_kspace.dtype

    pf_mask = xp.asarray(pf_mask)
    if pf_mask.dtype != xp.bool:
        pf_mask = pf_mask.astype(xp.bool)
    im_shape = pf_mask.shape
    ndim = pf_mask.ndim

    if not xp.all(xp.asarray(im_shape) % 2 == 0):
        raise ValueError(
            "This function assumes all k-space dimensions have even length.")

    if partial_kspace.size != xp.count_nonzero(pf_mask):
        raise ValueError(
            "partial kspace should have total size equal to the number of "
            "non-zeros in pf_mask")

    kspace_init = embed(partial_kspace, pf_mask)
    img_est = ifftnc(kspace_init)

    lr_kspace = xp.zeros_like(kspace_init)

    nz = xp.where(pf_mask)
    lr_slices = [slice(None)] * ndim
    pf_slices = [slice(None)] * ndim
    win2_slices = [slice(None)] * ndim
    lr_shape = [0] * ndim
    # pf_shape = [0, ]*ndim
    win2_shape = [0] * ndim
    for d in range(ndim):
        nz_min = xp.min(nz[d])
        nz_max = xp.max(nz[d])
        if hasattr(nz_min, "get"):
            # 0-dim GPU array to scalar
            nz_min, nz_max = nz_min.get(), nz_max.get()
        i_mid = im_shape[d] // 2
        if nz_min == 0:
            i_end = nz_max
            width = i_end - i_mid
        else:
            i_start = nz_min
            width = i_mid - i_start
        lr_slices[d] = slice(i_mid - width, i_mid + width + 1)
        lr_shape[d] = 2 * width + 1

        # pf_slices[d] = slice(nz_min, nz_max + 1)
        pf_shape = nz_max - nz_min + 1
        pf_slices[d] = slice(nz_min, nz_max + 1)
        win2_shape[d] = pf_shape + tw_outer
        if nz_min == 0:
            # pf_slices[d] = slice(nz_min, nz_max + 1 + tw_outer)
            win2_slices[d] = slice(tw_outer, tw_outer + pf_shape)
        else:
            # pf_slices[d] = slice(nz_min - tw_outer, nz_max + 1)
            win2_slices[d] = slice(pf_shape)

    lr_slices = tuple(lr_slices)
    win2_slices = tuple(win2_slices)
    pf_slices = tuple(pf_slices)
    # lr_mask = xp.zeros(pf_mask, dtype=xp.zeros)
    lr_win = hanning_apodization_window(lr_shape, tw_inner, xp)
    lr_kspace[lr_slices] = kspace_init[lr_slices] * lr_win

    img_lr = ifftnc(lr_kspace)
    phi = xp.angle(img_lr)

    pf_win = hanning_apodization_window(win2_shape, tw_outer, xp)[win2_slices]

    lr_mask = xp.zeros(pf_mask.shape, dtype=xp.float32)
    lr_mask[lr_slices] = lr_win

    win2_mask = xp.zeros(pf_mask.shape, dtype=xp.float32)
    win2_mask[pf_slices] = pf_win

    if show and ndim == 2:
        from matplotlib import pyplot as plt

        fig, axes = plt.subplots(2, 2)
        axes = axes.ravel()
        axes[0].imshow(pf_mask)
        axes[0].set_title("PF mask")
        axes[1].imshow(lr_mask)
        axes[1].set_title("LR Filter")
        axes[2].imshow(win2_mask)
        axes[2].set_title("Filter")
        axes[3].imshow(xp.abs(img_est).T)
        axes[3].set_title("Initial Estimate")
        for ax in axes:
            ax.set_xticklabels("")
            ax.set_yticklabels("")

    if verbose:
        norm0 = xp.linalg.norm(img_est)
        max0 = xp.max(xp.abs(img_est))
    if return_all_estimates:
        all_img_est = [img_est]

    # POCS iterations
    for i in range(niter):
        # step 5
        rho1 = xp.abs(img_est) * xp.exp(1j * phi)

        if verbose:
            change2 = xp.linalg.norm(rho1 - img_est) / norm0
            change1 = xp.max(xp.abs(rho1 - img_est)) / max0
            print("change = {}%% {}%%".format(100 * change2, 100 * change1))

        # step 6
        s1 = fftnc(rho1)

        # steps 7 & 8
        full_kspace = win2_mask * kspace_init + (1 - win2_mask) * s1

        # step 9
        img_est = ifftnc(full_kspace)
        if return_all_estimates:
            all_img_est.append(img_est)

    if return_all_estimates:
        return all_img_est
    return img_est