def autofov(ksp, coord, dcf, num_ro=100, device=sp.cpu_device, thresh=0.1): """Automatic estimation of FOV. FOV is estimated by thresholding a low resolution gridded image. coord will be modified in-place. Args: ksp (array): k-space measurements of shape (C, num_tr, num_ro, D). where C is the number of channels, num_tr is the number of TRs, num_ro is the readout points, and D is the number of spatial dimensions. coord (array): k-space coordinates of shape (num_tr, num_ro, D). dcf (array): density compensation factor of shape (num_tr, num_ro). num_ro (int): number of read-out points. device (Device): computing device. thresh (float): threshold between 0 and 1. """ device = sp.Device(device) xp = device.xp with device: kspc = ksp[:, :, :num_ro] coordc = coord[:, :num_ro, :] dcfc = dcf[:, :num_ro] coordc2 = sp.to_device(coordc * 2, device) num_coils = len(kspc) imgc_shape = np.array(sp.estimate_shape(coordc)) imgc2_shape = sp.estimate_shape(coordc2) imgc2_center = [i // 2 for i in imgc2_shape] imgc2 = sp.nufft_adjoint(sp.to_device(dcfc * kspc, device), coordc2, [num_coils] + imgc2_shape) imgc2 = xp.sum(xp.abs(imgc2)**2, axis=0)**0.5 if imgc2.ndim == 3: imgc2_cor = imgc2[:, imgc2.shape[1] // 2, :] thresh *= imgc2_cor.max() else: thresh *= imgc2.max() boxc = imgc2 > thresh boxc = sp.to_device(boxc) boxc_idx = np.nonzero(boxc) boxc_shape = np.array([ int(np.abs(boxc_idx[i] - imgc2_center[i]).max()) * 2 for i in range(imgc2.ndim) ]) img_scale = boxc_shape / imgc_shape coord *= img_scale
def _get_data(self): if self.coord is None: self.img_shape = list(self.y.shape[1:]) ndim = len(self.img_shape) self.y = sp.resize( self.y, [self.num_coils] + ndim * [self.ksp_calib_width]) if self.weights is not None: self.weights = sp.resize( self.weights, ndim * [self.ksp_calib_width]) else: self.img_shape = sp.estimate_shape(self.coord) calib_idx = np.amax(np.abs(self.coord), axis=- 1) < self.ksp_calib_width / 2 self.coord = self.coord[calib_idx] self.y = self.y[:, calib_idx] if self.weights is not None: self.weights = self.weights[calib_idx] if self.weights is None: self.y = sp.to_device(self.y / np.abs(self.y).max(), self.device) else: self.y = sp.to_device(self.weights**0.5 * self.y / np.abs(self.y).max(), self.device) if self.coord is not None: self.coord = sp.to_device(self.coord, self.device) if self.weights is not None: self.weights = sp.to_device(self.weights, self.device) self.weights = _estimate_weights(self.y, self.weights, self.coord)
def pipe_menon_dcf(coord, device=sp.cpu_device, max_iter=30, n=128, beta=8, width=4, show_pbar=True): r"""Compute Pipe Menon density compensation factor. Perform the following iteration: .. math:: w = \frac{w}{|G^H G w|} with :math:`G` as the gridding operator. Args: coord (array): k-space coordinates. device (Device): computing device. max_iter (int): number of iterations. n (int): Kaiser-Bessel sampling numbers for gridding operator. beta (float): Kaiser-Bessel kernel parameter. width (float): Kaiser-Bessel kernel width. show_pbar (bool): show progress bar. Returns: array: density compensation factor. References: Pipe, James G., and Padmanabhan Menon. Sampling Density Compensation in MRI: Rationale and an Iterative Numerical Solution. Magnetic Resonance in Medicine 41, no. 1 (1999): 179–86. """ device = sp.Device(device) xp = device.xp with device: w = xp.ones(coord.shape[:-1], dtype=coord.dtype) img_shape = sp.estimate_shape(coord) # Get kernel x = xp.arange(n, dtype=coord.dtype) / n kernel = xp.i0(beta * (1 - x**2)**0.5).astype(coord.dtype) kernel /= kernel.max() G = sp.linop.Gridding(img_shape, coord, width, kernel) with tqdm(total=max_iter, disable=not show_pbar) as pbar: for it in range(max_iter): GHGw = G.H * G * w w /= xp.abs(GHGw) resid = xp.abs(GHGw - 1).max().item() pbar.set_postfix(resid='{0:.2E}'.format(resid)) pbar.update() return w
def ConvSense(img_ker_shape, mps_ker, coord=None, weights=None, comm=None): """Convolution linear operator with sensitivity maps kernel in k-space. Args: img_ker_shape (tuple of ints): image kernel shape. mps_ker (array): sensitivity maps kernel. coord (array): coordinates. """ ndim = len(img_ker_shape) A = sp.linop.ConvolveInput(img_ker_shape, mps_ker, mode='valid', output_multi_channel=True) if coord is not None: num_coils = mps_ker.shape[0] grd_shape = [num_coils] + sp.estimate_shape(coord) iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0)) N = sp.linop.NUFFT(grd_shape, coord) A = N * iF * A if weights is not None: with sp.get_device(weights): P = sp.linop.Multiply(A.oshape, weights**0.5) A = P * A if comm is not None: C = sp.linop.AllReduceAdjoint(img_ker_shape, comm, in_place=True) A = A * C return A
def ConvImage(mps_ker_shape, img_ker, coord=None, weights=None): """Convolution linear operator with image kernel in k-space. Args: mps_ker_shape (tuple of ints): sensitivity maps kernel shape. img_ker (array): image kernel. coord (array): coordinates. """ ndim = img_ker.ndim A = sp.linop.ConvolveFilter(mps_ker_shape, img_ker, mode='valid', output_multi_channel=True) if coord is not None: num_coils = mps_ker_shape[0] grd_shape = [num_coils] + sp.estimate_shape(coord) iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0)) N = sp.linop.NUFFT(grd_shape, coord) A = N * iF * A if weights is not None: with sp.get_device(weights): P = sp.linop.Multiply(A.oshape, weights**0.5) A = P * A return A
def nufft1(img, traj, device=sp.Device(-1), smap=None, dcf=None, batch=40000, id_channel=False): xp = device.xp N_phase = img.shape[0] ksp = [] for num_ph in range(N_phase): img_t = img[num_ph] coord_t = traj[num_ph] if dcf is not None: dcf_t = dcf[num_ph] img_shape = sp.estimate_shape(coord_t) num_tr, num_ro, ndim = coord_t.shape if id_channel is True: num_coils = img_t.shape[0] else: if smap is None: smap = np.array([1]) num_coils = 1 else: num_coils = smap.shape[0] with device: ksp_t = np.zeros((num_coils, num_tr, num_ro), dtype=np.complex64) for c in range(num_coils): if id_channel is True: img_tc = sp.to_device(img_t[c, ...], device) else: img_tc = sp.to_device(img_t * smap[c, ...], device) for seg in range((num_tr - 1) // batch + 1): coord_tt = sp.to_device( coord_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) ksp_tt = sp.nufft(img_tc, coord_tt) if dcf is not None: dcf_tt = sp.to_device( dcf_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) ksp_tt = dcf_tt * ksp_tt ksp_t[c, seg * batch:np.minimum((seg + 1) * batch, num_tr), ...] = sp.to_device(ksp_tt) ksp.append(ksp_t) return np.asarray(ksp)
def _get_vars(self): ndim = len(self.img_shape) mps_ker_shape = [self.num_coils] + [self.mps_ker_width] * ndim if self.coord is None: img_ker_shape = [i + self.mps_ker_width - 1 for i in self.y.shape[1:]] else: grd_shape = sp.estimate_shape(self.coord) img_ker_shape = [i + self.mps_ker_width - 1 for i in grd_shape] self.img_ker = sp.dirac( img_ker_shape, dtype=self.dtype, device=self.device) with self.device: self.mps_ker = self.device.xp.zeros( mps_ker_shape, dtype=self.dtype)
def ConvSense(img_ker_shape, mps_ker, coord=None, weights=None, grd_shape=None, comm=None): """Convolution linear operator with sensitivity maps kernel in k-space. Args: img_ker_shape (tuple of ints): image kernel shape. mps_ker (array): sensitivity maps kernel. coord (array): coordinates. grd_shape (None or list): Shape of grid. """ ndim = len(img_ker_shape) num_coils = mps_ker.shape[0] mps_ker = mps_ker.reshape((num_coils, 1) + mps_ker.shape[1:]) R = sp.linop.Reshape((1, ) + tuple(img_ker_shape), img_ker_shape) C = sp.linop.ConvolveData(R.oshape, mps_ker, mode='valid', multi_channel=True) A = C * R if coord is not None: if grd_shape is None: grd_shape = sp.estimate_shape(coord) else: grd_shape = list(grd_shape) grd_shape = [num_coils] + grd_shape iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0)) N = sp.linop.NUFFT(grd_shape, coord) A = N * iF * A if weights is not None: with sp.get_device(weights): P = sp.linop.Multiply(A.oshape, weights**0.5) A = P * A if comm is not None: C = sp.linop.AllReduceAdjoint(img_ker_shape, comm, in_place=True) A = A * C return A
def gridding_recon(ksp, coord, dcf, T=1, device=sp.cpu_device): """ Gridding reconstruction. Args: ksp (array): k-space measurements of shape (C, num_tr, num_ro, D). where C is the number of channels, num_tr is the number of TRs, num_ro is the readout points, and D is the number of spatial dimensions. coord (array): k-space coordinates of shape (num_tr, num_ro, D). dcf (array): density compensation factor of shape (num_tr, num_ro). mps (array): sensitivity maps of shape (C, N_D, ..., N_1). where (N_D, ..., N_1) represents the image shape. T (int): number of frames. Returns: img (array): image of shape (T, N_D, ..., N_1). """ device = sp.Device(device) xp = device.xp num_coils, num_tr, num_ro = ksp.shape tr_per_frame = num_tr // T img_shape = sp.estimate_shape(coord) with device: img = [] for t in range(T): tr_start = t * tr_per_frame tr_end = (t + 1) * tr_per_frame coord_t = sp.to_device(coord[tr_start:tr_end], device) dcf_t = sp.to_device(dcf[tr_start:tr_end], device) img_t = 0 for c in range(num_coils): logging.info(f'Reconstructing time {t}, coil {c}') ksp_tc = sp.to_device(ksp[c, tr_start:tr_end, :], device) img_t += xp.abs( sp.nufft_adjoint(ksp_tc * dcf_t, coord_t, img_shape))**2 img_t = img_t**0.5 img.append(sp.to_device(img_t)) img = np.stack(img) return img
def ConvImage(mps_ker_shape, img_ker, coord=None, weights=None, grd_shape=None): """Convolution linear operator with image kernel in k-space. Args: mps_ker_shape (tuple of ints): sensitivity maps kernel shape. img_ker (array): image kernel. coord (array): coordinates. grd_shape (None or list): Shape of grid. """ ndim = img_ker.ndim num_coils = mps_ker_shape[0] img_ker = img_ker.reshape((1, ) + img_ker.shape) R = sp.linop.Reshape((num_coils, 1) + tuple(mps_ker_shape[1:]), mps_ker_shape) C = sp.linop.ConvolveFilter(R.oshape, img_ker, mode='valid', multi_channel=True) A = C * R if coord is not None: num_coils = mps_ker_shape[0] if grd_shape is None: grd_shape = sp.estimate_shape(coord) else: grd_shape = list(grd_shape) grd_shape = [num_coils] + grd_shape iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0)) N = sp.linop.NUFFT(grd_shape, coord) A = N * iF * A if weights is not None: with sp.get_device(weights): P = sp.linop.Multiply(A.oshape, weights**0.5) A = P * A return A
def nufft_adj1(data, traj, dcf, device=sp.Device(-1), smap=None, batch=40000, id_channel=False, ishape=None): xp = device.xp N_phase = data.shape[0] img = [] for num_ph in range(N_phase): ksp_t = data[num_ph] coord_t = traj[num_ph] dcf_t = dcf[num_ph] if smap is not None: img_shape = list(smap.shape[-3:]) elif ishape is not None: img_shape = ishape else: img_shape = sp.estimate_shape(coord_t) num_coils, num_tr, num_ro = ksp_t.shape ndim = coord_t.shape[-1] if id_channel is True: img_t = np.zeros((num_coils, ) + tuple(img_shape), dtype=np.complex64) else: img_t = 0 with device: for c in range(num_coils): img_tt = 0 for seg in range((num_tr - 1) // batch + 1): ksp_ttc = sp.to_device( ksp_t[c, seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) coord_tt = sp.to_device( coord_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) dcf_tt = sp.to_device( dcf_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) img_tt += sp.nufft_adjoint(ksp_ttc * dcf_tt, coord_tt, img_shape) # TODO smap if id_channel is True: img_t[c, ...] = sp.to_device(img_tt) else: if smap is None: img_t += xp.abs(img_tt)**2 else: img_t += sp.to_device( img_tt * xp.conj(sp.to_device(smap[c, ...], device))) img_t = sp.to_device(img_t) img.append(img_t) return np.asarray(img)
if __name__ == '__main__': logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser() parser.add_argument('--num_ro', type=int, default=100) parser.add_argument('--device', type=int, default=-1) parser.add_argument('--thresh', type=float, default=0.1) parser.add_argument('ksp_file', type=str) parser.add_argument('coord_file', type=str) parser.add_argument('dcf_file', type=str) args = parser.parse_args() ksp = np.load(args.ksp_file) coord = np.load(args.coord_file) dcf = np.load(args.dcf_file) autofov(ksp, coord, dcf, num_ro=args.num_ro, device=args.device, thresh=args.thresh) logging.info('Image shape: {}'.format(sp.estimate_shape(coord))) logging.info('Saving data.') np.save(args.coord_file, coord)