Beispiel #1
0
def test_masker():
    am = masker(arr, mask, order="F", squeeze_output=True)
    testing.assert_array_almost_equal(am, [0, 8, 1, 7])
    assert am.shape == (nmask, )
    expected_embed = arr * mask
    testing.assert_array_equal(expected_embed, embed(am, mask, order="F"))

    # am = masker(arr, mask, order="C", squeeze_output=True)
    # testing.assert_array_almost_equal(am, [0, 1, 7, 8])
    # assert am.shape == (nmask,)
    # testing.assert_array_equal(expected_embed, embed(am, mask, order="C"))

    am = masker(arr, mask, order="F", squeeze_output=False)
    testing.assert_array_almost_equal(am[:, 0], [0, 8, 1, 7])
    assert am.shape == (nmask, 1)
    testing.assert_array_equal(expected_embed, embed(am, mask, order="F"))
Beispiel #2
0
 def _norm_single_rep(self, x):
     # chain forward and adjoint together for a single repetition
     y = self._forward_single_rep(x)
     if self.sample_mask is not None:
         y = embed(y, mask=self.sample_mask, order=self.order, xp=self.xp)
     y = y.reshape(self.shape_in, order=self.order)
     x = self._adjoint_single_rep(y)
     return x
Beispiel #3
0
def test_masker_multidim():
    nreps = 8
    arr3 = np.tile(arr[..., np.newaxis], (1, 1, nreps))
    am3 = masker(arr3, mask, order="F")
    assert am3.shape == (nmask, nreps)
    testing.assert_array_almost_equal(am3[:, 0], [0, 8, 1, 7])
    expected_embed = arr3 * mask[..., np.newaxis]
    testing.assert_array_almost_equal(expected_embed,
                                      embed(am3, mask, order="F"))
Beispiel #4
0
 def _norm_single_rep(self, x):
     # chain forward and adjoint together for a single repetition
     y = self._forward_single_rep(x)
     if self.sample_mask is not None:
         y = embed(y, mask=self.sample_mask, order=self.order)
     if self.order == "C":
         y_shape = (self.Ncoils, ) + self.shape_in1
     else:
         y_shape = self.shape_in1 + (self.Ncoils, )
     y = y.reshape(y_shape, order=self.order)
     x = self._adjoint_single_rep(y)
     return x
Beispiel #5
0
def test_embed_demo(show_figure=False):
    shape = (512, 500)
    mask = np.ones(shape, dtype=np.bool)
    mask[:8, :8] = 0

    x = np.arange(mask.sum())
    y1 = embed(x, mask)

    y2 = embed(np.column_stack((x, x, x, x)), mask)

    np.testing.assert_allclose(y1, y2[:, :, 1])

    if show_figure:
        import matplotlib.pyplot as plt

        plt.figure(), plt.imshow(y1.T), plt.show()

    x3 = np.tile(x[:, None, None], (1, 2, 3))
    y3 = embed(x3, mask)
    assert y3.ndim == mask.ndim + x3.ndim - 1

    np.testing.assert_allclose(y1, y3[:, :, -1, -1])
Beispiel #6
0
    def adjoint(self, y):
        # TODO: add test for this case and the coils + sample_mask case
        # if y.ndim == 1 or y.shape[-1] == 1:
        xp = self.xp
        nreps = int(y.size / self.shape[0])
        if self.sample_mask is not None:
            if y.ndim == 1 and self.ndim > 1:
                y = y[:, np.newaxis]
            nmask = self.sample_mask.sum()
            if y.shape[0] != nmask:
                if self.order == "C":
                    y = y.reshape((-1, nmask), order=self.order)
                else:
                    y = y.reshape((nmask, -1), order=self.order)
            y = embed(y, mask=self.sample_mask, order=self.order)
        if nreps == 1:
            # 1D or single repetition or single coil nD
            y = y.reshape(self.shape_in, order=self.order)

            x = self._adjoint_single_rep(y)
        else:
            if self.order == "C":
                shape_tmp = (nreps, ) + self.shape_in
            else:
                shape_tmp = self.shape_in + (nreps, )

            y = y.reshape(shape_tmp, order=self.order)
            x = xp.zeros(shape_tmp, dtype=xp.result_type(y, np.complex64))
            if self.order == "C":
                for rep in range(nreps):
                    x[rep, ...] = self._adjoint_single_rep(y[rep, ...])
            else:
                for rep in range(nreps):
                    x[..., rep] = self._adjoint_single_rep(y[..., rep])

        #        if self.im_mask:
        #            x = masker(x, self.im_mask, order=self.order)

        if self.ortho:
            x *= self.scale_ortho
        if x.dtype != self.result_dtype:
            x = x.astype(self.result_dtype)
        if self.force_real_image:
            x = x.real.astype(self.result_dtype)
        return x
