def _AHyH_L(self, t): # Download k-space arrays. tr_start = t * self.tr_per_frame tr_end = (t + 1) * self.tr_per_frame coord_t = sp.to_device(self.coord[tr_start:tr_end], self.device) dcf_t = sp.to_device(self.dcf[tr_start:tr_end], self.device) ksp_t = sp.to_device(self.ksp[:, tr_start:tr_end], self.device) # A^H(y_t) AHy_t = 0 for c in range(self.C): mps_c = sp.to_device(self.mps[c], self.device) AHy_tc = sp.nufft_adjoint(dcf_t * ksp_t[c], coord_t, oshape=self.img_shape) AHy_tc *= self.xp.conj(mps_c) AHy_t += AHy_tc if self.comm is not None: self.comm.allreduce(AHy_t) for j in range(self.J): AHy_tj = self.B[j].H(AHy_t) self.R[j][t] = self.xp.sum(AHy_tj * self.xp.conj(self.L[j]), axis=range(-self.D, 0), keepdims=True)
def gradf(self, mrimg): out = self.xp.zeros_like(mrimg) for b in range(self.B): for c in range(self.C): mps_c = sp.to_device(self.mps[c], self.device) out[b] += sp.nufft_adjoint( self.bdcf[b] * (sp.nufft(mrimg[b] * mps_c, self.bcoord[b]) - self.bksp[b][c]), self.bcoord[b], oshape=mrimg.shape[1:]) * self.xp.conj(mps_c) if self.comm is not None: self.comm.allreduce(out) eps = 1e-31 for b in range(self.B): if b > 0: diff = mrimg[b] - mrimg[b - 1] sp.axpy(out[b], self.lamda, diff / (self.xp.abs(diff) + eps)) if b < self.B - 1: diff = mrimg[b] - mrimg[b + 1] sp.axpy(out[b], self.lamda, diff / (self.xp.abs(diff) + eps)) return out
def _normalize(self): with self.device: # Estimate maximum eigenvalue. coord_t = sp.to_device(self.coord[:self.tr_per_frame], self.device) dcf_t = sp.to_device(self.dcf[:self.tr_per_frame], self.device) F = sp.linop.NUFFT(self.img_shape, coord_t) W = sp.linop.Multiply(F.oshape, dcf_t) max_eig = sp.app.MaxEig(F.H * W * F, max_iter=self.max_power_iter, dtype=self.dtype, device=self.device, show_pbar=self.show_pbar).run() self.dcf /= max_eig # Estimate scaling. img_adj = 0 dcf = sp.to_device(self.dcf, self.device) coord = sp.to_device(self.coord, self.device) for c in range(self.C): mps_c = sp.to_device(self.mps[c], self.device) ksp_c = sp.to_device(self.ksp[c], self.device) img_adj_c = sp.nufft_adjoint(ksp_c * dcf, coord, self.img_shape) img_adj_c *= self.xp.conj(mps_c) img_adj += img_adj_c if self.comm is not None: self.comm.allreduce(img_adj) img_adj_norm = self.xp.linalg.norm(img_adj).item() self.ksp /= img_adj_norm
def _normalize(self): # Normalize using first phase. with device: mrimg_adj = 0 for c in range(self.C): mrimg_c = sp.nufft_adjoint(self.bksp[0][c] * self.bdcf[0], self.bcoord[0], self.img_shape) mrimg_c *= self.xp.conj(sp.to_device(mps[c], device)) mrimg_adj += mrimg_c if comm is not None: comm.allreduce(mrimg_adj) # Get maximum eigenvalue. F = sp.linop.NUFFT(self.img_shape, self.bcoord[0]) W = sp.linop.Multiply(F.oshape, self.bdcf[0]) max_eig = sp.app.MaxEig(F.H * W * F, max_iter=self.max_power_iter, dtype=ksp.dtype, device=device, show_pbar=self.show_pbar).run() # Normalize self.alpha /= max_eig self.lamda *= max_eig * self.xp.abs(mrimg_adj).max().item()
def autofov(ksp, coord, dcf, num_ro=100, device=sp.cpu_device, thresh=0.1): """Automatic estimation of FOV. FOV is estimated by thresholding a low resolution gridded image. coord will be modified in-place. Args: ksp (array): k-space measurements of shape (C, num_tr, num_ro, D). where C is the number of channels, num_tr is the number of TRs, num_ro is the readout points, and D is the number of spatial dimensions. coord (array): k-space coordinates of shape (num_tr, num_ro, D). dcf (array): density compensation factor of shape (num_tr, num_ro). num_ro (int): number of read-out points. device (Device): computing device. thresh (float): threshold between 0 and 1. """ device = sp.Device(device) xp = device.xp with device: kspc = ksp[:, :, :num_ro] coordc = coord[:, :num_ro, :] dcfc = dcf[:, :num_ro] coordc2 = sp.to_device(coordc * 2, device) num_coils = len(kspc) imgc_shape = np.array(sp.estimate_shape(coordc)) imgc2_shape = sp.estimate_shape(coordc2) imgc2_center = [i // 2 for i in imgc2_shape] imgc2 = sp.nufft_adjoint(sp.to_device(dcfc * kspc, device), coordc2, [num_coils] + imgc2_shape) imgc2 = xp.sum(xp.abs(imgc2)**2, axis=0)**0.5 if imgc2.ndim == 3: imgc2_cor = imgc2[:, imgc2.shape[1] // 2, :] thresh *= imgc2_cor.max() else: thresh *= imgc2.max() boxc = imgc2 > thresh boxc = sp.to_device(boxc) boxc_idx = np.nonzero(boxc) boxc_shape = np.array([ int(np.abs(boxc_idx[i] - imgc2_center[i]).max()) * 2 for i in range(imgc2.ndim) ]) img_scale = boxc_shape / imgc_shape coord *= img_scale
def gridding_recon(ksp, coord, dcf, T=1, device=sp.cpu_device): """ Gridding reconstruction. Args: ksp (array): k-space measurements of shape (C, num_tr, num_ro, D). where C is the number of channels, num_tr is the number of TRs, num_ro is the readout points, and D is the number of spatial dimensions. coord (array): k-space coordinates of shape (num_tr, num_ro, D). dcf (array): density compensation factor of shape (num_tr, num_ro). mps (array): sensitivity maps of shape (C, N_D, ..., N_1). where (N_D, ..., N_1) represents the image shape. T (int): number of frames. Returns: img (array): image of shape (T, N_D, ..., N_1). """ device = sp.Device(device) xp = device.xp num_coils, num_tr, num_ro = ksp.shape tr_per_frame = num_tr // T img_shape = sp.estimate_shape(coord) with device: img = [] for t in range(T): tr_start = t * tr_per_frame tr_end = (t + 1) * tr_per_frame coord_t = sp.to_device(coord[tr_start:tr_end], device) dcf_t = sp.to_device(dcf[tr_start:tr_end], device) img_t = 0 for c in range(num_coils): logging.info(f'Reconstructing time {t}, coil {c}') ksp_tc = sp.to_device(ksp[c, tr_start:tr_end, :], device) img_t += xp.abs( sp.nufft_adjoint(ksp_tc * dcf_t, coord_t, img_shape))**2 img_t = img_t**0.5 img.append(sp.to_device(img_t)) img = np.stack(img) return img
def time_nufft_adjoint(self): y = sp.nufft_adjoint(self.x, self.coord)
def nufft_adj1(data, traj, dcf, device=sp.Device(-1), smap=None, batch=40000, id_channel=False, ishape=None): xp = device.xp N_phase = data.shape[0] img = [] for num_ph in range(N_phase): ksp_t = data[num_ph] coord_t = traj[num_ph] dcf_t = dcf[num_ph] if smap is not None: img_shape = list(smap.shape[-3:]) elif ishape is not None: img_shape = ishape else: img_shape = sp.estimate_shape(coord_t) num_coils, num_tr, num_ro = ksp_t.shape ndim = coord_t.shape[-1] if id_channel is True: img_t = np.zeros((num_coils, ) + tuple(img_shape), dtype=np.complex64) else: img_t = 0 with device: for c in range(num_coils): img_tt = 0 for seg in range((num_tr - 1) // batch + 1): ksp_ttc = sp.to_device( ksp_t[c, seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) coord_tt = sp.to_device( coord_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) dcf_tt = sp.to_device( dcf_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) img_tt += sp.nufft_adjoint(ksp_ttc * dcf_tt, coord_tt, img_shape) # TODO smap if id_channel is True: img_t[c, ...] = sp.to_device(img_tt) else: if smap is None: img_t += xp.abs(img_tt)**2 else: img_t += sp.to_device( img_tt * xp.conj(sp.to_device(smap[c, ...], device))) img_t = sp.to_device(img_t) img.append(img_t) return np.asarray(img)
def circulant_precond(mps, weights=None, coord=None, lamda=0, device=sp.cpu_device): r"""Compute circulant preconditioner. Considers the optimization problem: .. math:: \min_P \| A^H A - F P F^H \|_2^2 where A is the Sense operator, and F is a unitary Fourier transform operator. Args: mps (array): sensitivity maps of shape [num_coils] + image shape. weights (array): k-space weights. coord (array): k-space coordinates of shape [...] + [ndim]. lamda (float): regularization. Returns: array: circulant preconditioner of image shape. """ if coord is not None: coord = sp.to_device(coord, device) if weights is not None: weights = sp.to_device(weights, device) dtype = mps.dtype device = sp.Device(device) xp = device.xp mps_shape = list(mps.shape) img_shape = mps_shape[1:] img2_shape = [i * 2 for i in img_shape] ndim = len(img_shape) scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape)**2 with device: idx = (slice(None, None, 2), ) * ndim if coord is None: ones = xp.zeros(img2_shape, dtype=dtype) if weights is None: ones[idx] = 1 else: ones[idx] = weights**0.5 psf = sp.ifft(ones) else: coord2 = coord * 2 ones = xp.ones(coord.shape[:-1], dtype=dtype) if weights is not None: ones *= weights**0.5 psf = sp.nufft_adjoint(ones, coord2, img2_shape) p_inv = 0 for mps_i in mps: mps_i = sp.to_device(mps_i, device) xcorr_fourier = xp.abs(sp.fft(xp.conj(mps_i), img2_shape))**2 xcorr = sp.ifft(xcorr_fourier) xcorr *= psf p_inv_i = sp.fft(xcorr) p_inv_i = p_inv_i[idx] p_inv_i *= scale if weights is not None: p_inv_i *= weights**0.5 p_inv += p_inv_i p_inv += lamda p_inv[p_inv == 0] = 1 p = 1 / p_inv return p.astype(dtype)
def kspace_precond(mps, weights=None, coord=None, lamda=0, device=sp.cpu_device, oversamp=1.25): r"""Compute a diagonal preconditioner in k-space. Considers the optimization problem: .. math:: \min_P \| P A A^H - I \|_F^2 where A is the Sense operator. Args: mps (array): sensitivity maps of shape [num_coils] + image shape. weights (array): k-space weights. coord (array): k-space coordinates of shape [...] + [ndim]. lamda (float): regularization. Returns: array: k-space preconditioner of same shape as k-space. """ dtype = mps.dtype if weights is not None: weights = sp.to_device(weights, device) device = sp.Device(device) xp = device.xp mps_shape = list(mps.shape) img_shape = mps_shape[1:] img2_shape = [i * 2 for i in img_shape] ndim = len(img_shape) scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape) with device: if coord is None: idx = (slice(None, None, 2), ) * ndim ones = xp.zeros(img2_shape, dtype=dtype) if weights is None: ones[idx] = 1 else: ones[idx] = weights**0.5 psf = sp.ifft(ones) else: coord2 = coord * 2 ones = xp.ones(coord.shape[:-1], dtype=dtype) if weights is not None: ones *= weights**0.5 psf = sp.nufft_adjoint(ones, coord2, img2_shape, oversamp=oversamp) p_inv = [] for mps_i in mps: mps_i = sp.to_device(mps_i, device) mps_i_norm2 = xp.linalg.norm(mps_i)**2 xcorr_fourier = 0 for mps_j in mps: mps_j = sp.to_device(mps_j, device) xcorr_fourier += xp.abs( sp.fft(mps_i * xp.conj(mps_j), img2_shape))**2 xcorr = sp.ifft(xcorr_fourier) xcorr *= psf if coord is None: p_inv_i = sp.fft(xcorr)[idx] else: p_inv_i = sp.nufft(xcorr, coord2, oversamp=oversamp) if weights is not None: p_inv_i *= weights**0.5 p_inv.append(p_inv_i * scale / mps_i_norm2) p_inv = (xp.abs(xp.stack(p_inv, axis=0)) + lamda) / (1 + lamda) p_inv[p_inv == 0] = 1 p = 1 / p_inv return p.astype(dtype)
def _update(self, t): # Form image. img_t = 0 for j in range(self.J): img_t += self.B[j](self.L[j] * self.R[j][t]) # Download k-space arrays. tr_start = t * self.tr_per_frame tr_end = (t + 1) * self.tr_per_frame coord_t = sp.to_device(self.coord[tr_start:tr_end], self.device) dcf_t = sp.to_device(self.dcf[tr_start:tr_end], self.device) ksp_t = sp.to_device(self.ksp[:, tr_start:tr_end], self.device) # Data consistency. e_t = 0 loss_t = 0 for c in range(self.C): mps_c = sp.to_device(self.mps[c], self.device) e_tc = sp.nufft(img_t * mps_c, coord_t) e_tc -= ksp_t[c] e_tc *= dcf_t**0.5 loss_t += self.xp.linalg.norm(e_tc)**2 e_tc *= dcf_t**0.5 e_tc = sp.nufft_adjoint(e_tc, coord_t, oshape=self.img_shape) e_tc *= self.xp.conj(mps_c) e_t += e_tc if self.comm is not None: self.comm.allreduce(e_t) self.comm.allreduce(loss_t) loss_t = loss_t.item() # Compute gradient. for j in range(self.J): lamda_j = self.lamda * self.G[j] # Loss. loss_t += lamda_j / self.T * self.xp.linalg.norm( self.L[j]).item()**2 loss_t += lamda_j * self.xp.linalg.norm(self.R[j][t]).item()**2 if np.isinf(loss_t) or np.isnan(loss_t): raise OverflowError # L gradient. g_L_j = self.B[j].H(e_t) g_L_j *= self.xp.conj(self.R[j][t]) g_L_j += lamda_j / self.T * self.L[j] g_L_j *= self.T # R gradient. g_R_jt = self.B[j].H(e_t) g_R_jt *= self.xp.conj(self.L[j]) g_R_jt = self.xp.sum(g_R_jt, axis=range(-self.D, 0), keepdims=True) g_R_jt += lamda_j * self.R[j][t] # Precondition. g_L_j /= self.J * self.sigma[j] + lamda_j g_R_jt /= self.J * self.sigma[j] + lamda_j # Add. self.L[j] -= self.alpha * self.beta**(self.epoch // self.decay_epoch) * g_L_j self.R[j][t] -= self.alpha * g_R_jt loss_t /= 2 return loss_t
ksp = xp.load(ksp_file) coord = xp.load(coord_file) def show_data_info(data, name): print("{}: shape={}, dtype={}".format(name, data.shape, data.dtype)) dcf = (coord[..., 0]**2 + coord[..., 1]**2)**0.5 pl.ScatterPlot(coord, dcf, title='Density compensation') show_data_info(ksp, "ksp") show_data_info(coord, "coord") show_data_info(dcf, "dcf") img_grid = sp.nufft_adjoint(ksp * dcf, coord) pl.ImagePlot(img_grid, z=0, title='Multi-channel Gridding') #%% md ## Estimate sensitivity maps using JSENSE # Here we use [JSENSE](https://onlinelibrary.wiley.com/doi/full/10.1002/mrm.21245) to estimate sensitivity maps. #%% mps = mr.app.JsenseRecon(ksp, coord=coord, device=device).run() #%% md ## CG
################################################################################ # Estimate coil maps using Walsh's method from temporal averaged data # with device: # assume spiral sampling patterns repeat every n_full_arms n_avr = int(xp.floor(ns / n_full_arms)) kdata_avr = kdata[:n_avr * n_full_arms, :, :].reshape( n_avr, n_full_arms, nc, nk) kdata_avr = xp.mean(kdata_avr, axis=0) kdata_avr = xp.transpose(kdata_avr, (1, 0, 2)) kloc_avr = kloc[:n_full_arms, :, :] avr_img = sp.nufft_adjoint(kdata_avr * kweight[xp.newaxis, xp.newaxis, :], kloc_avr) # needs to process on CPUs sens_map = estimate_coilmap_walsh(sp.to_device(avr_img, -1), smoothing=20, thresh=0.0) # copy it to GPU sens_map = sp.to_device(sens_map, device_id) # pl.ImagePlot(xp.squeeze(avr_img), z=0, title='Multi-channel Time Averaged Image') # pl.ImagePlot(xp.squeeze(xp.abs(sens_map)), z=0, title='Walsh (Python)') # TODO Espirit coil map estimation needs to be improved ################################################################################ # Reshape Data
def test_shepp_logan_dcf(self): img, coord, ksp = self.shepp_logan_setup() pm_dcf = dcf.pipe_menon_dcf(coord, show_pbar=False) img_dcf = sp.nufft_adjoint(ksp * pm_dcf, coord, oshape=img.shape) img_dcf /= np.abs(img_dcf).max() npt.assert_allclose(img, img_dcf, atol=1, rtol=1e-1)