示例#1
0
def test_nufft_adj(xp, mode, phasing, verbose=False):
    """ test nufft_adj() """
    N1 = 4
    N2 = 8
    n_shift = [2.7, 3.1]  # random shifts to stress it
    o1 = 2 * np.pi * xp.array([0.0, 0.1, 0.3, 0.4, 0.7, 0.9])
    o2 = o1[::-1].copy()
    omega = xp.stack((o1, o2), axis=-1)
    st = NufftBase(
        omega=omega,
        Nd=(N1, N2),
        Jd=[8, 8],
        Kd=2 * np.array([N1, N2]),
        Ld=1024,
        n_shift=n_shift,
        phasing=phasing,
        mode=mode,
        on_gpu=xp != np,
    )

    data = xp.arange(1, o1.size + 1).ravel()**2  # test spectrum
    xd = dtft_adj(data, omega, shape=(N1, N2), n_shift=n_shift, xp=xp)

    data_3reps = xp.stack((data, ) * 3, axis=-1)
    xn = st.adj(data_3reps)
    if verbose:
        print("nufft vs dtft max%%diff = %g" %
              max_percent_diff(xd, xn[:, :, -1]))
    xp.testing.assert_array_almost_equal(xp.squeeze(xd),
                                         xp.squeeze(xn[:, :, -1]),
                                         decimal=4)
    return
示例#2
0
def test_nufft_2d(xp, mode, precision, phasing, Kd, Jd, order, verbose=False):
    Nd = (16, 16)
    Ld = 1024
    n_shift = np.asarray(Nd) / 2
    omega = _perturbed_gridpoints(Nd, xp=xp)
    rstate = xp.random.RandomState(1234)

    rtol = 1e-3
    atol = 1e-5
    A = NufftBase(
        omega=omega,
        Nd=Nd,
        Jd=Jd,
        Kd=Kd,
        n_shift=n_shift,
        mode=mode,
        Ld=Ld,
        precision=precision,
        phasing=phasing,
        on_gpu=xp != np,
        order=order,
    )
    x = rstate.standard_normal(Nd)
    x = x + 1j * rstate.standard_normal(Nd)

    # forward
    y = A.fft(x)
    y_true = dtft(x, omega=omega, shape=Nd, n_shift=n_shift)
    xp.testing.assert_allclose(y, y_true, rtol=rtol, atol=atol)

    # multi-repetition forward
    if order == "F":
        x_reps = xp.stack((x, ) * 2, axis=-1)
        sl1 = (Ellipsis, 0)
    else:
        x_reps = xp.stack((x, ) * 2, axis=0)
        sl1 = (0, Ellipsis)
    y_reps = A.fft(x_reps)
    xp.testing.assert_allclose(y_reps[sl1], y_true, rtol=rtol, atol=atol)

    # adjoint
    x_adj = A.adj(y)
    x_adj_true = dtft_adj(y, omega=omega, shape=Nd, n_shift=n_shift)
    xp.testing.assert_allclose(x_adj, x_adj_true, rtol=rtol, atol=atol)

    # multi-repetition adjoint
    if order == "F":
        y_reps = xp.stack((y, ) * 2, axis=-1)
        sl1 = (Ellipsis, 0)
    else:
        y_reps = xp.stack((y, ) * 2, axis=0)
        sl1 = (0, Ellipsis)
    x_adj = A.adj(y_reps)
    xp.testing.assert_allclose(x_adj[sl1], x_adj_true, rtol=rtol, atol=atol)

    if verbose:
        print(mode, precision, phasing)
        print(f"\t{max_percent_diff(y, y_true)} "
              f"{max_percent_diff(x_adj, x_adj_true)}")
