Пример #1
0
    def test_sense_model(self):
        img_shape = [16, 16]
        mps_shape = [8, 16, 16]

        img = sp.randn(img_shape, dtype=np.complex)
        mps = sp.randn(mps_shape, dtype=np.complex)

        mask = np.zeros(img_shape)
        mask[::2, ::2] = 1.0

        A = linop.Sense(mps)

        check_linop_adjoint(A, dtype=np.complex)

        npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), A * img)
Пример #2
0
        def test_sense_model_with_comm(self):
            img_shape = [16, 16]
            mps_shape = [8, 16, 16]
            comm = sp.Communicator()

            img = sp.randn(img_shape, dtype=np.complex)
            mps = sp.randn(mps_shape, dtype=np.complex)
            comm.allreduce(img)
            comm.allreduce(mps)
            ksp = sp.fft(img * mps, axes=[-1, -2])

            A = linop.Sense(mps[comm.rank::comm.size], comm=comm)

            npt.assert_allclose(
                A.H(ksp[comm.rank::comm.size]),
                np.sum(sp.ifft(ksp, axes=[-1, -2]) * mps.conjugate(), 0))
Пример #3
0
    def test_noncart_sense_model_batch(self):
        img_shape = [16, 16]
        mps_shape = [8, 16, 16]

        img = sp.randn(img_shape, dtype=np.complex)
        mps = sp.randn(mps_shape, dtype=np.complex)

        y, x = np.mgrid[:16, :16]
        coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1)

        A = linop.Sense(mps, coord=coord, coil_batch_size=1)
        check_linop_adjoint(A, dtype=np.complex)
        npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]).ravel(),
                            (A * img).ravel(),
                            atol=0.1,
                            rtol=0.1)
Пример #4
0
    def test_st(self):
        #  check to make sure profile roughly matches anticipated within d1, d2
        N = 128
        tb = 16
        filts = ['ls', 'ms', 'pm', 'min', 'max']
        for idx, filt in enumerate(filts):
            pulse = sp.mri.rf.dzrf(N,
                                   tb,
                                   ptype='st',
                                   ftype=filt,
                                   d1=0.01,
                                   d2=0.01)

            m = np.abs(sp.fft(pulse, norm=None))

            pts = np.array(
                [m[int(N / 2 - 10)], m[int(N / 2)], m[int(N / 2 + 10)]])
            npt.assert_almost_equal(pts, np.array([0, 1, 0]), decimal=2)
Пример #5
0
def dzls(n=64, tb=4, d1=0.01, d2=0.01):
    di = dinf(d1, d2)
    w = di / tb
    f = np.asarray([0, (1 - w) * (tb / 2), (1 + w) * (tb / 2), (n / 2)])
    f = f / (n / 2)
    m = [1, 1, 0, 0]
    w = [1, d1 / d2]

    h = signal.firls(n + 1, f, m, w)
    # shift the filter half a sample to make it symmetric, like in MATLAB
    c = np.exp(
        1j * 2 * np.pi / (2 * (n + 1)) *
        np.concatenate([np.arange(0, n / 2 + 1, 1),
                        np.arange(-n / 2, 0, 1)]))
    h = np.real(sp.ifft(np.multiply(sp.fft(h, center=False), c), center=False))

    # lop off extra sample
    h = h[:n]

    return h
