def test_sense_model(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) mask = np.zeros(img_shape) mask[::2, ::2] = 1.0 A = linop.Sense(mps) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), A * img)
def test_sense_model_with_comm(self): img_shape = [16, 16] mps_shape = [8, 16, 16] comm = sp.Communicator() img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) comm.allreduce(img) comm.allreduce(mps) ksp = sp.fft(img * mps, axes=[-1, -2]) A = linop.Sense(mps[comm.rank::comm.size], comm=comm) npt.assert_allclose( A.H(ksp[comm.rank::comm.size]), np.sum(sp.ifft(ksp, axes=[-1, -2]) * mps.conjugate(), 0))
def test_noncart_sense_model_batch(self): img_shape = [16, 16] mps_shape = [8, 16, 16] img = sp.randn(img_shape, dtype=np.complex) mps = sp.randn(mps_shape, dtype=np.complex) y, x = np.mgrid[:16, :16] coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1) A = linop.Sense(mps, coord=coord, coil_batch_size=1) check_linop_adjoint(A, dtype=np.complex) npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]).ravel(), (A * img).ravel(), atol=0.1, rtol=0.1)
def test_st(self): # check to make sure profile roughly matches anticipated within d1, d2 N = 128 tb = 16 filts = ['ls', 'ms', 'pm', 'min', 'max'] for idx, filt in enumerate(filts): pulse = sp.mri.rf.dzrf(N, tb, ptype='st', ftype=filt, d1=0.01, d2=0.01) m = np.abs(sp.fft(pulse, norm=None)) pts = np.array( [m[int(N / 2 - 10)], m[int(N / 2)], m[int(N / 2 + 10)]]) npt.assert_almost_equal(pts, np.array([0, 1, 0]), decimal=2)
def dzls(n=64, tb=4, d1=0.01, d2=0.01): di = dinf(d1, d2) w = di / tb f = np.asarray([0, (1 - w) * (tb / 2), (1 + w) * (tb / 2), (n / 2)]) f = f / (n / 2) m = [1, 1, 0, 0] w = [1, d1 / d2] h = signal.firls(n + 1, f, m, w) # shift the filter half a sample to make it symmetric, like in MATLAB c = np.exp( 1j * 2 * np.pi / (2 * (n + 1)) * np.concatenate([np.arange(0, n / 2 + 1, 1), np.arange(-n / 2, 0, 1)])) h = np.real(sp.ifft(np.multiply(sp.fft(h, center=False), c), center=False)) # lop off extra sample h = h[:n] return h
def dz_recursive_rf(n_seg, tb, n, se_seq=False, tb_ref=8, z_pad_fact=4, win_fact=1.75, cancel_alpha_phs=True, t1=np.inf, tr_seg=60, use_mz=True, d1=0.01, d2=0.01, d1se=0.01, d2se=0.01): r"""Recursive SLR pulse design. Args: n_seg (int): number of segments designed by recursion. tb (int): time bandwidth product. n (int): pulse length. se_seq (bool): spin echo sequence. tb_ref (int): time bandwidth product of refocusing pulse. z_pad_fact (float): zero padding factor. win_fact (float): applied window factor. cancel_alpha_phs (bool): absorb the alpha phase profile from beta's profile, so they cancel for a flatter total phase t1 (float): t1 tr_seg (int): length of tr use_mz (bool): design the pulses accounting for the actual Mz profile d1 (float): passband ripple level in :math:'M_0^{-1}'. d2 (float): stopband ripple level in :math:'M_0^{-1}'. d1se (float): passband ripple level for se d2se (float): stopband ripple level for se Returns: If se_seq=True, 2-element tuple containing - **rf** (*array*): rf pulse out. - **rf_ref** (*array*): rf refocusing pulse out. """ # get refocusing pulse and its rotation parameters if se_seq is True: [bsf, d1se, d2se] = calc_ripples('se', d1se, d2se) b_ref = bsf * dzls(n, tb_ref, d1se, d2se) b_ref = np.concatenate( (np.zeros(int(z_pad_fact * n / 2 - n / 2)), b_ref, np.zeros(int(z_pad_fact * n / 2 - n / 2)))) rf_ref = b2rf(b_ref) bref = sp.fft(b_ref, norm=None) bref /= np.max(np.abs(bref)) bref_mag = np.abs(bref) aref_mag = np.abs(np.sqrt(1 - bref_mag**2)) flip_ref = 2 * np.arcsin(bref_mag[int(z_pad_fact * n / 2)]) \ * 180 / np.pi # get flip angles flip = np.zeros(n_seg) flip[n_seg - 1] = 90 for jj in range(n_seg - 2, -1, -1): if se_seq is False: flip[jj] = np.arctan(np.sin(flip[jj + 1] * np.pi / 180)) flip[jj] = flip[jj] * 180 / np.pi # deg else: flip[jj] = np.arctan( np.cos(flip_ref * np.pi / 180) * np.sin(flip[jj + 1] * np.pi / 180)) flip[jj] = flip[jj] * 180 / np.pi # deg # design first RF pulse b = np.zeros((int(z_pad_fact * n), n_seg), dtype=complex) b[int(z_pad_fact * n / 2 - n / 2):int(z_pad_fact * n / 2 + n / 2), 0] = \ dzls(n, tb, d1, d2) # b = np.concatenate((np.zeros(int(zPadFact*N/2-N/2)), b, # np.zeros(int(zPadFact*N/2-N/2)))) B = sp.fft(b[:, 0], norm=None) c = np.exp(-1j * 2 * np.pi / (n * z_pad_fact) / 2 * np.arange(-n * z_pad_fact / 2, n * z_pad_fact / 2, 1)) B = np.multiply(B, c) b[:, 0] = sp.ifft(B / np.max(np.abs(B)), norm=None) b[:, 0] *= np.sin(flip[0] * (np.pi / 180) / 2) rf = np.zeros((z_pad_fact * n, n_seg), dtype=complex) a = b2a(b[:, 0]) if cancel_alpha_phs: # cancel a phase by absorbing into b # Note that this is the only time we need to do it b_a_phase = sp.fft(b[:, 0], center=False, norm=None) * \ np.exp(-1j * np.angle(sp.fft(a[np.size(a)::-1], center=False, norm=None))) b[:, 0] = sp.ifft(b_a_phase, center=False, norm=None) rf[:, 0] = b2rf(b[:, 0]) # get min-phase alpha and its response # a = b2a(b[:, 0]) A = sp.fft(a) # calculate beta filter response B = sp.fft(b[:, 0], norm=None) if win_fact < z_pad_fact: win_len = (win_fact - 1) * n npad = n * z_pad_fact - win_fact * n # blackman window? window = signal.blackman(int((win_fact - 1) * n)) # split in half; stick N ones in the middle window = np.concatenate((window[0:int(win_len / 2)], np.ones(n), window[int(win_len / 2):])) window = np.concatenate( (np.zeros(int(npad / 2)), window, np.zeros(int(npad / 2)))) # apply windowing to first pulse for consistency b[:, 0] = np.multiply(b[:, 0], window) rf[:, 0] = b2rf(b[:, 0]) # recalculate B and A B = sp.fft(b[:, 0], norm=None) A = sp.fft(b2a(b[:, 0]), norm=None) # use A and B to get Mxy # Mxy = np.zeros((zPadFact*N, Nseg), dtype = complex) if se_seq is False: mxy0 = 2 * np.conj(A) * B else: mxy0 = 2 * A * np.conj(B) * bref**2 # Amplitude of next pulse's Mxy profile will be # |Mz*2*a*b| = |Mz*2*sqrt(1-abs(B).^2)*B|. # If we set this = |Mxy_1|, we can solve for |B| via solving quadratic # equation 4*Mz^2*(1-B^2)*B^2 = |Mxy_1|^2. # Subsequently solve for |A|, and get phase of A via min-phase, and # then get phase of B by dividing phase of A from first pulse's Mxy phase. mz = np.ones((z_pad_fact * n), dtype=complex) for jj in range(1, n_seg): # calculate Mz profile after previous pulse if se_seq is False: mz = mz * (1 - 2 * np.abs(B) ** 2) * np.exp(-tr_seg / t1) + \ (1 - np.exp(-tr_seg / t1)) else: mz = mz * (1 - 2 * (np.abs(A * bref_mag)**2 + np.abs(aref_mag * B)**2)) # (second term is about 1%) if use_mz is True: # design the pulses accounting for the # actual Mz profile (the full method) # set up quadratic equation to get |B| cq = -np.abs(mxy0)**2 if se_seq is False: bq = 4 * mz**2 aq = -4 * mz**2 else: bq = 4 * (bref_mag**4) * mz**2 aq = -4 * (bref_mag**4) * mz**2 bmag = np.sqrt( (-bq + np.real(np.sqrt(bq**2 - 4 * aq * cq))) / (2 * aq)) bmag[np.isnan(bmag)] = 0 # get A - easier to get complex A than complex B since |A| is # determined by |B|, and phase is gotten by min-phase relationship # Phase of B doesn't matter here since only profile mag is used by # b2a A = sp.fft(b2a(sp.ifft(bmag, norm=None)), norm=None) # trick: now we can get complex B from ratio of Mxy and A B = mxy0 / (2 * np.conj(A) * mz) else: # design assuming ideal Mz (conventional VFA) B *= np.sin(np.pi / 180 * flip[jj] / 2) \ / np.sin(np.pi / 180 * flip[jj - 1] / 2) A = sp.fft(b2a(sp.ifft(B, norm=None)), norm=None) # get polynomial b[:, jj] = sp.ifft(B, norm=None) if win_fact < z_pad_fact: b[:, jj] *= window # recalculate B and A B = sp.fft(b[:, jj], norm=None) A = sp.fft(b2a(b[:, jj]), norm=None) rf[:, jj] = b2rf(b[:, jj]) # truncate the RF if win_fact < z_pad_fact: pulse_len = int(win_fact * n) rf = rf[int(npad / 2):int(npad / 2 + pulse_len), :] if se_seq is False: return rf else: return rf, rf_ref
def dz_hadamard_b(n=128, g=5, gind=1, tb=4, d1=0.01, d2=0.01, shift=32): r"""Design a pulse with hadamard encoding Args: n (int): number of time points. g (int): order of the Hadamard matrix. gind (int): index of vector to use from Hadamard matrix for encoding. tb (int): time bandwidth product. d1 (float): passband ripple level in :math:'M_0^{-1}'. d2 (float): stopband ripple level in :math:'M_0^{-1}'. shift (int): n time points shift of pulse. Returns: b (array): SLR beta parameter. References: Souza, S.P., Szumowski, J., Dumoulin, C.L., Plewes, D.P. & Glover, G. 'Sima: Simultaneous multislice acquisition of MR images by hadamard - encoded excitation. J.Comput.Assist.Tomogr. 12, 1026–1030(1988). """ H = linalg.hadamard(g) encode = H[gind - 1, :] ftw = dinf(d1, d2) / tb # fractional transition width of the slab profile if gind == 1: # no sub-slices b = dzls(n, tb, d1, d2) else: # left stopband f = np.asarray([0, shift - (1 + ftw) * (tb / 2)]) m = np.asarray([0, 0]) w = np.asarray([d1 / d2]) # first sub-band ii = 1 gcent = shift + (ii - g / 2 - 1 / 2) * tb / g # first band center # first band left edge f = np.append(f, gcent - (tb / g / 2 - ftw * (tb / 2))) m = np.append(m, encode[ii - 1]) if encode[ii - 1] != encode[ii]: # add the first band's right edge and its amplitude, and a weight f = np.append(f, gcent + (tb / g / 2 - ftw * (tb / 2))) m = np.append(m, encode[ii - 1]) w = np.append(w, 1) # middle sub-bands for ii in range(2, g): gcent = shift + (ii - g / 2 - 1 / 2) * tb / g # center of band if encode[ii - 1] != encode[ii - 2]: # add a left edge and amp for this band f = np.append(f, gcent - (tb / g / 2 - ftw * (tb / 2))) m = np.append(m, encode[ii - 1]) if encode[ii - 1] != encode[ii]: # add a right edge and its amp, and a weight for this band f = np.append(f, gcent + (tb / g / 2 - ftw * (tb / 2))) m = np.append(m, encode[ii - 1]) w = np.append(w, 1) # last sub-band ii = g gcent = shift + (ii - g / 2 - 1 / 2) * tb / g # center of last band if encode[ii - 1] != encode[ii - 2]: # add a left edge and amp for the last band f = np.append(f, gcent - (tb / g / 2 - ftw * (tb / 2))) m = np.append(m, encode[ii - 1]) # add a right edge and its amp, and a weight for the last band f = np.append(f, gcent + (tb / g / 2 - ftw * (tb / 2))) m = np.append(m, encode[ii - 1]) w = np.append(w, 1) # right stop-band f = np.append(f, (shift + (1 + ftw) * (tb / 2), (n / 2))) / (n / 2) m = np.append(m, [0, 0]) w = np.append(w, d1 / d2) # separate the positive and negative bands mp = (m > 0).astype(float) mn = (m < 0).astype(float) # design the positive and negative filters c = np.exp(1j * 2 * np.pi / (2 * (n + 1)) * np.concatenate( [np.arange(0, n / 2 + 1, 1), np.arange(-n / 2, 0, 1)])) bp = signal.firls(n + 1, f, mp, w) # the positive filter bn = signal.firls(n + 1, f, mn, w) # the negative filter # combine the filters and demodulate b = sp.ifft(np.multiply(sp.fft(bp - bn, center=False), c), center=False) b = np.real(b[:n]) # hilbert transform to suppress negative passband b = signal.hilbert(b) # demodulate to DC c_shift = np.exp(-1j * 2 * np.pi / n * shift * np.arange(0, n, 1)) / 2 c_shift *= np.exp(-1j * np.pi / n * shift) b = np.multiply(b, c_shift) return b
def dz_gslider_b(n=128, g=5, gind=1, tb=4, d1=0.01, d2=0.01, phi=np.pi, shift=32): r"""Design a g-slider pulse b Args: n (int): number of time points. g (int): number of sub-slices. gind (int): subslice index. tb (int): time bandwidth product. d1 (float): passband ripple level in :math:'M_0^{-1}'. d2 (float): stopband ripple level in :math:'M_0^{-1}'. phi (float): subslice phase. shift (int): n time points shift of pulse. Returns: b (array): SLR beta parameter. References: Setsompop, K. et al. 'High-resolution in vivo diffusion imaging of the human brain with generalized slice dithered enhanced resolution: Simultaneous multislice (gSlider-SMS). Magn. Reson. Med.79, 141–151 (2018). """ ftw = dinf(d1, d2) / tb # fractional transition width of the slab profile if np.fmod(g, 2) and gind == int(np.ceil(g / 2)): # centered sub-slice if g == 1: # no sub-slices, as a sanity check b = dzls(n, tb, d1, d2) else: # Design 2 filters, to allow arbitrary phases on the subslice the # first is a wider notch filter with '0's where it the subslice # appears, and the second is the subslice. Multiply the subslice by # its phase and add the filters. f = np.asarray([ 0, (1 / g - ftw) * (tb / 2), (1 / g + ftw) * (tb / 2), (1 - ftw) * (tb / 2), (1 + ftw) * (tb / 2), (n / 2) ]) f = f / (n / 2) m_notch = [0, 0, 1, 1, 0, 0] m_sub = [1, 1, 0, 0, 0, 0] w = [1, 1, d1 / d2] b_notch = signal.firls(n + 1, f, m_notch, w) # the notched filter b_sub = signal.firls(n + 1, f, m_sub, w) # the subslice filter # add them with the subslice phase b = np.add(b_notch, np.multiply(np.exp(1j * phi), b_sub)) # shift the filter half a sample to make it symmetric, # like in MATLAB c = np.exp(1j * 2 * np.pi / (2 * (n + 1)) * np.concatenate( [np.arange(0, n / 2 + 1, 1), np.arange(-n / 2, 0, 1)])) b = sp.ifft(np.multiply(sp.fft(b, center=False), c), center=False) # lop off extra sample b = b[:n] else: # design filters for the slab and the subslice, hilbert xform them # to suppress their left bands, # then demodulate the result back to DC gcent = shift + (gind - g / 2 - 1 / 2) * tb / g if gind > 1 and gind < g: # separate transition bands for slab+slice f = np.asarray([ 0, shift - (1 + ftw) * (tb / 2), shift - (1 - ftw) * (tb / 2), gcent - (tb / g / 2 + ftw * (tb / 2)), gcent - (tb / g / 2 - ftw * (tb / 2)), gcent + (tb / g / 2 - ftw * (tb / 2)), gcent + (tb / g / 2 + ftw * (tb / 2)), shift + (1 - ftw) * (tb / 2), shift + (1 + ftw) * (tb / 2), (n / 2) ]) f = f / (n / 2) m_notch = [0, 0, 1, 1, 0, 0, 1, 1, 0, 0] m_sub = [0, 0, 0, 0, 1, 1, 0, 0, 0, 0] w = [d1 / d2, 1, 1, 1, d1 / d2] elif gind == 1: # the slab and slice share a left transition band f = np.asarray([ 0, shift - (1 + ftw) * (tb / 2), shift - (1 - ftw) * (tb / 2), gcent + (tb / g / 2 - ftw * (tb / 2)), gcent + (tb / g / 2 + ftw * (tb / 2)), shift + (1 - ftw) * (tb / 2), shift + (1 + ftw) * (tb / 2), (n / 2) ]) f = f / (n / 2) m_notch = [0, 0, 0, 0, 1, 1, 0, 0] m_sub = [0, 0, 1, 1, 0, 0, 0, 0] w = [d1 / d2, 1, 1, d1 / d2] elif gind == g: # the slab and slice share a right transition band f = np.asarray([ 0, shift - (1 + ftw) * (tb / 2), shift - (1 - ftw) * (tb / 2), gcent - (tb / g / 2 + ftw * (tb / 2)), gcent - (tb / g / 2 - ftw * (tb / 2)), shift + (1 - ftw) * (tb / 2), shift + (1 + ftw) * (tb / 2), (n / 2) ]) f = f / (n / 2) m_notch = [0, 0, 1, 1, 0, 0, 0, 0] m_sub = [0, 0, 0, 0, 1, 1, 0, 0] w = [d1 / d2, 1, 1, d1 / d2] c = np.exp(1j * 2 * np.pi / (2 * (n + 1)) * np.concatenate( [np.arange(0, n / 2 + 1, 1), np.arange(-n / 2, 0, 1)])) b_notch = signal.firls(n + 1, f, m_notch, w) # the notched filter b_notch = sp.ifft(np.multiply(sp.fft(b_notch, center=False), c), center=False) b_notch = np.real(b_notch[:n]) # hilbert transform to suppress negative passband b_notch = signal.hilbert(b_notch) b_sub = signal.firls(n + 1, f, m_sub, w) # the sub-band filter b_sub = sp.ifft(np.multiply(sp.fft(b_sub, center=False), c), center=False) b_sub = np.real(b_sub[:n]) # hilbert transform to suppress negative passband b_sub = signal.hilbert(b_sub) # add them with the subslice phase b = b_notch + np.exp(1j * phi) * b_sub # demodulate to DC c_shift = np.exp(-1j * 2 * np.pi / n * shift * np.arange(0, n, 1)) / 2 c_shift *= np.exp(-1j * np.pi / n * shift) b = np.multiply(b, c_shift) return b
#img = find2ndGrad2D(img) doubleShow(img, img2, dataTrue, 1, slice=12) #, slice=200) quit() img = np.load("./stefan_data/outputImage3D_train3_Index5.npy") img[img > 1.0] = 1.0 img[img < 0.0] = 0.0 img = find2ndGrad(img) plt.imshow(img[:, 100, :], cmap='gray') plt.show() quit() ksp = sp.fft(img) #, axes=[0, 1, 2]) #img = np.sum(np.abs(sp.ifft(ksp, axes=[-1, -2, -3]))**2, axis=0)**0.5 #img = np.fft.ifft(img) #img = np.abs(img) #print (np.sum(img)) #print (img.shape) #img[img>1.0] = 1.0 #img[img<0.0] = 0.0 plt.imshow(np.log(np.abs(ksp)[:, :, 160]), cmap='gray') #plt.imshow(np.log(np.abs(ksp)[:, 128, :]), cmap='gray') plt.show() quit() img_lr_bf = np.load("./stefan_data/img.npy")
def time_fft(self): y = sp.fft(self.x)
def time_fft_non_centered(self): y = sp.fft(self.x, center=False)
def circulant_precond(mps, weights=None, coord=None, lamda=0, device=sp.cpu_device): r"""Compute circulant preconditioner. Considers the optimization problem: .. math:: \min_P \| A^H A - F P F^H \|_2^2 where A is the Sense operator, and F is a unitary Fourier transform operator. Args: mps (array): sensitivity maps of shape [num_coils] + image shape. weights (array): k-space weights. coord (array): k-space coordinates of shape [...] + [ndim]. lamda (float): regularization. Returns: array: circulant preconditioner of image shape. """ if coord is not None: coord = sp.to_device(coord, device) if weights is not None: weights = sp.to_device(weights, device) dtype = mps.dtype device = sp.Device(device) xp = device.xp mps_shape = list(mps.shape) img_shape = mps_shape[1:] img2_shape = [i * 2 for i in img_shape] ndim = len(img_shape) scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape)**2 with device: idx = (slice(None, None, 2), ) * ndim if coord is None: ones = xp.zeros(img2_shape, dtype=dtype) if weights is None: ones[idx] = 1 else: ones[idx] = weights**0.5 psf = sp.ifft(ones) else: coord2 = coord * 2 ones = xp.ones(coord.shape[:-1], dtype=dtype) if weights is not None: ones *= weights**0.5 psf = sp.nufft_adjoint(ones, coord2, img2_shape) p_inv = 0 for mps_i in mps: mps_i = sp.to_device(mps_i, device) xcorr_fourier = xp.abs(sp.fft(xp.conj(mps_i), img2_shape))**2 xcorr = sp.ifft(xcorr_fourier) xcorr *= psf p_inv_i = sp.fft(xcorr) p_inv_i = p_inv_i[idx] p_inv_i *= scale if weights is not None: p_inv_i *= weights**0.5 p_inv += p_inv_i p_inv += lamda p_inv[p_inv == 0] = 1 p = 1 / p_inv return p.astype(dtype)
def kspace_precond(mps, weights=None, coord=None, lamda=0, device=sp.cpu_device, oversamp=1.25): r"""Compute a diagonal preconditioner in k-space. Considers the optimization problem: .. math:: \min_P \| P A A^H - I \|_F^2 where A is the Sense operator. Args: mps (array): sensitivity maps of shape [num_coils] + image shape. weights (array): k-space weights. coord (array): k-space coordinates of shape [...] + [ndim]. lamda (float): regularization. Returns: array: k-space preconditioner of same shape as k-space. """ dtype = mps.dtype if weights is not None: weights = sp.to_device(weights, device) device = sp.Device(device) xp = device.xp mps_shape = list(mps.shape) img_shape = mps_shape[1:] img2_shape = [i * 2 for i in img_shape] ndim = len(img_shape) scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape) with device: if coord is None: idx = (slice(None, None, 2), ) * ndim ones = xp.zeros(img2_shape, dtype=dtype) if weights is None: ones[idx] = 1 else: ones[idx] = weights**0.5 psf = sp.ifft(ones) else: coord2 = coord * 2 ones = xp.ones(coord.shape[:-1], dtype=dtype) if weights is not None: ones *= weights**0.5 psf = sp.nufft_adjoint(ones, coord2, img2_shape, oversamp=oversamp) p_inv = [] for mps_i in mps: mps_i = sp.to_device(mps_i, device) mps_i_norm2 = xp.linalg.norm(mps_i)**2 xcorr_fourier = 0 for mps_j in mps: mps_j = sp.to_device(mps_j, device) xcorr_fourier += xp.abs( sp.fft(mps_i * xp.conj(mps_j), img2_shape))**2 xcorr = sp.ifft(xcorr_fourier) xcorr *= psf if coord is None: p_inv_i = sp.fft(xcorr)[idx] else: p_inv_i = sp.nufft(xcorr, coord2, oversamp=oversamp) if weights is not None: p_inv_i *= weights**0.5 p_inv.append(p_inv_i * scale / mps_i_norm2) p_inv = (xp.abs(xp.stack(p_inv, axis=0)) + lamda) / (1 + lamda) p_inv[p_inv == 0] = 1 p = 1 / p_inv return p.astype(dtype)
pl.ImagePlot(admm_img) #%% md ## ADMM with circulant preconditioner #%% rho = 1 circ_precond = mr.circulant_precond(mps, coord=coord, device=device, lamda=rho) img_shape = mps.shape[1:] G = sp.linop.FiniteDifference(img_shape) g = G.H * G * sp.dirac(img_shape) g = sp.fft(g) g = sp.to_device(g, device=device) circ_precond = 1 / (1 / circ_precond + lamda * g) img_shape = mps.shape[1:] D = sp.linop.Multiply(img_shape, circ_precond) P = sp.linop.IFFT(img_shape) * D * sp.linop.FFT(img_shape) admm_cp_app = mr.app.TotalVariationRecon( ksp, mps, lamda=lamda, coord=coord, max_iter=max_iter // max_cg_iter, P=P, rho=rho, solver='ADMM', max_cg_iter=max_cg_iter, device=device, save_objective_values=True) admm_cp_img = admm_cp_app.run() pl.ImagePlot(admm_cp_img)
mesh[:, :, 1] = m2 return mesh.astype(np.float) name = 'img.jpg' image = Image.open(name).convert('L') arr = np.array(image) + 1j traj = cartisian2D(arr.shape, [1, 1], 1) plt.ScatterPlot(traj, title='Trajectory') image.close() arr = arr / np.max(arr[...]) print(traj.shape) kspaceNUFFT = sp.nufft(arr, traj) plt.ImagePlot(np.log(kspaceNUFFT), title='k-space data from NUFFT') kspaceFFT = sp.fft(arr) plt.ImagePlot(np.log(kspaceFFT), title='k-space data from FFT') print(kspaceFFT.shape) print(kspaceNUFFT.shape) sumNUFFT = np.sum(kspaceNUFFT) sumFFT = np.sum(kspaceFFT) if (np.allclose(kspaceNUFFT, kspaceFFT, rtol=10, atol=10) and np.isclose(sumNUFFT, sumFFT, rtol=50, atol=50)): print('Outputs are similar!') else: print('Outputs are NOT similar!')