示例#3
0
def test_dtft_adj_3d(xp, verbose=False, test_Cython=False):
    Nd = (32, 16, 2)
    n_shift = np.asarray([2, 1, 3]).reshape(3, 1)

    rstate = xp.random.RandomState(1234)
    X = rstate.standard_normal(Nd)  # test signal
    X = X + 1j * rstate.standard_normal(Nd)

    # test with uniform frequency locations:
    om = _uniform_freqs(Nd, xp=xp)

    xd = dtft_adj(X, om, Nd, n_shift)
    xl = dtft_adj(X, om, Nd, n_shift, useloop=True)
    xp.testing.assert_allclose(xd, xl, atol=1e-7)

    Xp = xp.exp(-1j * xp.dot(om, xp.asarray(n_shift)))
    Xp = X * Xp.reshape(X.shape, order="F")
    xf = xp.fft.ifftn(Xp) * np.prod(Nd)
    xp.testing.assert_allclose(xd, xf, atol=1e-7)
    if verbose:
        print("loop max %% difference = %g" % max_percent_diff(xl, xd))
        print("ifft max %% difference = %g" % max_percent_diff(xf, xd))

    if test_Cython:
        import time
        from mrrt.nufft.cy_dtft import dtft_adj as cy_dtft_adj

        t_start = time.time()
        xc = cy_dtft_adj(X.ravel(order="F"), om, Nd, n_shift)
        print("duration (1 rep) = {}".format(time.time() - t_start))
        print("ifft max %% difference = %g" % max_percent_diff(xf, xc))

        X_16rep = xp.tile(X.ravel(order="F")[:, None], (1, 16))
        t_start = time.time()
        cy_dtft_adj(X_16rep, om, Nd, xp.asarray(n_shift))
        print("duration (16 reps) = {}".format(time.time() - t_start))
        t_start = time.time()
        X_64rep = xp.tile(X.ravel(order="F")[:, None], (1, 64))
        xc64 = cy_dtft_adj(X_64rep, om, Nd, xp.asarray(n_shift))
        max_percent_diff(xf, xc64[..., -1])
        print("duration (64 reps) = {}".format(time.time() - t_start))
    #        %timeit xd = dtft_adj(X_16rep, om, Nd, n_shift);

    return
示例#4
0
def test_dtft_adj_1d(xp, verbose=False):
    Nd = (16, )
    n_shift = np.asarray([2]).reshape(1, 1)

    rstate = xp.random.RandomState(1234)
    X = rstate.standard_normal(Nd)  # test signal
    X = X + 1j * rstate.standard_normal(Nd)

    # test with uniform frequency locations:
    om = _uniform_freqs(Nd, xp=xp)

    xd = dtft_adj(X, om, Nd, n_shift)
    xl = dtft_adj(X, om, Nd, n_shift, useloop=True)
    xp.testing.assert_allclose(xd, xl, atol=1e-7)

    Xp = xp.exp(-1j * xp.dot(om, xp.asarray(n_shift)))
    Xp = X * Xp.reshape(X.shape, order="F")
    xf = xp.fft.ifftn(Xp) * np.prod(Nd)
    xp.testing.assert_allclose(xd, xf, atol=1e-7)
    if verbose:
        print("ifft max %% difference = %g" % max_percent_diff(xf, xd))
    return
示例#5
0
def test_nufft_3d(xp, mode, precision, phasing, order, verbose=False):
    ndim = 3
    Nd = [8] * ndim
    Kd = [16] * ndim
    Jd = [6] * ndim  # use odd kernel for variety (even in 1D, 2D tests)
    Ld = 1024
    n_shift = np.asarray(Nd) / 2
    omega = _perturbed_gridpoints(Nd, xp=xp)

    rtol = 1e-2
    atol = 1e-4
    rstate = xp.random.RandomState(1234)
    A = NufftBase(
        omega=omega,
        Nd=Nd,
        Jd=Jd,
        Kd=Kd,
        n_shift=n_shift,
        mode=mode,
        Ld=Ld,
        precision=precision,
        phasing=phasing,
        order=order,
        on_gpu=xp != np,
    )
    x = rstate.standard_normal(Nd)
    x = x + 1j * rstate.standard_normal(Nd)

    # forward
    y = A.fft(x)
    y_true = dtft(x, omega=omega, shape=Nd, n_shift=n_shift)
    xp.testing.assert_allclose(y, y_true, rtol=rtol, atol=atol)

    # TODO: fix case with multiple additional axes at start or end
    # (multi-repetition forward with 2 additional axes at start or end)
    if order == "C":
        # x_reps = xp.stack((x,) * 4, axis=0).reshape((2, 2) + x.shape)
        x_reps = xp.stack((x, ) * 4, axis=0).reshape((4, ) + x.shape)
        sl1 = (0, Ellipsis)
    else:
        # x_reps = xp.stack((x,) * 4, axis=-1).reshape(x.shape + (2, 2))
        x_reps = xp.stack((x, ) * 4, axis=-1).reshape(x.shape + (4, ))
        sl1 = (Ellipsis, 0)
    y_reps = A.fft(x_reps)
    xp.testing.assert_allclose(y_reps[sl1], y_true, rtol=rtol, atol=atol)

    # adjoint
    x_adj = A.adj(y)
    x_adj_true = dtft_adj(y, omega=omega, shape=Nd, n_shift=n_shift)
    xp.testing.assert_allclose(x_adj, x_adj_true, rtol=rtol, atol=atol)

    # TODO: fix case with multiple additional axes at start or end
    # multi-repetition adjoint with 2 additional axes at start or end
    if order == "C":
        # y_reps = xp.stack((y,) * 4, axis=0).reshape((2, 2) + y.shape)
        y_reps = xp.stack((y, ) * 4, axis=0).reshape((4, ) + y.shape)
    else:
        # y_reps = xp.stack((y,) * 4, axis=-1).reshape(y.shape + (2, 2))
        y_reps = xp.stack((y, ) * 4, axis=-1).reshape(y.shape + (4, ))
    x_adj = A.adj(y_reps)
    xp.testing.assert_allclose(x_adj[sl1], x_adj_true, rtol=rtol, atol=atol)

    if verbose:
        print(mode, precision, phasing)
        print(f"\t{max_percent_diff(y, y_true)} "
              f"{max_percent_diff(x_adj, x_adj_true)}")