Пример #6
0
def dz_recursive_rf(n_seg,
                    tb,
                    n,
                    se_seq=False,
                    tb_ref=8,
                    z_pad_fact=4,
                    win_fact=1.75,
                    cancel_alpha_phs=True,
                    t1=np.inf,
                    tr_seg=60,
                    use_mz=True,
                    d1=0.01,
                    d2=0.01,
                    d1se=0.01,
                    d2se=0.01):
    r"""Recursive SLR pulse design.

    Args:
        n_seg (int): number of segments designed by recursion.
        tb (int): time bandwidth product.
        n (int): pulse length.
        se_seq (bool): spin echo sequence.
        tb_ref (int): time bandwidth product of refocusing pulse.
        z_pad_fact (float): zero padding factor.
        win_fact (float): applied window factor.
        cancel_alpha_phs (bool): absorb the alpha phase
            profile from beta's profile, so they cancel for a flatter
            total phase
        t1 (float): t1
        tr_seg (int): length of tr
        use_mz (bool): design the pulses accounting for the actual Mz profile
        d1 (float): passband ripple level in :math:'M_0^{-1}'.
        d2 (float): stopband ripple level in :math:'M_0^{-1}'.
        d1se (float): passband ripple level for se
        d2se (float): stopband ripple level for se

    Returns:
        If se_seq=True, 2-element tuple containing

        - **rf** (*array*): rf pulse out.
        - **rf_ref** (*array*): rf refocusing pulse out.
    """

    # get refocusing pulse and its rotation parameters
    if se_seq is True:
        [bsf, d1se, d2se] = calc_ripples('se', d1se, d2se)
        b_ref = bsf * dzls(n, tb_ref, d1se, d2se)
        b_ref = np.concatenate(
            (np.zeros(int(z_pad_fact * n / 2 - n / 2)), b_ref,
             np.zeros(int(z_pad_fact * n / 2 - n / 2))))
        rf_ref = b2rf(b_ref)
        bref = sp.fft(b_ref, norm=None)
        bref /= np.max(np.abs(bref))
        bref_mag = np.abs(bref)
        aref_mag = np.abs(np.sqrt(1 - bref_mag**2))
        flip_ref = 2 * np.arcsin(bref_mag[int(z_pad_fact * n / 2)]) \
            * 180 / np.pi

    # get flip angles
    flip = np.zeros(n_seg)
    flip[n_seg - 1] = 90
    for jj in range(n_seg - 2, -1, -1):
        if se_seq is False:
            flip[jj] = np.arctan(np.sin(flip[jj + 1] * np.pi / 180))
            flip[jj] = flip[jj] * 180 / np.pi  # deg
        else:
            flip[jj] = np.arctan(
                np.cos(flip_ref * np.pi / 180) *
                np.sin(flip[jj + 1] * np.pi / 180))
            flip[jj] = flip[jj] * 180 / np.pi  # deg

    # design first RF pulse
    b = np.zeros((int(z_pad_fact * n), n_seg), dtype=complex)
    b[int(z_pad_fact * n / 2 - n / 2):int(z_pad_fact * n / 2 + n / 2), 0] = \
        dzls(n, tb, d1, d2)
    # b = np.concatenate((np.zeros(int(zPadFact*N/2-N/2)), b,
    #    np.zeros(int(zPadFact*N/2-N/2))))
    B = sp.fft(b[:, 0], norm=None)

    c = np.exp(-1j * 2 * np.pi / (n * z_pad_fact) / 2 *
               np.arange(-n * z_pad_fact / 2, n * z_pad_fact / 2, 1))
    B = np.multiply(B, c)
    b[:, 0] = sp.ifft(B / np.max(np.abs(B)), norm=None)
    b[:, 0] *= np.sin(flip[0] * (np.pi / 180) / 2)
    rf = np.zeros((z_pad_fact * n, n_seg), dtype=complex)
    a = b2a(b[:, 0])
    if cancel_alpha_phs:
        # cancel a phase by absorbing into b
        # Note that this is the only time we need to do it
        b_a_phase = sp.fft(b[:, 0], center=False, norm=None) * \
            np.exp(-1j * np.angle(sp.fft(a[np.size(a)::-1],
                                         center=False, norm=None)))
        b[:, 0] = sp.ifft(b_a_phase, center=False, norm=None)
    rf[:, 0] = b2rf(b[:, 0])

    # get min-phase alpha and its response
    # a = b2a(b[:, 0])
    A = sp.fft(a)

    # calculate beta filter response
    B = sp.fft(b[:, 0], norm=None)

    if win_fact < z_pad_fact:
        win_len = (win_fact - 1) * n
        npad = n * z_pad_fact - win_fact * n
        # blackman window?
        window = signal.blackman(int((win_fact - 1) * n))
        # split in half; stick N ones in the middle
        window = np.concatenate((window[0:int(win_len / 2)], np.ones(n),
                                 window[int(win_len / 2):]))
        window = np.concatenate(
            (np.zeros(int(npad / 2)), window, np.zeros(int(npad / 2))))
        # apply windowing to first pulse for consistency
        b[:, 0] = np.multiply(b[:, 0], window)
        rf[:, 0] = b2rf(b[:, 0])
        # recalculate B and A
        B = sp.fft(b[:, 0], norm=None)
        A = sp.fft(b2a(b[:, 0]), norm=None)

    # use A and B to get Mxy
    # Mxy = np.zeros((zPadFact*N, Nseg), dtype = complex)
    if se_seq is False:
        mxy0 = 2 * np.conj(A) * B
    else:
        mxy0 = 2 * A * np.conj(B) * bref**2

    # Amplitude of next pulse's Mxy profile will be
    #               |Mz*2*a*b| = |Mz*2*sqrt(1-abs(B).^2)*B|.
    # If we set this = |Mxy_1|, we can solve for |B| via solving quadratic
    # equation 4*Mz^2*(1-B^2)*B^2 = |Mxy_1|^2.
    # Subsequently solve for |A|, and get phase of A via min-phase, and
    # then get phase of B by dividing phase of A from first pulse's Mxy phase.
    mz = np.ones((z_pad_fact * n), dtype=complex)
    for jj in range(1, n_seg):

        # calculate Mz profile after previous pulse
        if se_seq is False:
            mz = mz * (1 - 2 * np.abs(B) ** 2) * np.exp(-tr_seg / t1) + \
                 (1 - np.exp(-tr_seg / t1))
        else:
            mz = mz * (1 - 2 *
                       (np.abs(A * bref_mag)**2 + np.abs(aref_mag * B)**2))
            # (second term is about 1%)

        if use_mz is True:  # design the pulses accounting for the
            # actual Mz profile (the full method)
            # set up quadratic equation to get |B|
            cq = -np.abs(mxy0)**2
            if se_seq is False:
                bq = 4 * mz**2
                aq = -4 * mz**2
            else:
                bq = 4 * (bref_mag**4) * mz**2
                aq = -4 * (bref_mag**4) * mz**2
            bmag = np.sqrt(
                (-bq + np.real(np.sqrt(bq**2 - 4 * aq * cq))) / (2 * aq))
            bmag[np.isnan(bmag)] = 0
            # get A - easier to get complex A than complex B since |A| is
            # determined by |B|, and phase is gotten by min-phase relationship
            # Phase of B doesn't matter here since only profile mag is used by
            # b2a
            A = sp.fft(b2a(sp.ifft(bmag, norm=None)), norm=None)
            # trick: now we can get complex B from ratio of Mxy and A
            B = mxy0 / (2 * np.conj(A) * mz)

        else:  # design assuming ideal Mz (conventional VFA)

            B *= np.sin(np.pi / 180 * flip[jj] / 2) \
                 / np.sin(np.pi / 180 * flip[jj - 1] / 2)
            A = sp.fft(b2a(sp.ifft(B, norm=None)), norm=None)

        # get polynomial
        b[:, jj] = sp.ifft(B, norm=None)

        if win_fact < z_pad_fact:
            b[:, jj] *= window
            # recalculate B and A
            B = sp.fft(b[:, jj], norm=None)
            A = sp.fft(b2a(b[:, jj]), norm=None)

        rf[:, jj] = b2rf(b[:, jj])

    # truncate the RF
    if win_fact < z_pad_fact:
        pulse_len = int(win_fact * n)
        rf = rf[int(npad / 2):int(npad / 2 + pulse_len), :]

    if se_seq is False:
        return rf
    else:
        return rf, rf_ref