Beispiel #7
0
def test_TV_masked(xp):
    # generate a random image-domain mask
    # im_mask = xp.random.standard_normal(c.shape) > -0.2

    # test 1D in/outputs but using masked values for the input array
    #  r = pywt.data.camera().astype(float)
    r = xp.random.randn(16, 16)
    x, y = xp.meshgrid(
        xp.arange(-r.shape[0] // 2, r.shape[0] // 2),
        xp.arange(-r.shape[1] // 2, r.shape[1] // 2),
        indexing="ij",
        sparse=True,
    )

    # make a circular mask
    im_mask = xp.sqrt(x * x + y * y) < r.shape[0] // 2
    # TODO: order='C' case is currently broken
    for order, mask_out in product(["F"], [None, im_mask]):
        r_masked = masker(r, im_mask, order=order)
        Wm = TV_Operator(
            r.shape,
            order=order,
            mask_in=im_mask,
            mask_out=mask_out,
            nd_input=True,
            nd_output=False,
            random_shift=True,
            **get_loc(xp),
        )
        out = Wm * r_masked

        if mask_out is None:
            assert_(out.ndim == 1)
            assert_(out.size == r.ndim * r.size)
            out = out.reshape(r.shape + (r.ndim, ), order=order)
        else:
            assert_(out.ndim == 1)
            nmask = mask_out.sum()
            if xp != np:
                nmask = nmask.get()
            assert_(out.size == Wm.ndim * nmask)
            out = embed(out, mask_out, order=order)

        Wm.H * (Wm * r_masked)
Beispiel #8
0
def test_FiniteDifference_masked(xp, shape):
    # test 1D in/outputs but using masked values for the input array
    #  r = pywt.data.camera().astype(float)
    rstate = xp.random.RandomState(1234)
    r = rstate.randn(*shape)
    x, y = xp.meshgrid(
        xp.arange(-r.shape[0] // 2, r.shape[0] // 2),
        xp.arange(-r.shape[1] // 2, r.shape[1] // 2),
        indexing="ij",
        sparse=True,
    )

    # make a circular mask
    im_mask = xp.sqrt(x ** 2 + y ** 2) < r.shape[0] // 2
    # TODO: order='C' case is currently broken
    for order, mask_out in product(["F"], [None, im_mask]):
        r_masked = masker(r, im_mask, order=order)
        Wm = FiniteDifferenceOperator(
            r.shape,
            order=order,
            use_corners=True,
            mask_in=im_mask,
            mask_out=mask_out,
            nd_input=True,
            nd_output=mask_out is not None,
            random_shift=True,
            **get_loc(xp),
        )
        out = Wm * r_masked

        if mask_out is None:
            assert_(out.ndim == 1)
            assert_(out.size == Wm.num_offsets * r.size)
            out = out.reshape(r.shape + (Wm.num_offsets,), order=order)
        else:
            assert_(out.ndim == 1)
            assert_(out.size == Wm.num_offsets * mask_out.sum())
            out = embed(out, mask_out, order=order)

        Wm.H * (Wm * r_masked)
Beispiel #9
0
    def forward(self, x):
        xp = self.xp
        if x.size % self.nargin != 0:
            raise ValueError("shape mismatch for forward DWT")
        if (self.mask_in is not None) and (not self.nd_input):
            x = embed(x, self.mask_in, order=self.order)
        nrepetitions = x.size // self.nargin
        if nrepetitions > 1:

            if self.order == "C":
                y = xp.zeros((nrepetitions, ) + self.coeff_arr_shape,
                             dtype=x.dtype)
                for rep in range(nrepetitions):
                    y[rep, ...] = self._forward1(x[rep, ...])
            else:
                y = xp.zeros(self.coeff_arr_shape + (nrepetitions, ),
                             dtype=x.dtype)
                for rep in range(nrepetitions):
                    y[..., rep] = self._forward1(x[..., rep])
        else:
            y = self._forward1(x)
        if (self.mask_out is not None) and (not self.nd_output):
            y = masker(y, self.mask_out, order=self.order)
        return y
Beispiel #10
0
 def adjoint(self, coeffs):
     xp = self.xp
     if coeffs.size % self.nargout != 0:
         raise ValueError("shape mismatch for adjoint DWT")
     nrepetitions = coeffs.size // self.nargout
     if (self.mask_out is not None) and (not self.nd_output):
         coeffs = embed(coeffs, self.mask_out, order=self.order)
         coeffs = coeffs.ravel(order=self.order)
     if nrepetitions > 1:
         if self.order == "C":
             x = xp.zeros((nrepetitions, ) + self.arr_shape,
                          dtype=coeffs.dtype)
             for rep in range(nrepetitions):
                 x[rep, ...] = self._adjoint1(coeffs[rep, ...])
         else:
             x = xp.zeros(self.arr_shape + (nrepetitions, ),
                          dtype=coeffs.dtype)
             for rep in range(nrepetitions):
                 x[..., rep] = self._adjoint1(coeffs[..., rep])
     else:
         x = self._adjoint1(coeffs)
     if (self.mask_in is not None) and (not self.nd_input):
         x = masker(x, self.mask_in, order=self.order)
     return x
Beispiel #11
0
    def norm(self, x):
        # if not hasattr(self, 'Q') or self.Q is None:
        #     warnings.warn("Toeplitz Q did not exist, creating it...")
        #     self.prep_toeplitz()

        x = complexify(x)
        if self.masked:
            x = embed(x, self.mask, order=self.order)
        try:
            if x.size % self.nargin != 0:
                raise ValueError("wrong size input")
            Nrepetitions = x.size // self.nargin
        except IndexError:
            Nrepetitions = 1

        if hasattr(self, "Q"):
            slices = [slice(None)] * x.ndim
            for d in range(len(self.Nd)):
                slices[d] = slice(self.Nd[d])
            if Nrepetitions == 1:
                y = fftn(x, s=self.Q.shape)
                y *= self.Q
                y = ifftn(y)[slices]
            else:
                if self.order == "C":
                    x_shape = (Nrepetitions,) + tuple(self.Nd)
                else:
                    x_shape = tuple(self.Nd) + (Nrepetitions,)
                x = x.reshape(x_shape, order=self.order)
                fft_axes = tuple(np.arange(len(self.Nd)))
                y = fftn(x, s=self.Q.shape, axes=fft_axes)
                y *= self.Q[..., np.newaxis]  # add an axis for repetitions
                y = ifftn(y, axes=fft_axes)[slices]
        else:
            y = self.H * (self * x)
        return y
Beispiel #12
0
def test_partial_FFT_with_im_mask(xp, nd_in, order, shift):
    """ masked FFT with missing samples and masked image domain """
    c = get_data(xp)
    rstate = xp.random.RandomState(1234)
    sample_mask = rstate.rand(*(128, 127)) > 0.5
    x, y = xp.meshgrid(
        xp.arange(-c.shape[0] // 2, c.shape[0] // 2),
        xp.arange(-c.shape[1] // 2, c.shape[1] // 2),
        indexing="ij",
        sparse=True,
    )

    # make a circular mask
    im_mask = xp.sqrt(x * x + y * y) < c.shape[0] // 2

    nd_out = False
    FTop = FFT_Operator(
        c.shape,
        order=order,
        im_mask=im_mask,
        use_fft_shifts=shift,
        nd_input=nd_in,
        nd_output=nd_out,
        sample_mask=sample_mask,
        gpu_force_reinit=False,
        mask_kspace_on_gpu=(not shift),
        **get_loc(xp),
    )

    # create new linear operator for forward followed by inverse transform
    FtF = FTop.H * FTop
    assert isinstance(FtF, LinearOperatorMulti)

    # test forward only
    forw = embed(FTop * masker(c, im_mask, order=order),
                 sample_mask,
                 order=order)

    if shift:
        expected_forw = sample_mask * fftnc(c * im_mask)
    else:
        expected_forw = sample_mask * fftn(c * im_mask)
    xp.testing.assert_allclose(forw, expected_forw, rtol=1e-7, atol=1e-4)

    # test roundtrip
    roundtrip = FTop.H * (FTop * masker(c, im_mask, order=order))
    if shift:
        expected_roundtrip = masker(ifftnc(sample_mask * fftnc(c * im_mask)),
                                    im_mask,
                                    order=order)
    else:
        expected_roundtrip = masker(ifftn(sample_mask * fftn(c * im_mask)),
                                    im_mask,
                                    order=order)

    xp.testing.assert_allclose(roundtrip,
                               expected_roundtrip,
                               rtol=1e-7,
                               atol=1e-4)

    # test roundtrip with 2 reps
    c2 = xp.stack([c] * 2, axis=-1)
    roundtrip = FTop.H * (FTop * masker(c2, im_mask, order=order))
    if shift:
        expected_roundtrip = masker(
            ifftnc(
                sample_mask[..., xp.newaxis] *
                fftnc(c2 * im_mask[..., xp.newaxis], axes=(0, 1)),
                axes=(0, 1),
            ),
            im_mask,
            order=order,
        )
    else:
        expected_roundtrip = masker(
            ifftn(
                sample_mask[..., xp.newaxis] *
                fftn(c2 * im_mask[..., xp.newaxis], axes=(0, 1)),
                axes=(0, 1),
            ),
            im_mask,
            order=order,
        )
    xp.testing.assert_allclose(roundtrip,
                               expected_roundtrip,
                               rtol=1e-7,
                               atol=1e-4)
Beispiel #13
0
    def adjoint(self, y):
        # TODO: add test for this case and the coils + sample_mask case
        # if y.ndim == 1 or y.shape[-1] == 1:
        xp = self.xp
        ncoils = self.Ncoils
        nreps = int(y.size / self.shape[0])
        if self.sample_mask is not None:
            if y.ndim == 1 and self.ndim > 1:
                y = y[:, np.newaxis]
            nmask = xp.count_nonzero(self.sample_mask)
            if self.on_gpu:
                nmask = nmask.get()
            if y.shape[0] != nmask:
                if self.order == "C":
                    y = y.reshape((-1, nmask), order=self.order)
                else:
                    y = y.reshape((nmask, -1), order=self.order)
            y = embed(y, mask=self.sample_mask, order=self.order)
        if nreps == 1:
            # 1D or single repetition or single coil nD
            if self.order == "C":
                y = y.reshape((ncoils, ) + self.shape_in1, order=self.order)
            else:
                y = y.reshape(self.shape_in1 + (ncoils, ), order=self.order)
            if self.Nmaps == 1:
                x = self._adjoint_single_rep(y, i_map=0)
            else:
                x = xp.zeros(self.shape_inM,
                             dtype=xp.result_type(y, xp.complex64))
                if self.order == "C":
                    for i_map in range(self.Nmaps):
                        x[i_map, ...] = self._adjoint_single_rep(y,
                                                                 i_map=i_map)
                else:
                    for i_map in range(self.Nmaps):
                        x[..., i_map] = self._adjoint_single_rep(y,
                                                                 i_map=i_map)
        else:
            if self.order == "C":
                y = y.reshape((nreps, ncoils) + self.shape_in1,
                              order=self.order)  # or shape_out?
                x_shape = (nreps, ) + self.shape_inM
            else:
                y = y.reshape(self.shape_in1 + (ncoils, nreps),
                              order=self.order)  # or shape_out?
                x_shape = self.shape_inM + (nreps, )

            x = xp.zeros(x_shape, dtype=xp.result_type(y, xp.complex64))
            for i_map in range(self.Nmaps):
                if self.order == "C":
                    for rep in range(nreps):
                        x[rep, i_map,
                          ...] = self._adjoint_single_rep(y[rep, ...],
                                                          i_map=i_map)
                else:
                    for rep in range(nreps):
                        x[..., i_map,
                          rep] = self._adjoint_single_rep(y[..., rep],
                                                          i_map=i_map)
        #        if self.im_mask:
        #            x = masker(x, self.im_mask, order=self.order)
        if self.ortho:
            x *= self.scale_ortho
        if x.dtype != self.result_dtype:
            x = x.astype(self.result_dtype)
        if self.force_real_image:
            # x = x.real.astype(self.result_dtype)
            if x.dtype in [np.complex64, np.complex128]:
                x.imag[:] = 0
        return x
Beispiel #14
0
def mri_partial_fourier_nd(
    partial_kspace,
    pf_mask,
    niter=5,
    tw_inner=8,
    tw_outer=3,
    fill_conj=False,
    init=None,
    verbose=False,
    show=False,
    return_all_estimates=False,
    xp=None,
):
    """Partial Fourier reconstruction.

    Parameters
    ----------

    Returns
    -------

    Notes
    -----
    The implementation is based on a multi-dimensional iterative reconstruction
    technique as described in [1]_.  This is an extension of the 1D iterative
    methods described in [2]_ and [3]_.  The concept of partial Fourier imaging
    was first proposed in [4]_, [5]_.

    References
    ----------
    .. [1] Xu, Y. and Haacke, E. M.  Partial Fourier imaging in
        multi-dimensions: A means to save a full factor of two in time.
        J. Magn. Reson. Imaging, 2001; 14:628–635.
        doi:10.1002/jmri.1228

    .. [2] Haacke, E.; Lindskogj, E. & Lin, W. A fast, iterative,
        partial-Fourier technique capable of local phase recovery.
        Journal of Magnetic Resonance, 1991; 92:126-145.

    .. [3] Liang, Z.-P.; Boada, F.; Constable, R. T.; Haacke, E.; Lauterbur,
        P. & Smith, M. Constrained reconstruction methods in MR imaging.
        Rev Magn Reson Med, 1992; 4: 67-185

    .. [4] Margosian, P.; Schmitt, F. & Purdy, D. Faster MR imaging: imaging
        with half the data Health Care Instrum, 1986; 1:195.

    .. [5] Feinberg, D. A.; Hale, J. D.; Watts, J. C.; Kaufman, L. & Mark, A.
        Halving MR imaging time by conjugation: demonstration at 3.5 kG.
        Radiology, 1986, 161, 527-531

    """
    xp, on_gpu = get_array_module(partial_kspace, xp)
    partial_kspace = xp.asarray(partial_kspace)
    # dtype = partial_kspace.dtype

    pf_mask = xp.asarray(pf_mask)
    if pf_mask.dtype != xp.bool:
        pf_mask = pf_mask.astype(xp.bool)
    im_shape = pf_mask.shape
    ndim = pf_mask.ndim

    if not xp.all(xp.asarray(im_shape) % 2 == 0):
        raise ValueError(
            "This function assumes all k-space dimensions have even length.")

    if partial_kspace.size != xp.count_nonzero(pf_mask):
        raise ValueError(
            "partial kspace should have total size equal to the number of "
            "non-zeros in pf_mask")

    kspace_init = embed(partial_kspace, pf_mask)
    img_est = ifftnc(kspace_init)

    lr_kspace = xp.zeros_like(kspace_init)

    nz = xp.where(pf_mask)
    lr_slices = [slice(None)] * ndim
    pf_slices = [slice(None)] * ndim
    win2_slices = [slice(None)] * ndim
    lr_shape = [0] * ndim
    # pf_shape = [0, ]*ndim
    win2_shape = [0] * ndim
    for d in range(ndim):
        nz_min = xp.min(nz[d])
        nz_max = xp.max(nz[d])
        if hasattr(nz_min, "get"):
            # 0-dim GPU array to scalar
            nz_min, nz_max = nz_min.get(), nz_max.get()
        i_mid = im_shape[d] // 2
        if nz_min == 0:
            i_end = nz_max
            width = i_end - i_mid
        else:
            i_start = nz_min
            width = i_mid - i_start
        lr_slices[d] = slice(i_mid - width, i_mid + width + 1)
        lr_shape[d] = 2 * width + 1

        # pf_slices[d] = slice(nz_min, nz_max + 1)
        pf_shape = nz_max - nz_min + 1
        pf_slices[d] = slice(nz_min, nz_max + 1)
        win2_shape[d] = pf_shape + tw_outer
        if nz_min == 0:
            # pf_slices[d] = slice(nz_min, nz_max + 1 + tw_outer)
            win2_slices[d] = slice(tw_outer, tw_outer + pf_shape)
        else:
            # pf_slices[d] = slice(nz_min - tw_outer, nz_max + 1)
            win2_slices[d] = slice(pf_shape)

    lr_slices = tuple(lr_slices)
    win2_slices = tuple(win2_slices)
    pf_slices = tuple(pf_slices)
    # lr_mask = xp.zeros(pf_mask, dtype=xp.zeros)
    lr_win = hanning_apodization_window(lr_shape, tw_inner, xp)
    lr_kspace[lr_slices] = kspace_init[lr_slices] * lr_win

    img_lr = ifftnc(lr_kspace)
    phi = xp.angle(img_lr)

    pf_win = hanning_apodization_window(win2_shape, tw_outer, xp)[win2_slices]

    lr_mask = xp.zeros(pf_mask.shape, dtype=xp.float32)
    lr_mask[lr_slices] = lr_win

    win2_mask = xp.zeros(pf_mask.shape, dtype=xp.float32)
    win2_mask[pf_slices] = pf_win

    if show and ndim == 2:
        from matplotlib import pyplot as plt

        fig, axes = plt.subplots(2, 2)
        axes = axes.ravel()
        axes[0].imshow(pf_mask)
        axes[0].set_title("PF mask")
        axes[1].imshow(lr_mask)
        axes[1].set_title("LR Filter")
        axes[2].imshow(win2_mask)
        axes[2].set_title("Filter")
        axes[3].imshow(xp.abs(img_est).T)
        axes[3].set_title("Initial Estimate")
        for ax in axes:
            ax.set_xticklabels("")
            ax.set_yticklabels("")

    if verbose:
        norm0 = xp.linalg.norm(img_est)
        max0 = xp.max(xp.abs(img_est))
    if return_all_estimates:
        all_img_est = [img_est]

    # POCS iterations
    for i in range(niter):
        # step 5
        rho1 = xp.abs(img_est) * xp.exp(1j * phi)

        if verbose:
            change2 = xp.linalg.norm(rho1 - img_est) / norm0
            change1 = xp.max(xp.abs(rho1 - img_est)) / max0
            print("change = {}%% {}%%".format(100 * change2, 100 * change1))

        # step 6
        s1 = fftnc(rho1)

        # steps 7 & 8
        full_kspace = win2_mask * kspace_init + (1 - win2_mask) * s1

        # step 9
        img_est = ifftnc(full_kspace)
        if return_all_estimates:
            all_img_est.append(img_est)

    if return_all_estimates:
        return all_img_est
    return img_est
Beispiel #15
0
def test_DWT_masked(xp, order, decimation):
    rstate = xp.random.RandomState(5)
    c = rstate.randn(128, 128)
    # generate a random image-domain mask
    # im_mask = xp.random.standard_normal(c.shape) > -0.2

    # test 1D in/outputs but using masked values for the input array
    im_mask = xp.ones(c.shape, dtype=xp.bool)
    im_mask[16:64, 16:64] = 0  # mask out a region
    c_masked = masker(c, im_mask, order=order)
    fb = filters.pywt_as_filterbank("db2", xp=xp, decimation=decimation)
    Wm = MDWT_Operator(
        c.shape,
        level=2,
        filterbank=fb,
        order=order,
        mode="periodization",
        mask_in=im_mask,
        mask_out=None,
        nd_input=False,
        nd_output=False,
        random_shift=True,
        **get_loc(xp),
    )
    coeffs = Wm * c_masked
    assert_(coeffs.ndim == 1)
    if decimation == 2:
        assert_(coeffs.size == c.size)
        coeffs = coeffs.reshape(c.shape, order=order)
    else:
        assert_(coeffs.size > c.size)

    c_recon = Wm.H * coeffs
    assert_(c_recon.ndim == 1)
    if xp is np:
        assert_(c_recon.size == im_mask.sum())
    else:
        assert_(c_recon.size == im_mask.sum().get())
    c_recon = embed(c_recon, im_mask, order=order)
    xp.testing.assert_allclose(c_recon, c * im_mask, rtol=1e-9, atol=1e-9)

    # test 1D in/outputs but using masked values for both input and output
    # arrays
    coeffs_mask = coeffs != 0  # mask out regions of zero-valued coeffs
    Wm2 = MDWT_Operator(
        c.shape,
        level=2,
        filterbank=fb,
        order=order,
        mode="periodization",
        mask_in=im_mask,
        mask_out=coeffs_mask,
        nd_input=False,
        nd_output=False,
        random_shift=True,
        **get_loc(xp),
    )
    coeffs = Wm2 * c_masked
    assert_(coeffs.ndim == 1)
    if decimation == 2:
        if xp is np:
            assert_(coeffs.size == coeffs_mask.sum())
        else:
            assert_(coeffs.size == coeffs_mask.sum().get())
    c_recon = Wm2.H * coeffs
    c_recon = embed(c_recon, im_mask, order=order)
    xp.testing.assert_allclose(c_recon, c * im_mask, rtol=1e-9, atol=1e-9)
def _test_mri_multi(
    ndim=3,
    N0=8,
    grid_os_factor=1.5,
    J0=4,
    Ld=4096,
    n_coils=1,
    fieldmap_segments=None,
    precisions=["single", "double"],
    phasings=["real", "complex"],
    recon_cases=["CPU,Tab0", "CPU,Tab", "CPU,Sp"],
    rtol=1e-3,
    compare_to_exact=False,
    show_figures=False,
    nufft_kwargs={},
    navg_time=1,
    n_creation=1,
    return_errors=False,
    gpu_memflags=None,
    verbose=False,
    return_operator=False,
    spectral_offsets=None,
):
    """Run a batch of NUFFT tests."""
    all_err_forward = np.zeros(
        (len(recon_cases), len(precisions), len(phasings))
    )
    all_err_adj = np.zeros((len(recon_cases), len(precisions), len(phasings)))
    alltimes = {}
    if not np.isscalar(navg_time):
        navg_time_cpu, navg_time_gpu = navg_time
    else:
        navg_time_cpu = navg_time_gpu = navg_time
    for i, recon_case in enumerate(recon_cases):
        if "CPU" in recon_case:
            navg_time = navg_time_cpu
        else:
            navg_time = navg_time_gpu

        for j, precision in enumerate(precisions):
            for k, phasing in enumerate(phasings):
                if verbose:
                    print(
                        "phasing={}, precision={}, type={}".format(
                            phasing, precision, recon_case
                        )
                    )

                if "Tab" in recon_case:
                    # may want to create twice when benchmarking GPU case
                    # because the custom kernels are compiled the first time
                    ncr_max = n_creation
                else:
                    ncr_max = 1
                # on_gpu = ('GPU' in recon_case)
                for ncr in range(ncr_max):
                    (
                        Gn,
                        wi_full,
                        xTrue,
                        ig,
                        data_true,
                        times,
                    ) = generate_sim_data(
                        recon_case=recon_case,
                        ndim=ndim,
                        N0=N0,
                        J0=J0,
                        grid_os_factor=grid_os_factor,
                        fieldmap_segments=fieldmap_segments,
                        Ld=Ld,
                        n_coils=n_coils,
                        precision=precision,
                        phasing=phasing,
                        nufft_kwargs=nufft_kwargs,
                        MRI_object_kwargs=dict(gpu_memflags=gpu_memflags),
                        spectral_offsets=spectral_offsets,
                    )

                xp = Gn.xp

                # time the forward operator
                sim_data = Gn * xTrue  # dry run
                tstart = time.time()
                for nt in range(navg_time):
                    sim_data = Gn * xTrue
                    sim_data += 0.0
                sim_data = xp.squeeze(sim_data)  # TODO: should be 1D already?
                # print("type(xTrue) = {}".format(type(xTrue)))
                # print("type(sim_data) = {}".format(type(sim_data)))
                t_for = (time.time() - tstart) / navg_time
                times["MRI: forward"] = t_for

                # time the norm operator
                Gn.norm(xTrue)  # dry run
                tstart = time.time()
                for nt in range(navg_time):
                    Gn.norm(xTrue)
                t_norm = (time.time() - tstart) / navg_time

                times["MRI: norm"] = t_norm
                if precision == "single":
                    dtype_real = np.float32
                    dtype_cplx = np.complex64
                else:
                    dtype_real = np.float64
                    dtype_cplx = np.complex128

                if "Tab" in recon_case:
                    if phasing == "complex":
                        assert_equal(Gn.Gnufft.h[0].dtype, dtype_cplx)
                    else:
                        assert_equal(Gn.Gnufft.h[0].dtype, dtype_real)
                else:
                    if phasing == "complex":
                        assert_equal(Gn.Gnufft.p.dtype, dtype_cplx)
                    else:
                        assert_equal(Gn.Gnufft.p.dtype, dtype_real)
                assert_equal(sim_data.dtype, dtype_cplx)

                if compare_to_exact:
                    # compare_to_exact only currently for single-coil,
                    # no fieldmap case
                    if spectral_offsets is not None:
                        raise NotImplementedError(
                            "compare_to_exact doesn't currently support "
                            "spectral offsets"
                        )
                    nshift_exact = tuple(s / 2 for s in Gn.Nd)
                    sim_data2 = dtft(
                        xTrue, Gn.omega, shape=Gn.Nd, n_shift=nshift_exact
                    )

                    sd2_norm = xp.linalg.norm(sim_data2)
                    rel_err = xp.linalg.norm(sim_data - sim_data2) / sd2_norm
                    if "GPU" in recon_case:
                        if hasattr(rel_err, "get"):
                            rel_err = rel_err.get()
                    all_err_forward[i, j, k] = rel_err
                    print(
                        "{},{},{}: forward error = {}".format(
                            recon_case, precision, phasing, rel_err
                        )
                    )
                    rel_err_mag = (
                        xp.linalg.norm(np.abs(sim_data) - np.abs(sim_data2))
                        / sd2_norm
                    )
                    print(
                        f"{recon_case},{precision},{phasing}: "
                        f"forward mag diff error = {rel_err_mag}"
                    )
                    assert rel_err < rtol

                # TODO: update DiagonalOperatorMulti to auto-set loc_in,
                #       loc_out appropriately
                if xp is np:
                    diag_args = dict(loc_in="cpu", loc_out="cpu")
                else:
                    diag_args = dict(loc_in="gpu", loc_out="gpu")
                diag_op = DiagonalOperatorMulti(wi_full, **diag_args)
                if n_coils == 1:
                    data_dcf = diag_op * data_true
                else:
                    data_dcf = diag_op * sim_data

                # time the adjoint operation
                im_est = Gn.H * data_dcf  # dry run
                tstart = time.time()
                for nt in range(navg_time):
                    im_est = Gn.H * data_dcf
                t_adj = (time.time() - tstart) / navg_time
                times["MRI: adjoint"] = t_adj

                if hasattr(Gn, "mask") and Gn.mask is not None:
                    im_est = embed(im_est, Gn.mask)
                else:
                    if spectral_offsets is None:
                        im_est = im_est.reshape(Gn.Nd, order=Gn.order)
                    else:
                        im_est = im_est.reshape(
                            tuple(Gn.Nd) + (len(spectral_offsets),),
                            order=Gn.order,
                        )

                if compare_to_exact:
                    im_est_exact = dtft_adj(
                        data_dcf, Gn.omega, shape=Gn.Nd, n_shift=nshift_exact
                    )
                    ex_norm = xp.linalg.norm(im_est_exact)
                    rel_err = xp.linalg.norm(im_est - im_est_exact) / ex_norm
                    all_err_adj[i, j, k] = rel_err
                    if verbose:
                        print(
                            "{},{},{}: adjoint error = {}".format(
                                recon_case, precision, phasing, rel_err
                            )
                        )
                    rel_err_mag = (
                        xp.linalg.norm(np.abs(im_est) - np.abs(im_est_exact))
                        / ex_norm
                    )
                    if verbose:
                        print(
                            "{},{},{}: adjoint mag diff error = {}".format(
                                recon_case, precision, phasing, rel_err
                            )
                        )
                    assert_(rel_err < rtol)

                title = ", ".join([recon_case, precision, phasing])
                if show_figures:
                    from matplotlib import pyplot as plt
                    from pyvolplot import volshow

                    if compare_to_exact:
                        volshow(
                            [
                                im_est_exact,
                                im_est,
                                im_est_exact - im_est,
                                xp.abs(im_est_exact) - xp.abs(im_est),
                            ]
                        )
                    else:
                        volshow(im_est)
                        plt.title(title)
                alltimes[title] = times

    if return_operator:
        if return_errors:
            return Gn, alltimes, all_err_forward, all_err_adj
        return Gn, alltimes

    if return_errors:
        return alltimes, all_err_forward, all_err_adj
    return alltimes