def generate_test_FFT(self, shape, field_scale): """ Factorized code to test 2D and 3D wrapped FFT with different homogeneous B0 field shifts at constant time. """ for L, i, n_coils in product(self.L, range(self.max_iter), self.n_coils): mask = np.random.randint(2, size=shape) field_shift = field_scale * np.random.randint(-150, 150) field_map = field_shift * np.ones(shape) # Prepare reference and wrapper operators fourier_op = FFT(mask=mask, shape=shape, n_coils=n_coils) wrapper_op = ORCFFTWrapper(fourier_op, field_map=field_map, time_vec=np.ones(shape[0]), mask=np.ones(shape), num_interpolators=L, n_bins=self.n_bins) # Forward operator img = np.squeeze(np.random.randn(n_coils, *shape) \ + 1j * np.random.randn(n_coils, *shape)) ksp_fft = fourier_op.op(img) ksp_wra = wrapper_op.op(img * np.exp(-2j * np.pi * field_shift)) np.testing.assert_allclose(ksp_fft, ksp_wra, rtol=self.rtol) # Adjoint operator ksp = np.squeeze(np.random.randn(n_coils, *shape) \ + 1j * np.random.randn(n_coils, *shape)) img_fft = fourier_op.adj_op(ksp) img_wra = wrapper_op.adj_op(ksp * np.exp(2j * np.pi * field_shift)) np.testing.assert_allclose(img_fft, img_wra, rtol=self.rtol)
def get_operators( kspace_data, loc, mask, fourier_type=1, max_iter=80, regularisation=None, linear=None, ): """Create the various operators from the config file.""" n_coils = 1 if kspace_data.ndim == 2 else kspace_data.shape[0] shape = kspace_data.shape[-2:] if fourier_type == 0: # offline reconstruction kspace_generator = KspaceGeneratorBase(full_kspace=kspace_data, mask=mask, max_iter=max_iter) fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask) elif fourier_type == 1: # online type I reconstruction kspace_generator = Column2DKspaceGenerator(full_kspace=kspace_data, mask_cols=loc) fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask) elif fourier_type == 2: # online type II reconstruction kspace_generator = DataOnlyKspaceGenerator(full_kspace=kspace_data, mask_cols=loc) fourier_op = ColumnFFT(shape=shape, n_coils=n_coils) else: raise NotImplementedError if linear is None: linear_op = Identity() else: lin_cls = linear.pop("class", None) if lin_cls == "WaveletN": linear_op = WaveletN(n_coils=n_coils, n_jobs=4, **linear) linear_op.op(np.zeros_like(kspace_data)) elif lin_cls == "Identity": linear_op = Identity() else: raise NotImplementedError prox_op = IdentityProx() if regularisation is not None: reg_cls = regularisation.pop("class") if reg_cls == "LASSO": prox_op = LASSO(weights=regularisation["weights"]) if reg_cls == "GroupLASSO": prox_op = GroupLASSO(weights=regularisation["weights"]) elif reg_cls == "OWL": prox_op = OWL(**regularisation, n_coils=n_coils, bands_shape=linear_op.coeffs_shape) elif reg_cls == "IdentityProx": prox_op = IdentityProx() linear_op = Identity() return kspace_generator, fourier_op, linear_op, prox_op
def test_FFT(self): """Test the adjoint operator for the 2D non-Cartesian Fourier transform """ for i, num_coil in product(range(self.max_iter), self.num_channels): _mask = np.random.randint(2, size=(self.N, self.N)) _samples = convert_mask_to_locations(_mask) print("Process FFT test '{0}'...", i) fourier_op_dir = FFT(samples=_samples, shape=(self.N, self.N), n_coils=num_coil) fourier_op_adj = FFT(samples=_samples, shape=(self.N, self.N), n_coils=num_coil) Img = np.squeeze( np.random.randn(num_coil, self.N, self.N) + 1j * np.random.randn(num_coil, self.N, self.N)) f = np.squeeze( np.random.randn(num_coil, self.N, self.N) + 1j * np.random.randn(num_coil, self.N, self.N)) f_p = fourier_op_dir.op(Img) I_p = fourier_op_adj.adj_op(f) x_d = np.vdot(Img, I_p) x_ad = np.vdot(f_p, f) np.testing.assert_allclose(x_d, x_ad, rtol=1e-10) print(" FFT adjoint test passes")
# View Input # image.show() # mask.show() ############################################################################# # Generate the kspace # ------------------- # # From the 3D Orange volume and the acquisition mask, we retrospectively # undersample the k-space using a cartesian acquisition mask # We then reconstruct the zero order solution as a baseline # Get the locations of the kspace samples kspace_loc = convert_mask_to_locations(mask.data) # Generate the subsampled kspace fourier_op = FFT(samples=kspace_loc, shape=image.shape) kspace_data = fourier_op.op(image) # Zero order solution image_rec0 = pysap.Image(data=fourier_op.adj_op(kspace_data), metadata=image.metadata) # image_rec0.show() # Calculate SSIM base_ssim = ssim(image_rec0, image) print(base_ssim) ############################################################################# # FISTA optimization # ------------------ #
# View Input # image.show() # mask.show() ############################################################################# # Generate the kspace # ------------------- # # From the 2D brain slice and the acquisition mask, we retrospectively # undersample the k-space using a cartesian acquisition mask # We then reconstruct the zero order solution as a baseline # Get the locations of the kspace samples and the associated observations fourier_op = FFT(samples=kspace_loc, shape=image.shape, n_coils=cartesian_ref_image.shape[0]) kspace_obs = fourier_op.op(cartesian_ref_image) # Zero Filled reconstruction zero_filled = fourier_op.adj_op(kspace_obs) image_rec0 = pysap.Image(data=np.sqrt(np.sum(np.abs(zero_filled)**2, axis=0))) # image_rec0.show() base_ssim = ssim(image_rec0, image) print('The Base SSIM is : ' + str(base_ssim)) ############################################################################# # FISTA optimization # ------------------ # # We now want to refine the zero order solution using a FISTA optimization.
from mri.operators.utils import convert_mask_to_locations from mri.reconstructors import SingleChannelReconstructor from mri.scripts.gridsearch import launch_grid from pysap.data import get_sample_data from modopt.math.metrics import ssim from modopt.opt.proximity import SparseThreshold from modopt.opt.linear import Identity import numpy as np # Load MR data and obtain kspace image = get_sample_data('2d-mri') mask = get_sample_data("cartesian-mri-mask") kspace_loc = convert_mask_to_locations(mask.data) fourier_op = FFT(samples=kspace_loc, shape=image.shape) kspace_data = fourier_op.op(image.data) # Define the keyword dictionaries based on convention ref = image metrics = { 'ssim': { 'metric': ssim, 'mapping': { 'x_new': 'test', 'y_new': None }, 'cst_kwargs': { 'ref': image, 'mask': None }, 'early_stopping': True,
def create_cartesian_metrics(online_pb, real_img, final_mask, final_k, estimates=None): metrics_fourier_op = FFT( shape=final_k.shape[-2:], n_coils=final_k.shape[0] if final_k.ndim == 3 else 1, mask=final_mask) metrics_gradient_op = OnlineGradAnalysis(fourier_op=metrics_fourier_op) metrics_gradient_op.obs_data = final_k square_mask = np.zeros(real_img.shape) real_img_size = real_img.shape img_size = [min(real_img.shape)] * 2 square_mask[real_img_size[0] // 2 - img_size[0] // 2:real_img_size[0] // 2 + img_size[0] // 2, real_img_size[1] // 2 - img_size[1] // 2:real_img_size[1] // 2 + img_size[1] // 2] = 1 def data_res_on(x): if isinstance(online_pb.gradient_op, OnlineGradSynthesis): return online_pb.gradient_op.cost(online_pb.linear_op.op(x)) return online_pb.gradient_op.cost(x) metrics = { 'psnr': { 'metric': psnr_ssos, 'mapping': { 'x_new': 'test' }, 'early_stopping': False, 'cst_kwargs': { 'ref': real_img, 'mask': square_mask }, }, 'ssim': { 'metric': ssim_ssos, 'mapping': { 'x_new': 'test' }, 'cst_kwargs': { 'ref': real_img, 'mask': square_mask }, 'early_stopping': False, }, 'data_res_off': { 'metric': lambda x: metrics_gradient_op.cost(x), 'mapping': { 'x_new': 'x' }, 'early_stopping': False, 'cst_kwargs': dict(), }, 'data_res_on': { 'metric': data_res_on, 'mapping': { 'x_new': 'x' }, 'early_stopping': False, 'cst_kwargs': dict(), }, 'x_new': { 'metric': lambda x: ssos(x), 'mapping': { 'x_new': 'x' }, 'early_stopping': False, 'cst_kwargs': dict(), }, } if online_pb.opt == 'condatvu': metrics['reg_res'] = { 'metric': lambda y: online_pb.prox_op.cost(y), 'mapping': { 'y_new': 'y' }, 'early_stopping': False, 'cst_kwargs': dict(), } else: metrics['reg_res'] = { 'metric': lambda x: online_pb.prox_op.cost(online_pb.linear_op.op(x)), 'mapping': { 'x_new': 'x' }, 'early_stopping': False, 'cst_kwargs': dict(), } metrics_config = { 'metrics': metrics, 'cost_op_kwargs': { "cost_interval": 1 }, 'metric_call_period': 1, 'estimate_call_period': estimates, } return metrics_config