Пример #7
0
def dz_hadamard_b(n=128, g=5, gind=1, tb=4, d1=0.01, d2=0.01, shift=32):
    r"""Design a pulse with hadamard encoding

    Args:
        n (int): number of time points.
        g (int): order of the Hadamard matrix.
        gind (int): index of vector to use from Hadamard matrix for encoding.
        tb (int): time bandwidth product.
        d1 (float): passband ripple level in :math:'M_0^{-1}'.
        d2 (float): stopband ripple level in :math:'M_0^{-1}'.
        shift (int): n time points shift of pulse.

    Returns:
        b (array): SLR beta parameter.

    References:
            Souza, S.P., Szumowski, J., Dumoulin, C.L., Plewes, D.P. &
            Glover, G. 'Sima: Simultaneous multislice acquisition of MR images
            by hadamard - encoded excitation. J.Comput.Assist.Tomogr. 12,
            1026–1030(1988).
    """

    H = linalg.hadamard(g)
    encode = H[gind - 1, :]

    ftw = dinf(d1, d2) / tb  # fractional transition width of the slab profile

    if gind == 1:  # no sub-slices
        b = dzls(n, tb, d1, d2)
    else:
        # left stopband
        f = np.asarray([0, shift - (1 + ftw) * (tb / 2)])
        m = np.asarray([0, 0])
        w = np.asarray([d1 / d2])
        # first sub-band
        ii = 1
        gcent = shift + (ii - g / 2 - 1 / 2) * tb / g  # first band center
        # first band left edge
        f = np.append(f, gcent - (tb / g / 2 - ftw * (tb / 2)))
        m = np.append(m, encode[ii - 1])
        if encode[ii - 1] != encode[ii]:
            # add the first band's right edge and its amplitude, and a weight
            f = np.append(f, gcent + (tb / g / 2 - ftw * (tb / 2)))
            m = np.append(m, encode[ii - 1])
            w = np.append(w, 1)
        # middle sub-bands
        for ii in range(2, g):
            gcent = shift + (ii - g / 2 - 1 / 2) * tb / g  # center of band
            if encode[ii - 1] != encode[ii - 2]:
                # add a left edge and amp for this band
                f = np.append(f, gcent - (tb / g / 2 - ftw * (tb / 2)))
                m = np.append(m, encode[ii - 1])
            if encode[ii - 1] != encode[ii]:
                # add a right edge and its amp, and a weight for this band
                f = np.append(f, gcent + (tb / g / 2 - ftw * (tb / 2)))
                m = np.append(m, encode[ii - 1])
                w = np.append(w, 1)
        # last sub-band
        ii = g
        gcent = shift + (ii - g / 2 - 1 / 2) * tb / g  # center of last band
        if encode[ii - 1] != encode[ii - 2]:
            # add a left edge and amp for the last band
            f = np.append(f, gcent - (tb / g / 2 - ftw * (tb / 2)))
            m = np.append(m, encode[ii - 1])
        # add a right edge and its amp, and a weight for the last band
        f = np.append(f, gcent + (tb / g / 2 - ftw * (tb / 2)))
        m = np.append(m, encode[ii - 1])
        w = np.append(w, 1)
        # right stop-band
        f = np.append(f, (shift + (1 + ftw) * (tb / 2), (n / 2))) / (n / 2)
        m = np.append(m, [0, 0])
        w = np.append(w, d1 / d2)

        # separate the positive and negative bands
        mp = (m > 0).astype(float)
        mn = (m < 0).astype(float)

        # design the positive and negative filters
        c = np.exp(1j * 2 * np.pi / (2 * (n + 1)) * np.concatenate(
            [np.arange(0, n / 2 + 1, 1),
             np.arange(-n / 2, 0, 1)]))
        bp = signal.firls(n + 1, f, mp, w)  # the positive filter
        bn = signal.firls(n + 1, f, mn, w)  # the negative filter

        # combine the filters and demodulate
        b = sp.ifft(np.multiply(sp.fft(bp - bn, center=False), c),
                    center=False)

        b = np.real(b[:n])
        # hilbert transform to suppress negative passband
        b = signal.hilbert(b)
        # demodulate to DC
        c_shift = np.exp(-1j * 2 * np.pi / n * shift * np.arange(0, n, 1)) / 2
        c_shift *= np.exp(-1j * np.pi / n * shift)
        b = np.multiply(b, c_shift)

    return b