def _test_mri_multi(
    ndim=3,
    N0=8,
    grid_os_factor=1.5,
    J0=4,
    Ld=4096,
    n_coils=1,
    fieldmap_segments=None,
    precisions=["single", "double"],
    phasings=["real", "complex"],
    recon_cases=["CPU,Tab0", "CPU,Tab", "CPU,Sp"],
    rtol=1e-3,
    compare_to_exact=False,
    show_figures=False,
    nufft_kwargs={},
    navg_time=1,
    n_creation=1,
    return_errors=False,
    gpu_memflags=None,
    verbose=False,
    return_operator=False,
    spectral_offsets=None,
):
    """Run a batch of NUFFT tests."""
    all_err_forward = np.zeros(
        (len(recon_cases), len(precisions), len(phasings))
    )
    all_err_adj = np.zeros((len(recon_cases), len(precisions), len(phasings)))
    alltimes = {}
    if not np.isscalar(navg_time):
        navg_time_cpu, navg_time_gpu = navg_time
    else:
        navg_time_cpu = navg_time_gpu = navg_time
    for i, recon_case in enumerate(recon_cases):
        if "CPU" in recon_case:
            navg_time = navg_time_cpu
        else:
            navg_time = navg_time_gpu

        for j, precision in enumerate(precisions):
            for k, phasing in enumerate(phasings):
                if verbose:
                    print(
                        "phasing={}, precision={}, type={}".format(
                            phasing, precision, recon_case
                        )
                    )

                if "Tab" in recon_case:
                    # may want to create twice when benchmarking GPU case
                    # because the custom kernels are compiled the first time
                    ncr_max = n_creation
                else:
                    ncr_max = 1
                # on_gpu = ('GPU' in recon_case)
                for ncr in range(ncr_max):
                    (
                        Gn,
                        wi_full,
                        xTrue,
                        ig,
                        data_true,
                        times,
                    ) = generate_sim_data(
                        recon_case=recon_case,
                        ndim=ndim,
                        N0=N0,
                        J0=J0,
                        grid_os_factor=grid_os_factor,
                        fieldmap_segments=fieldmap_segments,
                        Ld=Ld,
                        n_coils=n_coils,
                        precision=precision,
                        phasing=phasing,
                        nufft_kwargs=nufft_kwargs,
                        MRI_object_kwargs=dict(gpu_memflags=gpu_memflags),
                        spectral_offsets=spectral_offsets,
                    )

                xp = Gn.xp

                # time the forward operator
                sim_data = Gn * xTrue  # dry run
                tstart = time.time()
                for nt in range(navg_time):
                    sim_data = Gn * xTrue
                    sim_data += 0.0
                sim_data = xp.squeeze(sim_data)  # TODO: should be 1D already?
                # print("type(xTrue) = {}".format(type(xTrue)))
                # print("type(sim_data) = {}".format(type(sim_data)))
                t_for = (time.time() - tstart) / navg_time
                times["MRI: forward"] = t_for

                # time the norm operator
                Gn.norm(xTrue)  # dry run
                tstart = time.time()
                for nt in range(navg_time):
                    Gn.norm(xTrue)
                t_norm = (time.time() - tstart) / navg_time

                times["MRI: norm"] = t_norm
                if precision == "single":
                    dtype_real = np.float32
                    dtype_cplx = np.complex64
                else:
                    dtype_real = np.float64
                    dtype_cplx = np.complex128

                if "Tab" in recon_case:
                    if phasing == "complex":
                        assert_equal(Gn.Gnufft.h[0].dtype, dtype_cplx)
                    else:
                        assert_equal(Gn.Gnufft.h[0].dtype, dtype_real)
                else:
                    if phasing == "complex":
                        assert_equal(Gn.Gnufft.p.dtype, dtype_cplx)
                    else:
                        assert_equal(Gn.Gnufft.p.dtype, dtype_real)
                assert_equal(sim_data.dtype, dtype_cplx)

                if compare_to_exact:
                    # compare_to_exact only currently for single-coil,
                    # no fieldmap case
                    if spectral_offsets is not None:
                        raise NotImplementedError(
                            "compare_to_exact doesn't currently support "
                            "spectral offsets"
                        )
                    nshift_exact = tuple(s / 2 for s in Gn.Nd)
                    sim_data2 = dtft(
                        xTrue, Gn.omega, shape=Gn.Nd, n_shift=nshift_exact
                    )

                    sd2_norm = xp.linalg.norm(sim_data2)
                    rel_err = xp.linalg.norm(sim_data - sim_data2) / sd2_norm
                    if "GPU" in recon_case:
                        if hasattr(rel_err, "get"):
                            rel_err = rel_err.get()
                    all_err_forward[i, j, k] = rel_err
                    print(
                        "{},{},{}: forward error = {}".format(
                            recon_case, precision, phasing, rel_err
                        )
                    )
                    rel_err_mag = (
                        xp.linalg.norm(np.abs(sim_data) - np.abs(sim_data2))
                        / sd2_norm
                    )
                    print(
                        f"{recon_case},{precision},{phasing}: "
                        f"forward mag diff error = {rel_err_mag}"
                    )
                    assert rel_err < rtol

                # TODO: update DiagonalOperatorMulti to auto-set loc_in,
                #       loc_out appropriately
                if xp is np:
                    diag_args = dict(loc_in="cpu", loc_out="cpu")
                else:
                    diag_args = dict(loc_in="gpu", loc_out="gpu")
                diag_op = DiagonalOperatorMulti(wi_full, **diag_args)
                if n_coils == 1:
                    data_dcf = diag_op * data_true
                else:
                    data_dcf = diag_op * sim_data

                # time the adjoint operation
                im_est = Gn.H * data_dcf  # dry run
                tstart = time.time()
                for nt in range(navg_time):
                    im_est = Gn.H * data_dcf
                t_adj = (time.time() - tstart) / navg_time
                times["MRI: adjoint"] = t_adj

                if hasattr(Gn, "mask") and Gn.mask is not None:
                    im_est = embed(im_est, Gn.mask)
                else:
                    if spectral_offsets is None:
                        im_est = im_est.reshape(Gn.Nd, order=Gn.order)
                    else:
                        im_est = im_est.reshape(
                            tuple(Gn.Nd) + (len(spectral_offsets),),
                            order=Gn.order,
                        )

                if compare_to_exact:
                    im_est_exact = dtft_adj(
                        data_dcf, Gn.omega, shape=Gn.Nd, n_shift=nshift_exact
                    )
                    ex_norm = xp.linalg.norm(im_est_exact)
                    rel_err = xp.linalg.norm(im_est - im_est_exact) / ex_norm
                    all_err_adj[i, j, k] = rel_err
                    if verbose:
                        print(
                            "{},{},{}: adjoint error = {}".format(
                                recon_case, precision, phasing, rel_err
                            )
                        )
                    rel_err_mag = (
                        xp.linalg.norm(np.abs(im_est) - np.abs(im_est_exact))
                        / ex_norm
                    )
                    if verbose:
                        print(
                            "{},{},{}: adjoint mag diff error = {}".format(
                                recon_case, precision, phasing, rel_err
                            )
                        )
                    assert_(rel_err < rtol)

                title = ", ".join([recon_case, precision, phasing])
                if show_figures:
                    from matplotlib import pyplot as plt
                    from pyvolplot import volshow

                    if compare_to_exact:
                        volshow(
                            [
                                im_est_exact,
                                im_est,
                                im_est_exact - im_est,
                                xp.abs(im_est_exact) - xp.abs(im_est),
                            ]
                        )
                    else:
                        volshow(im_est)
                        plt.title(title)
                alltimes[title] = times

    if return_operator:
        if return_errors:
            return Gn, alltimes, all_err_forward, all_err_adj
        return Gn, alltimes

    if return_errors:
        return alltimes, all_err_forward, all_err_adj
    return alltimes