def test_sigpy_cupy(self): import sigpy as sp assert Device(0) == sp.Device(0) device = Device(0) assert device.spdevice == sp.Device(0)
def __init__(self, ksp, calib_width=24, thresh=0.01, kernel_width=6, crop=0.8, max_iter=100, device=sp.cpu_device, output_eigenvalue=False, show_pbar=True): self.device = sp.Device(device) self.output_eigenvalue = output_eigenvalue self.crop = crop img_ndim = ksp.ndim - 1 num_coils = len(ksp) with sp.get_device(ksp): # Get calibration region calib_shape = [num_coils] + [calib_width] * img_ndim calib = sp.resize(ksp, calib_shape) calib = sp.to_device(calib, device) xp = self.device.xp with self.device: # Get calibration matrix kernel_shape = [num_coils] + [kernel_width] * img_ndim kernel_strides = [1] * (img_ndim + 1) mat = sp.array_to_blocks(calib, kernel_shape, kernel_strides) mat = mat.reshape([-1, sp.prod(kernel_shape)]) # Perform SVD on calibration matrix _, S, VH = xp.linalg.svd(mat, full_matrices=False) VH = VH[S > thresh * S.max(), :] # Get kernels num_kernels = len(VH) kernels = VH.reshape([num_kernels] + kernel_shape) img_shape = ksp.shape[1:] # Get covariance matrix in image domain AHA = xp.zeros(img_shape[::-1] + (num_coils, num_coils), dtype=ksp.dtype) for kernel in kernels: img_kernel = sp.ifft(sp.resize(kernel, ksp.shape), axes=range(-img_ndim, 0)) aH = xp.expand_dims(img_kernel.T, axis=-1) a = xp.conj(aH.swapaxes(-1, -2)) AHA += aH @ a AHA *= (sp.prod(img_shape) / kernel_width**img_ndim) self.mps = xp.ones(ksp.shape[::-1] + (1, ), dtype=ksp.dtype) alg = sp.alg.PowerMethod( lambda x: AHA @ x, self.mps, norm_func=lambda x: xp.sum( xp.abs(x)**2, axis=-2, keepdims=True)**0.5, max_iter=max_iter) super().__init__(alg, show_pbar=show_pbar)
def pipe_menon_dcf(coord, device=sp.cpu_device, max_iter=30, n=128, beta=8, width=4, show_pbar=True): r"""Compute Pipe Menon density compensation factor. Perform the following iteration: .. math:: w = \frac{w}{|G^H G w|} with :math:`G` as the gridding operator. Args: coord (array): k-space coordinates. device (Device): computing device. max_iter (int): number of iterations. n (int): Kaiser-Bessel sampling numbers for gridding operator. beta (float): Kaiser-Bessel kernel parameter. width (float): Kaiser-Bessel kernel width. show_pbar (bool): show progress bar. Returns: array: density compensation factor. References: Pipe, James G., and Padmanabhan Menon. Sampling Density Compensation in MRI: Rationale and an Iterative Numerical Solution. Magnetic Resonance in Medicine 41, no. 1 (1999): 179–86. """ device = sp.Device(device) xp = device.xp with device: w = xp.ones(coord.shape[:-1], dtype=coord.dtype) img_shape = sp.estimate_shape(coord) # Get kernel x = xp.arange(n, dtype=coord.dtype) / n kernel = xp.i0(beta * (1 - x**2)**0.5).astype(coord.dtype) kernel /= kernel.max() G = sp.linop.Gridding(img_shape, coord, width, kernel) with tqdm(total=max_iter, disable=not show_pbar) as pbar: for it in range(max_iter): GHGw = G.H * G * w w /= xp.abs(GHGw) resid = xp.abs(GHGw - 1).max().item() pbar.set_postfix(resid='{0:.2E}'.format(resid)) pbar.update() return w
def __init__(self, input, output, batch_size, mu, lamda=0, max_iter=100, max_inner_iter=100, device=sp.cpu_device, checkpoint_path=None): dtype = output.dtype num_data = len(output) num_batches = num_data // batch_size self.device = device self.lamda = lamda self.batch_size = batch_size self.input = input self.output = output self.checkpoint_path = checkpoint_path self.device = sp.Device(device) xp = self.device.xp with self.device: self.mat = xp.zeros(input.shape[1:] + output.shape[1:], dtype=dtype) self.input_t = xp.empty((batch_size, ) + input.shape[1:], dtype=dtype) self.output_t = xp.empty((batch_size, ) + output.shape[1:], dtype=dtype) self.t_idx = sp.ShuffledNumbers(num_batches) self._get_A() def proxf(mu, x): return sp.app.LinearLeastSquares(self.A, self.output_t, x=x, lamda=self.lamda / num_batches, mu=1 / mu, z=x, max_iter=max_inner_iter).run() alg = sp.alg.ProximalPointMethod(proxf, mu, self.mat, max_iter=max_iter) super().__init__(alg)
def __init__(self, y, mps_ker_width=16, ksp_calib_width=24, lamda=0, device=sp.cpu_device, comm=None, weights=None, coord=None, max_iter=10, max_inner_iter=10, normalize=True, show_pbar=True): self.y = y self.mps_ker_width = mps_ker_width self.ksp_calib_width = ksp_calib_width self.lamda = lamda self.weights = weights self.coord = coord self.max_iter = max_iter self.max_inner_iter = max_inner_iter self.normalize = normalize self.device = sp.Device(device) self.comm = comm self.dtype = y.dtype self.num_coils = len(y) if comm is not None: show_pbar = show_pbar and comm.rank == 0 self._get_data() self._get_vars() self._get_alg() super().__init__(self.alg, show_pbar=show_pbar)
def spdevice(self): """sigpy.Device: The equivalent ```sigpy.Device```.""" if not env.sigpy_available(): raise RuntimeError("`sigpy` not installed.") if self.id >= 0 and self.type != "cuda": raise RuntimeError( f"sigpy.Device does not support type {self.type}") return sp.Device(self.id)
def __init__(self, ksp, coord, dcf, mps, T, lamda, blk_widths=[32, 64, 128], alpha=1, beta=0.5, sgw=None, device=sp.cpu_device, comm=None, seed=0, max_epoch=60, decay_epoch=20, max_power_iter=5, show_pbar=True): self.ksp = ksp self.coord = coord self.dcf = dcf self.mps = mps self.sgw = sgw self.blk_widths = blk_widths self.T = T self.lamda = lamda self.alpha = alpha self.beta = beta self.device = sp.Device(device) self.comm = comm self.seed = seed self.max_epoch = max_epoch self.decay_epoch = decay_epoch self.max_power_iter = max_power_iter self.show_pbar = show_pbar and (comm is None or comm.rank == 0) np.random.seed(self.seed) self.xp = self.device.xp with self.device: self.xp.random.seed(self.seed) self.dtype = self.ksp.dtype self.C, self.num_tr, self.num_ro = self.ksp.shape self.tr_per_frame = self.num_tr // self.T self.img_shape = self.mps.shape[1:] self.D = len(self.img_shape) self.J = len(self.blk_widths) if self.sgw is not None: self.dcf *= np.expand_dims(self.sgw, -1) self.B = [self._get_B(j) for j in range(self.J)] self.G = [self._get_G(j) for j in range(self.J)] self._normalize()
def jsens_calib(ksp, coord, dcf, ishape, device = sp.Device(-1)): img_s = nft.nufft_adj([ksp],[coord],[dcf],device = device,ishape = ishape,id_channel =True) ksp = sp.fft(input=np.asarray(img_s[0]),axes=(1,2,3)) mps = mr.app.JsenseRecon(ksp, mps_ker_width=12, ksp_calib_width=32, lamda=0, device=device, comm=sp.Communicator(), max_iter=10, max_inner_iter=10).run() return mps
def check_linop_adjoint(A, dtype=np.float, device=sp.cpu_device): device = sp.Device(device) x = sp.randn(A.ishape, dtype=dtype, device=device) y = sp.randn(A.oshape, dtype=dtype, device=device) xp = device.xp with device: lhs = xp.vdot(A * x, y) rhs = xp.vdot(x, A.H * y) xp.testing.assert_allclose(lhs, rhs, atol=1e-5, rtol=1e-5)
def nufft1(img, traj, device=sp.Device(-1), smap=None, dcf=None, batch=40000, id_channel=False): xp = device.xp N_phase = img.shape[0] ksp = [] for num_ph in range(N_phase): img_t = img[num_ph] coord_t = traj[num_ph] if dcf is not None: dcf_t = dcf[num_ph] img_shape = sp.estimate_shape(coord_t) num_tr, num_ro, ndim = coord_t.shape if id_channel is True: num_coils = img_t.shape[0] else: if smap is None: smap = np.array([1]) num_coils = 1 else: num_coils = smap.shape[0] with device: ksp_t = np.zeros((num_coils, num_tr, num_ro), dtype=np.complex64) for c in range(num_coils): if id_channel is True: img_tc = sp.to_device(img_t[c, ...], device) else: img_tc = sp.to_device(img_t * smap[c, ...], device) for seg in range((num_tr - 1) // batch + 1): coord_tt = sp.to_device( coord_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) ksp_tt = sp.nufft(img_tc, coord_tt) if dcf is not None: dcf_tt = sp.to_device( dcf_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) ksp_tt = dcf_tt * ksp_tt ksp_t[c, seg * batch:np.minimum((seg + 1) * batch, num_tr), ...] = sp.to_device(ksp_tt) ksp.append(ksp_t) return np.asarray(ksp)
def process_slice(kspace, args, calib_method='jsense'): # get data dimensions nky, nkz, nechoes, ncoils = kspace.shape # ESPIRiT parameters nmaps = args.num_emaps calib_size = args.ncalib crop_value = args.crop_value if args.device is -1: device = sp.cpu_device else: device = sp.Device(args.device) # compute sensitivity maps (BART) #cmd = f'ecalib -d 0 -S -m {nmaps} -c {crop_value} -r {calib_size}' #maps = bart.bart(1, cmd, kspace[:,:,0,None,:]) #maps = np.reshape(maps, (nky, nkz, 1, ncoils, nmaps)) # compute sensitivity maps (SigPy) ksp = np.transpose(kspace[:, :, 0, :], [2, 1, 0]) if calib_method is 'espirit': maps = app.EspiritCalib(ksp, calib_width=calib_size, crop=crop_value, device=device, show_pbar=False).run() elif calib_method is 'jsense': maps = app.JsenseRecon(ksp, mps_ker_width=6, ksp_calib_width=calib_size, device=device, show_pbar=False).run() else: raise ValueError('%s calibration method not implemented...' % calib_method) maps = np.reshape(np.transpose(maps, [2, 1, 0]), (nky, nkz, 1, ncoils, nmaps)) # Convert everything to tensors kspace_tensor = cplx.to_tensor(kspace).unsqueeze(0) maps_tensor = cplx.to_tensor(maps).unsqueeze(0) # Do coil combination using sensitivity maps (PyTorch) A = T.SenseModel(maps_tensor) im_tensor = A(kspace_tensor, adjoint=True) # Convert tensor back to numpy array image = cplx.to_numpy(im_tensor.squeeze(0)) return image, maps
def autofov(ksp, coord, dcf, num_ro=100, device=sp.cpu_device, thresh=0.1): """Automatic estimation of FOV. FOV is estimated by thresholding a low resolution gridded image. coord will be modified in-place. Args: ksp (array): k-space measurements of shape (C, num_tr, num_ro, D). where C is the number of channels, num_tr is the number of TRs, num_ro is the readout points, and D is the number of spatial dimensions. coord (array): k-space coordinates of shape (num_tr, num_ro, D). dcf (array): density compensation factor of shape (num_tr, num_ro). num_ro (int): number of read-out points. device (Device): computing device. thresh (float): threshold between 0 and 1. """ device = sp.Device(device) xp = device.xp with device: kspc = ksp[:, :, :num_ro] coordc = coord[:, :num_ro, :] dcfc = dcf[:, :num_ro] coordc2 = sp.to_device(coordc * 2, device) num_coils = len(kspc) imgc_shape = np.array(sp.estimate_shape(coordc)) imgc2_shape = sp.estimate_shape(coordc2) imgc2_center = [i // 2 for i in imgc2_shape] imgc2 = sp.nufft_adjoint(sp.to_device(dcfc * kspc, device), coordc2, [num_coils] + imgc2_shape) imgc2 = xp.sum(xp.abs(imgc2)**2, axis=0)**0.5 if imgc2.ndim == 3: imgc2_cor = imgc2[:, imgc2.shape[1] // 2, :] thresh *= imgc2_cor.max() else: thresh *= imgc2.max() boxc = imgc2 > thresh boxc = sp.to_device(boxc) boxc_idx = np.nonzero(boxc) boxc_shape = np.array([ int(np.abs(boxc_idx[i] - imgc2_center[i]).max()) * 2 for i in range(imgc2.ndim) ]) img_scale = boxc_shape / imgc_shape coord *= img_scale
def _get_params(self): self.device = sp.Device(self.device) self.dtype = self.y.dtype self.num_data = len(self.y) self.filt_width = self.L.shape[-1] self.num_filters = self.L.shape[self.multi_channel] self.data_ndim = self.y.ndim - self.multi_channel - 1 if self.mode == 'full': self.R_shape = [self.num_data, self.num_filters] + [i - self.filt_width + 1 for i in self.y.shape[-self.data_ndim:]] else: self.R_shape = [self.num_data, self.num_filters] + [i + self.filt_width - 1 for i in self.y.shape[-self.data_ndim:]]
def interp(I, M_field, device=sp.Device(-1), k_id=1, deblur=True): # b spline interpolation N = 64 if k_id is 0: kernel = [(3 * (x / N)**3 - 6 * (x / N)**2 + 4) / 6 for x in range(0, N)] + [(2 - x / N)**3 / 6 for x in range(N, 2 * N)] dkernel = np.array([-.2, 1.4, -.2]) k_wid = 4 else: kernel = [1 - x / (2 * N) for x in range(0, 2 * N)] dkernel = np.array([0, 1, 0]) deblur = False k_wid = 2 kernel = np.asarray(kernel) c_device = sp.get_device(I) ndim = M_field.shape[-1] # 2d/3d if ndim is 3: dkernel = dkernel[:, None, None] * dkernel[None, :, None] * dkernel[None, None, :] Nx, Ny, Nz = I.shape my, mx, mz = np.meshgrid(np.arange(Ny), np.arange(Nx), np.arange(Nz)) m = np.stack((mx, my, mz), axis=-1) M_field = M_field + m else: dkernel = dkernel[:, None] * dkernel[None, :] Nx, Ny = I.shape my, mx = np.meshgrid(np.arange(Ny), np.arange(Nx)) m = np.stack((mx, my, mz), axis=-1) M_field = M_field + m # TODO remove out of range values # image warp g_device = device I = sp.to_device(input=I, device=g_device) I = sp.interp.interpolate(I, k_wid, kernel, M_field.astype(np.float64)) # deconv if deblur is True: sp.conv.convolve(I, dkernel) I = sp.to_device(input=I, device=c_device) return I
def __init__(self, ksp, coord, dcf, mps, resp, B, lamda=1e-6, alpha=1, beta=0.5, max_power_iter=10, max_iter=300, device=sp.cpu_device, margin=10, coil_batch_size=None, comm=None, show_pbar=True, **kwargs): self.B = B self.C = len(mps) self.mps = mps self.device = sp.Device(device) self.xp = device.xp self.alpha = alpha self.beta = beta self.lamda = lamda self.max_iter = max_iter self.max_power_iter = max_power_iter self.comm = comm if comm is not None: self.show_pbar = show_pbar and comm.rank == 0 self.img_shape = list(mps.shape[1:]) bins = np.percentile(resp, np.linspace(0 + margin, 100 - margin, B + 1)) self.bksp = [] self.bcoord = [] self.bdcf = [] for b in range(B): idx = (resp >= bins[b]) & (resp < bins[b + 1]) self.bksp.append(sp.to_device(ksp[:, idx], self.device)) self.bcoord.append(sp.to_device(coord[idx], self.device)) self.bdcf.append(sp.to_device(dcf[idx], self.device)) self._normalize()
def gridding_recon(ksp, coord, dcf, T=1, device=sp.cpu_device): """ Gridding reconstruction. Args: ksp (array): k-space measurements of shape (C, num_tr, num_ro, D). where C is the number of channels, num_tr is the number of TRs, num_ro is the readout points, and D is the number of spatial dimensions. coord (array): k-space coordinates of shape (num_tr, num_ro, D). dcf (array): density compensation factor of shape (num_tr, num_ro). mps (array): sensitivity maps of shape (C, N_D, ..., N_1). where (N_D, ..., N_1) represents the image shape. T (int): number of frames. Returns: img (array): image of shape (T, N_D, ..., N_1). """ device = sp.Device(device) xp = device.xp num_coils, num_tr, num_ro = ksp.shape tr_per_frame = num_tr // T img_shape = sp.estimate_shape(coord) with device: img = [] for t in range(T): tr_start = t * tr_per_frame tr_end = (t + 1) * tr_per_frame coord_t = sp.to_device(coord[tr_start:tr_end], device) dcf_t = sp.to_device(dcf[tr_start:tr_end], device) img_t = 0 for c in range(num_coils): logging.info(f'Reconstructing time {t}, coil {c}') ksp_tc = sp.to_device(ksp[c, tr_start:tr_end, :], device) img_t += xp.abs( sp.nufft_adjoint(ksp_tc * dcf_t, coord_t, img_shape))**2 img_t = img_t**0.5 img.append(sp.to_device(img_t)) img = np.stack(img) return img
def _get_params(self): self.device = sp.Device(self.device) self.dtype = self.y.dtype self.data_ndim = self.y.ndim - self.multi_channel - 1 if self.checkpoint_path is not None: self.checkpoint_path = pathlib.Path(self.checkpoint_path) self.checkpoint_path.mkdir(parents=True, exist_ok=True) self.batch_size = min(len(self.y), self.batch_size) self.num_batches = len(self.y) // self.batch_size self.L_shape = [self.num_filters] + [self.filt_width] * self.data_ndim if self.multi_channel: self.L_shape = [self.y.shape[1]] + self.L_shape if self.mode == 'full': self.R_t_shape = [self.batch_size, self.num_filters] + [i - self.filt_width + 1 for i in self.y.shape[-self.data_ndim:]] else: self.R_t_shape = [self.batch_size, self.num_filters] + [i + self.filt_width - 1 for i in self.y.shape[-self.data_ndim:]]
def init_bloch_vector(shape=None, dtype=np.complex, device=sp.cpu_device): """Initialize magnetization in Bloch vector representation. Args: shape (tuple): batch shape. dtype (Dtype): data type. device (Device): device. Returns: array: magnetization array. """ device = sp.Device(device) xp = device.xp with device: if shape is None: shape = [] m = xp.zeros(list(shape) + [3], dtype=dtype) m[..., 2] = 1 return m
def init_density_matrix(shape=None, dtype=np.complex, device=sp.cpu_device): """Initialize magnetization in density matrix representation. Args: shape (tuple): batch shape. dtype (Dtype): data type. device (Device): device. Returns: array: magnetization array. """ device = sp.Device(device) xp = device.xp with device: if shape is None: shape = [] p = xp.zeros(list(shape) + [2, 2], dtype=dtype) p[..., 0, 0] = 1 return p
def NFTs(ishape, coord, device = sp.Device(-1)): n_Channel = ishape[0] oshape = list((n_Channel,)) + list(coord.shape[:-1]) NFT = sp.linop.NUFFT(ishape[1:], coord=coord) NFTs = Diags([DLD(NFT,device=device) for i in range(n_Channel)],oshape,ishape) # B1 = sp.linop.ToDevice(NFT.ishape,idevice=sp.Device(-1),odevice=device) # B2 = sp.linop.ToDevice(NFT.oshape,idevice=sp.Device(-1),odevice=device) # NFTs = Diags([B2.H*NFT*B1 for i in range(n_Channel)],oshape,ishape) # i_vec_len = 1 # for tmp in ishape: # i_vec_len = i_vec_len * tmp # o_vec_len = 1 # for tmp in oshape: # o_vec_len = o_vec_len * tmp # NFTs = sp.linop.Diag([B2.H*NFT*B1 for i in range(n_Channel)]) # R1 = sp.linop.Reshape(oshape=(o_vec_len,),ishape=oshape) # R2 = sp.linop.Reshape(oshape=(i_vec_len,),ishape=ishape) # NFTs = R1.H*NFTs*R2 return NFTs
def DLD(Linop, device = sp.Device(-1)): B1 = sp.linop.ToDevice(Linop.ishape,idevice=sp.Device(-1),odevice=device) B2 = sp.linop.ToDevice(Linop.oshape,idevice=sp.Device(-1),odevice=device) Linop = B2.H*Linop*B1 return Linop
def __init__(self, rdr, tbl, mps, psf, phi, spr='W', cft=True, \ lmb=1e-5, mit=30, alp=0.25, tol=1e-3, dev=-1): self.cpu = -1 self.max_dims = 8 self.device = sp.Device(dev) self.xp = self.device.xp self.center = cft with self.device: self.rdr = self.xp.array(rdr).astype(self.xp.int32) self.tbl = self.xp.array(tbl) self.tbl = self.tbl / self.xp.max(self.xp.abs(self.tbl)) self.mps = self.xp.array(self._broadcast_check(mps)) self.psf = self.xp.array(self._broadcast_check(psf)) self.phi = self.xp.array(self._broadcast_check(phi)) self.lmb = lmb # Lambda. self.mit = mit # Max-Iter self.alp = alp # Step size. self.tol = tol # Tolerance. self.wx = self.psf.shape[7] self.sx = self.mps.shape[7] self.sy = self.mps.shape[6] self.sz = self.mps.shape[5] self.nc = self.mps.shape[4] self.md = self.mps.shape[3] self.tf = self.phi.shape[2] self.tk = self.phi.shape[1] self.net_acceleration = (self.tf * self.sy * self.sz) / self.rdr.shape[0] assert (self.md == 1 ) # Until multiple ESPIRiT maps is implemented. self.S = None if (spr == 'W'): wavelet_axes = tuple( [k for k in range(5, 8) if self.mps.shape[k] > 1]) self.S = sp.linop.Wavelet( [1, self.tk, 1, self.md, 1, self.sz, self.sy, self.sx], axes=wavelet_axes) elif (spr == 'T'): self.S = sp.linop.FiniteDifference( [1, self.tk, 1, self.md, 1, self.sz, self.sy, self.sx], axes=(5, 6, 7)) else: self.S = sp.linop.Identity( [1, self.tk, 1, self.md, 1, self.sz, self.sy, self.sx]) self.E = sp.linop.Multiply( [1, self.tk, 1, self.md, 1, self.sz, self.sy, self.sx], self.mps) self.R = sp.linop.Resize([1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], \ [1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.sx]) self.Fx = sp.linop.FFT([1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], axes=(7,), \ center=self.center) self.Psf = sp.linop.Multiply( [1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], self.psf) self.Fyz = sp.linop.FFT([1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], axes=(5, 6), \ center=self.center) self.K = None self._construct_kernel() self.K = sp.linop.Reshape( [ 1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], \ [ self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx]) * \ sp.linop.Sum( [self.tk, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], axes=(1,)) * \ sp.linop.Multiply([ 1, self.tk, 1, 1, self.nc, self.sz, self.sy, self.wx], self.kernel) self._construct_AHb() self.res = 0 * self.AHb self.AHA = self.S * self.E.H * self.R.H * self.Fx.H * self.Psf.H * self.Fyz.H * self.K * \ self.Fyz * self.Psf * self.Fx * self.R * self.E * self.S.H self.mxevl = sp.app.MaxEig(self.AHA, self.xp.complex64, device=dev).run() self.alp = self.alp / self.mxevl self.gradf = lambda x: self.AHA(x) - self.AHb self.proxg = sp.prox.L1Reg(self.res.shape, lmb) alg = sp.alg.GradientMethod(self.gradf, self.res, self.alp, proxg=self.proxg, accelerate=True, tol=self.tol, max_iter=self.mit) super().__init__(alg)
def use_device(self, device): self.device = sp.Device(device)
def nufft_adj1(data, traj, dcf, device=sp.Device(-1), smap=None, batch=40000, id_channel=False, ishape=None): xp = device.xp N_phase = data.shape[0] img = [] for num_ph in range(N_phase): ksp_t = data[num_ph] coord_t = traj[num_ph] dcf_t = dcf[num_ph] if smap is not None: img_shape = list(smap.shape[-3:]) elif ishape is not None: img_shape = ishape else: img_shape = sp.estimate_shape(coord_t) num_coils, num_tr, num_ro = ksp_t.shape ndim = coord_t.shape[-1] if id_channel is True: img_t = np.zeros((num_coils, ) + tuple(img_shape), dtype=np.complex64) else: img_t = 0 with device: for c in range(num_coils): img_tt = 0 for seg in range((num_tr - 1) // batch + 1): ksp_ttc = sp.to_device( ksp_t[c, seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) coord_tt = sp.to_device( coord_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) dcf_tt = sp.to_device( dcf_t[seg * batch:np.minimum((seg + 1) * batch, num_tr), ...], device) img_tt += sp.nufft_adjoint(ksp_ttc * dcf_tt, coord_tt, img_shape) # TODO smap if id_channel is True: img_t[c, ...] = sp.to_device(img_tt) else: if smap is None: img_t += xp.abs(img_tt)**2 else: img_t += sp.to_device( img_tt * xp.conj(sp.to_device(smap[c, ...], device))) img_t = sp.to_device(img_t) img.append(img_t) return np.asarray(img)
def circulant_precond(mps, weights=None, coord=None, lamda=0, device=sp.cpu_device): r"""Compute circulant preconditioner. Considers the optimization problem: .. math:: \min_P \| A^H A - F P F^H \|_2^2 where A is the Sense operator, and F is a unitary Fourier transform operator. Args: mps (array): sensitivity maps of shape [num_coils] + image shape. weights (array): k-space weights. coord (array): k-space coordinates of shape [...] + [ndim]. lamda (float): regularization. Returns: array: circulant preconditioner of image shape. """ if coord is not None: coord = sp.to_device(coord, device) if weights is not None: weights = sp.to_device(weights, device) dtype = mps.dtype device = sp.Device(device) xp = device.xp mps_shape = list(mps.shape) img_shape = mps_shape[1:] img2_shape = [i * 2 for i in img_shape] ndim = len(img_shape) scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape)**2 with device: idx = (slice(None, None, 2), ) * ndim if coord is None: ones = xp.zeros(img2_shape, dtype=dtype) if weights is None: ones[idx] = 1 else: ones[idx] = weights**0.5 psf = sp.ifft(ones) else: coord2 = coord * 2 ones = xp.ones(coord.shape[:-1], dtype=dtype) if weights is not None: ones *= weights**0.5 psf = sp.nufft_adjoint(ones, coord2, img2_shape) p_inv = 0 for mps_i in mps: mps_i = sp.to_device(mps_i, device) xcorr_fourier = xp.abs(sp.fft(xp.conj(mps_i), img2_shape))**2 xcorr = sp.ifft(xcorr_fourier) xcorr *= psf p_inv_i = sp.fft(xcorr) p_inv_i = p_inv_i[idx] p_inv_i *= scale if weights is not None: p_inv_i *= weights**0.5 p_inv += p_inv_i p_inv += lamda p_inv[p_inv == 0] = 1 p = 1 / p_inv return p.astype(dtype)
np.int(np.max(traj[..., 1]) - np.min(traj[..., 1])), np.int(np.max(traj[..., 2]) - np.min(traj[..., 2]))) ## calibration print('Calibration...') ksp = np.reshape(np.transpose(data, (2, 1, 0, 3, 4, 5)), (nCoil, nphase * npe, nfe)) dcf2 = np.reshape(np.transpose(dcf**2, (2, 1, 0, 3, 4, 5)), (nphase * npe, nfe)) coord = np.reshape(np.transpose(traj, (2, 1, 0, 3, 4, 5)), (nphase * npe, nfe, 3)) mps = ext.jsens_calib(ksp, coord, dcf2, device=sp.Device(device), ishape=tshape) S = sp.linop.Multiply(tshape, mps) imgL = cfl.read_cfl(fname + '_mrL') imgL = np.squeeze(imgL) ## registration print('Registration...') M_fields = [] iM_fields = [] if reg_flag is 1: for i in range(nphase): M_field, iM_field = reg.ANTsReg(np.abs(imgL[n_ref]), np.abs(imgL[i])) M_fields.append(M_field)
traj = traj[...,:nf_e,:] data = data[...,:nf_e,:] dcf = dcf[...,:nf_e,:] nphase,nEcalib,nCoil,npe,nfe,_ = data.shape tshape = (np.int(np.max(traj[...,0])-np.min(traj[...,0])) ,np.int(np.max(traj[...,1])-np.min(traj[...,1])) ,np.int(np.max(traj[...,2])-np.min(traj[...,2]))) ## calibration print('Calibration...') ksp = np.reshape(np.transpose(data,(2,1,0,3,4,5)),(nCoil,nphase*npe,nfe)) dcf2 = np.reshape(np.transpose(dcf**2,(2,1,0,3,4,5)),(nphase*npe,nfe)) coord = np.reshape(np.transpose(traj,(2,1,0,3,4,5)),(nphase*npe,nfe,3)) mps = ext.jsens_calib(ksp,coord,dcf2,device = sp.Device(device),ishape = tshape) S = sp.linop.Multiply(tshape, mps) imgL = cfl.read_cfl(fname+'_mrL') imgL = np.squeeze(imgL) ## registration print('Registration...') M_fields = [] iM_fields = [] if reg_flag is 1: for i in range(nphase): M_field, iM_field = reg.ANTsReg(np.abs(imgL[n_ref]), np.abs(imgL[i])) M_fields.append(M_field) iM_fields.append(iM_field) M_fields = np.asarray(M_fields)
def use_device(self, device): self.device = sp.Device(device) self.L = [sp.to_device(L_j, self.device) for L_j in self.L] self.R = [sp.to_device(R_j, self.device) for R_j in self.R]
import sigpy as sp import sigpy.mri as mr import sigpy.plot as pl import matplotlib.pyplot as plt import numpy as np # Set parameters and load dataset max_iter = 30 max_cg_iter = 5 lamda = 0.001 ksp_file = 'data/liver/ksp.npy' coord_file = 'data/liver/coord.npy' device = sp.Device(-1) xp = device.xp device.use() # Load datasets. ksp = xp.load(ksp_file) coord = xp.load(coord_file) print(f'K-space shape: {ksp.shape}') print(f'K-space dtype: {ksp.dtype}') print(f'K-space (min, max): ({np.abs(ksp).min()}, {np.abs(ksp).max()})') print(f'Coord shape: {coord.shape}') # (na, ns, 2) print(f'Coord shape: {coord.dtype}') print(f'Coord (min, max): ({coord.min()}, {coord.max()})') plt.ion()
def Demons(If, Im, level, device=-1, rho=0.7, sigmas_f=[2, 2, 2, 3], sigmas_e=[2, 2, 2, 2], sigmas_s=[.5, .5, 1, 1], iters=[40, 40, 40, 20, 20]): ### normalization?? Im = np.abs(Im) m_scale = np.max(Im) Im = Im / m_scale If = np.abs(If) If = If / m_scale ### registration M = np.zeros(Im.shape + (3, )) Mt = np.zeros(Im.shape + (3, )) for k in range(level): print('Demons Level:{}'.format(k)) ### hyperparameter assignment scale = 2**(level - k - 1) sigma_f = sigmas_f[k] sigma_e = sigmas_e[k] sigma_s = sigmas_s[k] iter_each_level = iters[k] ### Ift = ndimage.zoom(If, zoom=1 / scale, order=2) Ift = ndimage.gaussian_filter(Ift, sigma=sigma_s, truncate=2.0) Imt = ndimage.zoom(Im, zoom=1 / scale, order=2) Imt = ndimage.gaussian_filter(Imt, sigma=sigma_s, truncate=2.0) Imask = pmask(Imt + Ift, 1e-2) Isizet = Ift.shape Mt = M_scale(Mt, Isizet) uo = np.zeros_like(Mt) for i in range(iter_each_level): Imm = interp(Imt, Mt, device=sp.Device(device), k_id=1) Ifm = interp(Ift, -Mt, device=sp.Device(device), k_id=1) dI = Ifm - Imm Is = (Ifm + Imm) / 2 # Is = ndimage.gaussian_filter((Ifm+Imm)/2,sigma=sigma_s,truncate=2.0) gIx, gIy, gIz = imgrad3d(Is) gI = np.sqrt(np.abs(gIx**2 + gIy**2 + gIz**2) + 1e-6) discriminator = gI**2 + np.abs(dI)**2 dI = dI * 3.0 ux = -dI * gIx / discriminator uy = -dI * gIy / discriminator uz = -dI * gIz / discriminator mask = (gI < 1e-4) | (~Imask) ux[np.isnan(ux) | mask] = 0 uy[np.isnan(uy) | mask] = 0 uz[np.isnan(uz) | mask] = 0 ux = np.maximum(np.minimum(ux, 1), -1) uy = np.maximum(np.minimum(uy, 1), -1) uz = np.maximum(np.minimum(uz, 1), -1) ux = ndimage.gaussian_filter(ux, sigma=sigma_f) uy = ndimage.gaussian_filter(uy, sigma=sigma_f) uz = ndimage.gaussian_filter(uz, sigma=sigma_f) Mt[..., 0] = Mt[..., 0] + rho * ux + (1 - rho) * uo[..., 0] Mt[..., 1] = Mt[..., 1] + rho * uy + (1 - rho) * uo[..., 1] Mt[..., 2] = Mt[..., 2] + rho * uz + (1 - rho) * uo[..., 2] uo[..., 0] = ux uo[..., 1] = uy uo[..., 2] = uz Mt[..., 0] = ndimage.gaussian_filter(Mt[..., 0], sigma=sigma_e) Mt[..., 1] = ndimage.gaussian_filter(Mt[..., 1], sigma=sigma_e) Mt[..., 2] = ndimage.gaussian_filter(Mt[..., 2], sigma=sigma_e) ### TODO inverse combination (right now just double) M = M_scale(Mt * 2, Im.shape) return M