Пример #8
0
def dz_gslider_b(n=128,
                 g=5,
                 gind=1,
                 tb=4,
                 d1=0.01,
                 d2=0.01,
                 phi=np.pi,
                 shift=32):
    r"""Design a g-slider pulse b

    Args:
        n (int): number of time points.
        g (int): number of sub-slices.
        gind (int): subslice index.
        tb (int): time bandwidth product.
        d1 (float): passband ripple level in :math:'M_0^{-1}'.
        d2 (float): stopband ripple level in :math:'M_0^{-1}'.
        phi (float): subslice phase.
        shift (int): n time points shift of pulse.

    Returns:
        b (array): SLR beta parameter.

    References:
        Setsompop, K. et al. 'High-resolution in vivo diffusion imaging of the
        human brain with generalized slice dithered enhanced resolution:
        Simultaneous multislice (gSlider-SMS). Magn. Reson. Med.79, 141–151
        (2018).
    """
    ftw = dinf(d1, d2) / tb  # fractional transition width of the slab profile

    if np.fmod(g, 2) and gind == int(np.ceil(g / 2)):  # centered sub-slice
        if g == 1:  # no sub-slices, as a sanity check
            b = dzls(n, tb, d1, d2)
        else:
            # Design 2 filters, to allow arbitrary phases on the subslice the
            # first is a wider notch filter with '0's where it the subslice
            # appears, and the second is the subslice. Multiply the subslice by
            # its phase and add the filters.
            f = np.asarray([
                0, (1 / g - ftw) * (tb / 2), (1 / g + ftw) * (tb / 2),
                (1 - ftw) * (tb / 2), (1 + ftw) * (tb / 2), (n / 2)
            ])
            f = f / (n / 2)
            m_notch = [0, 0, 1, 1, 0, 0]
            m_sub = [1, 1, 0, 0, 0, 0]
            w = [1, 1, d1 / d2]

            b_notch = signal.firls(n + 1, f, m_notch, w)  # the notched filter
            b_sub = signal.firls(n + 1, f, m_sub, w)  # the subslice filter
            # add them with the subslice phase
            b = np.add(b_notch, np.multiply(np.exp(1j * phi), b_sub))
            # shift the filter half a sample to make it symmetric,
            # like in MATLAB
            c = np.exp(1j * 2 * np.pi / (2 * (n + 1)) * np.concatenate(
                [np.arange(0, n / 2 + 1, 1),
                 np.arange(-n / 2, 0, 1)]))
            b = sp.ifft(np.multiply(sp.fft(b, center=False), c), center=False)
            # lop off extra sample
            b = b[:n]

    else:
        # design filters for the slab and the subslice, hilbert xform them
        # to suppress their left bands,
        # then demodulate the result back to DC
        gcent = shift + (gind - g / 2 - 1 / 2) * tb / g
        if gind > 1 and gind < g:
            # separate transition bands for slab+slice
            f = np.asarray([
                0, shift - (1 + ftw) * (tb / 2), shift - (1 - ftw) * (tb / 2),
                gcent - (tb / g / 2 + ftw * (tb / 2)),
                gcent - (tb / g / 2 - ftw * (tb / 2)),
                gcent + (tb / g / 2 - ftw * (tb / 2)),
                gcent + (tb / g / 2 + ftw * (tb / 2)),
                shift + (1 - ftw) * (tb / 2), shift + (1 + ftw) * (tb / 2),
                (n / 2)
            ])
            f = f / (n / 2)
            m_notch = [0, 0, 1, 1, 0, 0, 1, 1, 0, 0]
            m_sub = [0, 0, 0, 0, 1, 1, 0, 0, 0, 0]
            w = [d1 / d2, 1, 1, 1, d1 / d2]
        elif gind == 1:
            # the slab and slice share a left transition band
            f = np.asarray([
                0, shift - (1 + ftw) * (tb / 2), shift - (1 - ftw) * (tb / 2),
                gcent + (tb / g / 2 - ftw * (tb / 2)),
                gcent + (tb / g / 2 + ftw * (tb / 2)),
                shift + (1 - ftw) * (tb / 2), shift + (1 + ftw) * (tb / 2),
                (n / 2)
            ])
            f = f / (n / 2)
            m_notch = [0, 0, 0, 0, 1, 1, 0, 0]
            m_sub = [0, 0, 1, 1, 0, 0, 0, 0]
            w = [d1 / d2, 1, 1, d1 / d2]
        elif gind == g:
            # the slab and slice share a right transition band
            f = np.asarray([
                0, shift - (1 + ftw) * (tb / 2), shift - (1 - ftw) * (tb / 2),
                gcent - (tb / g / 2 + ftw * (tb / 2)),
                gcent - (tb / g / 2 - ftw * (tb / 2)),
                shift + (1 - ftw) * (tb / 2), shift + (1 + ftw) * (tb / 2),
                (n / 2)
            ])
            f = f / (n / 2)
            m_notch = [0, 0, 1, 1, 0, 0, 0, 0]
            m_sub = [0, 0, 0, 0, 1, 1, 0, 0]
            w = [d1 / d2, 1, 1, d1 / d2]

        c = np.exp(1j * 2 * np.pi / (2 * (n + 1)) * np.concatenate(
            [np.arange(0, n / 2 + 1, 1),
             np.arange(-n / 2, 0, 1)]))

        b_notch = signal.firls(n + 1, f, m_notch, w)  # the notched filter
        b_notch = sp.ifft(np.multiply(sp.fft(b_notch, center=False), c),
                          center=False)
        b_notch = np.real(b_notch[:n])
        # hilbert transform to suppress negative passband
        b_notch = signal.hilbert(b_notch)

        b_sub = signal.firls(n + 1, f, m_sub, w)  # the sub-band filter
        b_sub = sp.ifft(np.multiply(sp.fft(b_sub, center=False), c),
                        center=False)

        b_sub = np.real(b_sub[:n])
        # hilbert transform to suppress negative passband
        b_sub = signal.hilbert(b_sub)

        # add them with the subslice phase
        b = b_notch + np.exp(1j * phi) * b_sub

        # demodulate to DC
        c_shift = np.exp(-1j * 2 * np.pi / n * shift * np.arange(0, n, 1)) / 2
        c_shift *= np.exp(-1j * np.pi / n * shift)

        b = np.multiply(b, c_shift)

    return b
