def _exp(b1): device = sp.get_device(b1) xp = device.xp with device: alpha = xp.abs(b1) phi = xp.angle(b1) cos_alpha = xp.cos(alpha / 2) sin_alpha = xp.sin(alpha / 2) cos_phi = xp.cos(phi) sin_phi = xp.sin(phi) return xp.array( [[cos_alpha, -1j * sin_alpha * cos_phi - sin_alpha * sin_phi], [-1j * sin_alpha * cos_phi + sin_alpha * sin_phi, cos_alpha]])
def ConvSense(img_ker_shape, mps_ker, coord=None, weights=None, grd_shape=None, comm=None): """Convolution linear operator with sensitivity maps kernel in k-space. Args: img_ker_shape (tuple of ints): image kernel shape. mps_ker (array): sensitivity maps kernel. coord (array): coordinates. grd_shape (None or list): Shape of grid. """ ndim = len(img_ker_shape) num_coils = mps_ker.shape[0] mps_ker = mps_ker.reshape((num_coils, 1) + mps_ker.shape[1:]) R = sp.linop.Reshape((1, ) + tuple(img_ker_shape), img_ker_shape) C = sp.linop.ConvolveData(R.oshape, mps_ker, mode='valid', multi_channel=True) A = C * R if coord is not None: if grd_shape is None: grd_shape = sp.estimate_shape(coord) else: grd_shape = list(grd_shape) grd_shape = [num_coils] + grd_shape iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0)) N = sp.linop.NUFFT(grd_shape, coord) A = N * iF * A if weights is not None: with sp.get_device(weights): P = sp.linop.Multiply(A.oshape, weights**0.5) A = P * A if comm is not None: C = sp.linop.AllReduceAdjoint(img_ker_shape, comm, in_place=True) A = A * C return A
def Sense(mps, coord=None, weights=None, ishape=None, coil_batch_size=None): """Sense linear operator. Args: mps (array): sensitivity maps of length = number of channels. coord (None or array): coordinates. """ num_coils = len(mps) if ishape is None: ishape = mps.shape[1:] img_ndim = mps.ndim - 1 else: img_ndim = len(ishape) num_coils = len(mps) if coil_batch_size is None: coil_batch_size = num_coils if coil_batch_size < len(mps): num_coil_batches = (num_coils + coil_batch_size - 1) // coil_batch_size return sp.linop.Vstack([ Sense(mps[c::num_coil_batches], coord=coord, weights=weights, ishape=ishape) for c in range(num_coil_batches) ], axis=0) S = sp.linop.Multiply(ishape, mps) if coord is None: F = sp.linop.FFT(S.oshape, axes=range(-img_ndim, 0)) else: F = sp.linop.NUFFT(S.oshape, coord) A = F * S if weights is not None: with sp.get_device(weights): P = sp.linop.Multiply(F.oshape, weights**0.5) A = P * A A.repr_str = 'Sense' return A
def labels_to_scores(labels): """Convert labels to scores. Args: labels (array): One-dimensional label array. Returns: array: Score array of shape (len(labels), max(labels) + 1). """ device = sp.get_device(labels) xp = device.xp with device: num_classes = labels.max() + 1 scores = xp.zeros([len(labels), num_classes], dtype=np.float32) scores[xp.arange(len(labels)), labels] = 1 return scores
def ConvImage(mps_ker_shape, img_ker, coord=None, weights=None, grd_shape=None): """Convolution linear operator with image kernel in k-space. Args: mps_ker_shape (tuple of ints): sensitivity maps kernel shape. img_ker (array): image kernel. coord (array): coordinates. grd_shape (None or list): Shape of grid. """ ndim = img_ker.ndim num_coils = mps_ker_shape[0] img_ker = img_ker.reshape((1, ) + img_ker.shape) R = sp.linop.Reshape((num_coils, 1) + tuple(mps_ker_shape[1:]), mps_ker_shape) C = sp.linop.ConvolveFilter(R.oshape, img_ker, mode='valid', multi_channel=True) A = C * R if coord is not None: num_coils = mps_ker_shape[0] if grd_shape is None: grd_shape = sp.estimate_shape(coord) else: grd_shape = list(grd_shape) grd_shape = [num_coils] + grd_shape iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0)) N = sp.linop.NUFFT(grd_shape, coord) A = N * iF * A if weights is not None: with sp.get_device(weights): P = sp.linop.Multiply(A.oshape, weights**0.5) A = P * A return A
def _exp(b1): device = sp.get_device(b1) xp = device.xp with device: alpha = xp.abs(b1) phi = xp.angle(b1) cos_alpha = xp.cos(alpha / 2) sin_alpha = xp.sin(alpha / 2) cos_phi = xp.cos(phi) sin_phi = xp.sin(phi) R = xp.zeros((2, 2), dtype=xp.complex) R[0, 0] = cos_alpha R[0, 1] = -1j * sin_alpha * cos_phi - sin_alpha * sin_phi R[1, 0] = -1j * sin_alpha * cos_phi + sin_alpha * sin_phi R[1, 1] = cos_alpha return R
def free_induction_decay(input, f0, t1, t2, dt): """Simulate free induction decay to input magnetization. Off-resonance, T1 recovery, and T2 relaxation array dimensions must be consistent with the input batch dimensions. Args: input (array): magnetization array. f0 (array): off-resonance frequency values. t1 (array): T1 recovery values. t2 (array): T2 relaxation values. dt (float): free induction decay duration. Returns: array: magnetization array after hard pulse rotation, in representation consistent with input. """ p = to_density_matrix(input) device = sp.get_device(input) xp = device.xp with device: e2 = xp.exp(-dt / t2) e1 = xp.exp(-dt / t1) e0 = xp.exp(-1j * dt * 2 * np.pi * f0) p = p.copy() p[..., 0, 0] *= e1 p[..., 1, 1] *= e1 p[..., 1, 0] *= e0 * e2 p[..., 0, 1] *= xp.conj(e0) * e2 p[..., 0, 0] += 1 - e1 if is_bloch_vector(input): return to_bloch_vector(p) else: return p
def hard_pulse_rotation(input, b1): """Apply hard pulse rotation to input magnetization. Args: input (array): magnetization array. b1 (complex float): complex B1 value in radian. Returns: array: magnetization array after hard pulse rotation, in representation consistent with input. """ p = to_density_matrix(input) device = sp.get_device(p) with device: b1 = sp.to_device(b1, device) p = _exp(b1) @ p @ _exp(-b1) if is_bloch_vector(input): return to_bloch_vector(p) else: return p
def g(input): device = sp.get_device(input) xp = device.xp with device: return lamda * xp.sum(xp.abs(W(input))).item()
def _estimate_weights(y, weights, coord): if weights is None and coord is None: with sp.get_device(y): weights = (sp.rss(y, axes=(0, )) > 0).astype(y.dtype) return weights
def Sense(mps, coord=None, weights=None, tseg=None, ishape=None, coil_batch_size=None, comm=None, transp_nufft=False): """Sense linear operator. Args: mps (array): sensitivity maps of length = number of channels. coord (None or array): coordinates. weights (None or array): k-space weights. Useful for soft-gating or density compensation. tseg (None or Dictionary): parameters for time-segmented off-resonance correction. Parameters are 'b0' (array), 'dt' (float), 'lseg' (int), and 'n_bins' (int). Lseg is the number of time segments used, and n_bins is the number of histogram bins. ishape (None or tuple): image shape. coil_batch_size (None or int): batch size for processing multi-channel. When None, process all coils at the same time. Useful for saving memory. comm (None or `sigpy.Communicator`): communicator for distributed computing. """ # Get image shape and dimension. num_coils = len(mps) if ishape is None: ishape = mps.shape[1:] img_ndim = mps.ndim - 1 else: img_ndim = len(ishape) # Serialize linop if coil_batch_size is smaller than num_coils. num_coils = len(mps) if coil_batch_size is None: coil_batch_size = num_coils if coil_batch_size < len(mps): num_coil_batches = (num_coils + coil_batch_size - 1) // coil_batch_size A = sp.linop.Vstack([ Sense(mps[c::num_coil_batches], coord=coord, weights=weights, ishape=ishape) for c in range(num_coil_batches) ], axis=0) if comm is not None: C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True) A = A * C return A # Create Sense linear operator S = sp.linop.Multiply(ishape, mps) if tseg is None: if coord is None: F = sp.linop.FFT(S.oshape, axes=range(-img_ndim, 0)) else: if transp_nufft is False: F = sp.linop.NUFFT(S.oshape, coord) else: F = sp.linop.NUFFT(S.oshape, -coord).H A = F * S # If B0 provided, perform time-segmented off-resonance compensation else: if transp_nufft is False: F = sp.linop.NUFFT(S.oshape, coord) else: F = sp.linop.NUFFT(S.oshape, -coord).H time = len(coord) * tseg['dt'] b, ct = sp.mri.util.tseg_off_res_b_ct(tseg['b0'], tseg['n_bins'], tseg['lseg'], tseg['dt'], time) for ii in range(tseg['lseg']): Bi = sp.linop.Multiply(F.oshape, b[:, ii]) Cti = sp.linop.Multiply(S.ishape, ct[:, ii].reshape(S.ishape)) # operation below is effectively A = A + Bi * F(Cti * S) if ii == 0: A = Bi * F * S * Cti else: A = A + Bi * F * S * Cti if weights is not None: with sp.get_device(weights): P = sp.linop.Multiply(F.oshape, weights**0.5) A = P * A if comm is not None: C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True) A = A * C A.repr_str = 'Sense' return A
def normalize(x): with sp.get_device(x): return xp.sum(xp.abs(x)**2, axis=-2, keepdims=True)**0.5
def espirit_maps(ksp, calib_width=24, thresh=0.001, kernel_width=6, crop=0.8, max_power_iter=30, device=sp.cpu_device, output_eigenvalue=False): """Generate ESPIRiT maps from k-space. Currently only supports outputting one set of maps. Args: ksp (array): k-space array of shape [num_coils, n_ndim, ..., n_1] calib (tuple of ints): length-2 image shape. thresh (float): threshold for the calibration matrix. kernel_width (int): kernel width for the calibration matrix. max_power_iter (int): maximum number of power iterations. device (Device): computing device. crop (int): cropping threshold. Returns: array: ESPIRiT maps of the same shape as ksp. References: Martin Uecker, Peng Lai, Mark J. Murphy, Patrick Virtue, Michael Elad, John M. Pauly, Shreyas S. Vasanawala, and Michael Lustig ESPIRIT - An Eigenvalue Approach to Autocalibrating Parallel MRI: Where SENSE meets GRAPPA. Magnetic Resonance in Medicine, 71:990-1001 (2014) """ img_ndim = ksp.ndim - 1 num_coils = len(ksp) with sp.get_device(ksp): # Get calibration region calib_shape = [num_coils] + [calib_width] * img_ndim calib = sp.resize(ksp, calib_shape) calib = sp.to_device(calib, device) device = sp.Device(device) xp = device.xp with device: # Get calibration matrix kernel_shape = [num_coils] + [kernel_width] * img_ndim kernel_strides = [1] * (img_ndim + 1) mat = sp.array_to_blocks(calib, kernel_shape, kernel_strides) mat = mat.reshape([-1, sp.prod(kernel_shape)]) # Perform SVD on calibration matrix _, S, VH = xp.linalg.svd(mat, full_matrices=False) VH = VH[S > thresh * S.max(), :] # Get kernels num_kernels = len(VH) kernels = VH.reshape([num_kernels] + kernel_shape) img_shape = ksp.shape[1:] # Get covariance matrix in image domain AHA = xp.zeros(img_shape[::-1] + (num_coils, num_coils), dtype=ksp.dtype) for kernel in kernels: img_kernel = sp.ifft(sp.resize(kernel, ksp.shape), axes=range(-img_ndim, 0)) aH = xp.expand_dims(img_kernel.T, axis=-1) a = xp.conj(aH.swapaxes(-1, -2)) AHA += aH @ a AHA *= (sp.prod(img_shape) / kernel_width**img_ndim) # Power Iteration to compute top eigenvector mps = xp.ones(ksp.shape[::-1] + (1, ), dtype=ksp.dtype) for _ in range(max_power_iter): sp.copyto(mps, AHA @ mps) eig_value = xp.sum(xp.abs(mps)**2, axis=-2, keepdims=True)**0.5 mps /= eig_value # Normalize phase with respect to first channel mps = mps.T[0] mps *= xp.conj(mps[0] / xp.abs(mps[0])) # Crop maps by thresholding eigenvalue eig_value = eig_value.T[0] mps *= eig_value > crop if output_eigenvalue: return mps, eig_value else: return mps
def Sense(mps, coord=None, weights=None, ishape=None, coil_batch_size=None, comm=None): """Sense linear operator. Args: mps (array): sensitivity maps of length = number of channels. coord (None or array): coordinates. weights (None or array): k-space weights. Useful for soft-gating or density compensation. ishape (None or tuple): image shape. coil_batch_size (None or int): batch size for processing multi-channel. When None, process all coils at the same time. Useful for saving memory. comm (None or `sigpy.Communicator`): communicator for distributed computing. """ # Get image shape and dimension. num_coils = len(mps) if ishape is None: ishape = mps.shape[1:] img_ndim = mps.ndim - 1 else: img_ndim = len(ishape) # Serialize linop if coil_batch_size is smaller than num_coils. num_coils = len(mps) if coil_batch_size is None: coil_batch_size = num_coils if coil_batch_size < len(mps): num_coil_batches = (num_coils + coil_batch_size - 1) // coil_batch_size A = sp.linop.Vstack([ Sense(mps[c::num_coil_batches], coord=coord, weights=weights, ishape=ishape) for c in range(num_coil_batches) ], axis=0) if comm is not None: C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True) A = A * C return A # Create Sense linear operator S = sp.linop.Multiply(ishape, mps) if coord is None: F = sp.linop.FFT(S.oshape, axes=range(-img_ndim, 0)) else: F = sp.linop.NUFFT(S.oshape, coord) A = F * S if weights is not None: with sp.get_device(weights): P = sp.linop.Multiply(F.oshape, weights**0.5) A = P * A if comm is not None: C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True) A = A * C A.repr_str = 'Sense' return A
def g(x): device = sp.get_device(x) xp = device.xp with device: return lamda * xp.sum(xp.abs(x)).item()
def __init__(self, ksp, calib_width=24, thresh=0.02, kernel_width=6, crop=0.95, max_iter=100, device=sp.cpu_device, output_eigenvalue=False, show_pbar=True): self.device = sp.Device(device) self.output_eigenvalue = output_eigenvalue self.crop = crop img_ndim = ksp.ndim - 1 num_coils = len(ksp) with sp.get_device(ksp): # Get calibration region calib_shape = [num_coils] + [calib_width] * img_ndim calib = sp.resize(ksp, calib_shape) calib = sp.to_device(calib, device) xp = self.device.xp with self.device: # Get calibration matrix kernel_shape = [num_coils] + [kernel_width] * img_ndim kernel_strides = [1] * (img_ndim + 1) mat = sp.array_to_blocks(calib, kernel_shape, kernel_strides) mat = mat.reshape([-1, sp.prod(kernel_shape)]) # Perform SVD on calibration matrix _, S, VH = xp.linalg.svd(mat, full_matrices=False) VH = VH[S > thresh * S.max(), :] # Get kernels num_kernels = len(VH) kernels = VH.reshape([num_kernels] + kernel_shape) img_shape = ksp.shape[1:] # Get covariance matrix in image domain AHA = xp.zeros(img_shape[::-1] + (num_coils, num_coils), dtype=ksp.dtype) for kernel in kernels: img_kernel = sp.ifft(sp.resize(kernel, ksp.shape), axes=range(-img_ndim, 0)) aH = xp.expand_dims(img_kernel.T, axis=-1) a = xp.conj(aH.swapaxes(-1, -2)) AHA += aH @ a AHA *= (sp.prod(img_shape) / kernel_width**img_ndim) self.mps = xp.ones(ksp.shape[::-1] + (1, ), dtype=ksp.dtype) def forward(x): with self.device: return AHA @ x def normalize(x): with self.device: return xp.sum(xp.abs(x)**2, axis=-2, keepdims=True)**0.5 alg = sp.alg.PowerMethod(forward, self.mps, norm_func=normalize, max_iter=max_iter) super().__init__(alg, show_pbar=show_pbar)
def forward(x): with sp.get_device(x): return AHA @ x