Пример #9
0
#img = find2ndGrad2D(img)

doubleShow(img, img2, dataTrue, 1, slice=12)  #, slice=200)
quit()

img = np.load("./stefan_data/outputImage3D_train3_Index5.npy")

img[img > 1.0] = 1.0
img[img < 0.0] = 0.0

img = find2ndGrad(img)
plt.imshow(img[:, 100, :], cmap='gray')
plt.show()
quit()

ksp = sp.fft(img)  #, axes=[0, 1, 2])
#img = np.sum(np.abs(sp.ifft(ksp, axes=[-1, -2, -3]))**2, axis=0)**0.5

#img = np.fft.ifft(img)
#img = np.abs(img)
#print (np.sum(img))
#print (img.shape)

#img[img>1.0] = 1.0
#img[img<0.0] = 0.0
plt.imshow(np.log(np.abs(ksp)[:, :, 160]), cmap='gray')
#plt.imshow(np.log(np.abs(ksp)[:, 128, :]), cmap='gray')
plt.show()
quit()

img_lr_bf = np.load("./stefan_data/img.npy")
Пример #10
0
 def time_fft(self):
     y = sp.fft(self.x)
Пример #11
0
 def time_fft_non_centered(self):
     y = sp.fft(self.x, center=False)
Пример #12
0
def circulant_precond(mps,
                      weights=None,
                      coord=None,
                      lamda=0,
                      device=sp.cpu_device):
    r"""Compute circulant preconditioner.

    Considers the optimization problem:

    .. math::
        \min_P \| A^H A - F P F^H  \|_2^2

    where A is the Sense operator,
    and F is a unitary Fourier transform operator.

    Args:
        mps (array): sensitivity maps of shape [num_coils] + image shape.
        weights (array): k-space weights.
        coord (array): k-space coordinates of shape [...] + [ndim].
        lamda (float): regularization.

    Returns:
        array: circulant preconditioner of image shape.

    """
    if coord is not None:
        coord = sp.to_device(coord, device)

    if weights is not None:
        weights = sp.to_device(weights, device)

    dtype = mps.dtype
    device = sp.Device(device)
    xp = device.xp

    mps_shape = list(mps.shape)
    img_shape = mps_shape[1:]
    img2_shape = [i * 2 for i in img_shape]
    ndim = len(img_shape)

    scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape)**2
    with device:
        idx = (slice(None, None, 2), ) * ndim
        if coord is None:
            ones = xp.zeros(img2_shape, dtype=dtype)
            if weights is None:
                ones[idx] = 1
            else:
                ones[idx] = weights**0.5

            psf = sp.ifft(ones)
        else:
            coord2 = coord * 2
            ones = xp.ones(coord.shape[:-1], dtype=dtype)
            if weights is not None:
                ones *= weights**0.5

            psf = sp.nufft_adjoint(ones, coord2, img2_shape)

        p_inv = 0
        for mps_i in mps:
            mps_i = sp.to_device(mps_i, device)
            xcorr_fourier = xp.abs(sp.fft(xp.conj(mps_i), img2_shape))**2
            xcorr = sp.ifft(xcorr_fourier)
            xcorr *= psf
            p_inv_i = sp.fft(xcorr)
            p_inv_i = p_inv_i[idx]
            p_inv_i *= scale
            if weights is not None:
                p_inv_i *= weights**0.5

            p_inv += p_inv_i

        p_inv += lamda
        p_inv[p_inv == 0] = 1
        p = 1 / p_inv

        return p.astype(dtype)
Пример #13
0
def kspace_precond(mps,
                   weights=None,
                   coord=None,
                   lamda=0,
                   device=sp.cpu_device,
                   oversamp=1.25):
    r"""Compute a diagonal preconditioner in k-space.

    Considers the optimization problem:

    .. math::
        \min_P \| P A A^H - I \|_F^2

    where A is the Sense operator.

    Args:
        mps (array): sensitivity maps of shape [num_coils] + image shape.
        weights (array): k-space weights.
        coord (array): k-space coordinates of shape [...] + [ndim].
        lamda (float): regularization.

    Returns:
        array: k-space preconditioner of same shape as k-space.

    """
    dtype = mps.dtype

    if weights is not None:
        weights = sp.to_device(weights, device)

    device = sp.Device(device)
    xp = device.xp

    mps_shape = list(mps.shape)
    img_shape = mps_shape[1:]
    img2_shape = [i * 2 for i in img_shape]
    ndim = len(img_shape)

    scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape)
    with device:
        if coord is None:
            idx = (slice(None, None, 2), ) * ndim

            ones = xp.zeros(img2_shape, dtype=dtype)
            if weights is None:
                ones[idx] = 1
            else:
                ones[idx] = weights**0.5

            psf = sp.ifft(ones)
        else:
            coord2 = coord * 2
            ones = xp.ones(coord.shape[:-1], dtype=dtype)
            if weights is not None:
                ones *= weights**0.5

            psf = sp.nufft_adjoint(ones, coord2, img2_shape, oversamp=oversamp)

        p_inv = []
        for mps_i in mps:
            mps_i = sp.to_device(mps_i, device)
            mps_i_norm2 = xp.linalg.norm(mps_i)**2
            xcorr_fourier = 0
            for mps_j in mps:
                mps_j = sp.to_device(mps_j, device)
                xcorr_fourier += xp.abs(
                    sp.fft(mps_i * xp.conj(mps_j), img2_shape))**2

            xcorr = sp.ifft(xcorr_fourier)
            xcorr *= psf
            if coord is None:
                p_inv_i = sp.fft(xcorr)[idx]
            else:
                p_inv_i = sp.nufft(xcorr, coord2, oversamp=oversamp)

            if weights is not None:
                p_inv_i *= weights**0.5

            p_inv.append(p_inv_i * scale / mps_i_norm2)

        p_inv = (xp.abs(xp.stack(p_inv, axis=0)) + lamda) / (1 + lamda)
        p_inv[p_inv == 0] = 1
        p = 1 / p_inv

        return p.astype(dtype)
pl.ImagePlot(admm_img)

#%% md

## ADMM with circulant preconditioner

#%%

rho = 1
circ_precond = mr.circulant_precond(mps, coord=coord, device=device, lamda=rho)

img_shape = mps.shape[1:]
G = sp.linop.FiniteDifference(img_shape)
g = G.H * G * sp.dirac(img_shape)
g = sp.fft(g)
g = sp.to_device(g, device=device)
circ_precond = 1 / (1 / circ_precond + lamda * g)

img_shape = mps.shape[1:]
D = sp.linop.Multiply(img_shape, circ_precond)
P = sp.linop.IFFT(img_shape) * D * sp.linop.FFT(img_shape)

admm_cp_app = mr.app.TotalVariationRecon(
        ksp, mps, lamda=lamda, coord=coord, max_iter=max_iter // max_cg_iter,
        P=P, rho=rho,
        solver='ADMM', max_cg_iter=max_cg_iter, device=device, save_objective_values=True)
admm_cp_img = admm_cp_app.run()

pl.ImagePlot(admm_cp_img)
Пример #15
0
    mesh[:, :, 1] = m2
    return mesh.astype(np.float)


name = 'img.jpg'
image = Image.open(name).convert('L')

arr = np.array(image) + 1j
traj = cartisian2D(arr.shape, [1, 1], 1)
plt.ScatterPlot(traj, title='Trajectory')
image.close()
arr = arr / np.max(arr[...])

print(traj.shape)

kspaceNUFFT = sp.nufft(arr, traj)

plt.ImagePlot(np.log(kspaceNUFFT), title='k-space data from NUFFT')

kspaceFFT = sp.fft(arr)

plt.ImagePlot(np.log(kspaceFFT), title='k-space data from FFT')
print(kspaceFFT.shape)
print(kspaceNUFFT.shape)
sumNUFFT = np.sum(kspaceNUFFT)
sumFFT = np.sum(kspaceFFT)
if (np.allclose(kspaceNUFFT, kspaceFFT, rtol=10, atol=10)
        and np.isclose(sumNUFFT, sumFFT, rtol=50, atol=50)):
    print('Outputs are  similar!')
else:
    print('Outputs are NOT similar!')