def get_linear_n_regularization_operator(self, gradient_formulation, wavelet_name, dimension=2, nb_scale=3, n_coils=1, n_jobs=1, verbose=0): # A helper function to obtain linear and regularization operator try: linear_op = WaveletN( nb_scale=nb_scale, wavelet_name=wavelet_name, dim=dimension, n_coils=n_coils, n_jobs=n_jobs, verbose=verbose, ) except ValueError: # TODO this is a hack and we need to have a separate WaveletUD2. # For Undecimated wavelets, the wavelet_name is wavelet_id linear_op = WaveletUD2( wavelet_id=wavelet_name, nb_scale=nb_scale, n_coils=n_coils, n_jobs=n_jobs, verbose=verbose, ) if gradient_formulation == 'synthesis': regularizer_op = SparseThreshold(Identity(), 0, thresh_type="soft") elif gradient_formulation == "analysis": regularizer_op = SparseThreshold(linear_op, 0, thresh_type="soft") return linear_op, regularizer_op
class simplex_threshold(object): """ Simplex Threshold proximity operator This class stacks the proximity operators Simplex and Threshold Calls: * :func:`proxs.Simplex` """ def __init__(self,linop, weights,mass=None,pos_en=False): self.linop = linop self.weights = weights self.thresh = SparseThreshold(self.linop, self.weights) self.simplex = Simplex(mass=mass,pos_en=pos_en) def update_weights(self, weights): """Update weights This method update the values of the weights Parameters ---------- weights : np.ndarray Input array of weights """ self.weights = weights self.thresh = SparseThreshold(self.linop, weights) def op(self, data, extra_factor=1.0): return np.array([self.simplex.op(data[0]),self.thresh.op(data[1],extra_factor=extra_factor)])
def set_objective(self, X, y, lmbd): self.X, self.y, self.lmbd = X, y, lmbd n_features = self.X.shape[1] if self.restart_strategy == 'greedy': min_beta = 1.0 s_greedy = 1.1 p_lazy = 1.0 q_lazy = 1.0 else: min_beta = None s_greedy = None p_lazy = 1 / 30 q_lazy = 1 / 10 self.fb = ForwardBackward( x=np.zeros(n_features), # this is the coefficient w grad=GradBasic( input_data=y, op=lambda w: self.X@w, trans_op=lambda res: self.X.T@res, ), prox=SparseThreshold(Identity(), lmbd), beta_param=1.0, min_beta=min_beta, metric_call_period=None, restart_strategy=self.restart_strategy, xi_restart=0.96, s_greedy=s_greedy, p_lazy=p_lazy, q_lazy=q_lazy, auto_iterate=False, progress=False, cost=None, )
def update_weights(self, weights): """Update weights This method update the values of the weights Parameters ---------- weights : np.ndarray Input array of weights """ self.weights = weights self.thresh = SparseThreshold(self.linop, weights)
def reco_wav(kspace, gradient_op, mu=1 * 1e-8, max_iter=10, nb_scales=4, wavelet_name='db4'): # for now this is only working with my fork of pysap-fastMRI # I will get it changed soon so that we don't need to ask for a specific # pysap-mri install from ..wavelets import WaveletDecimated from mri.numerics.reconstruct import sparse_rec_fista linear_op = WaveletDecimated( nb_scale=nb_scales, wavelet_name=wavelet_name, padding='periodization', ) prox_op = LinearCompositionProx( linear_op=linear_op, prox_op=SparseThreshold(Identity(), None, thresh_type="soft"), ) gradient_op.obs_data = kspace cost_op = None x_final, _, _, _ = sparse_rec_fista( gradient_op=gradient_op, linear_op=Identity(), prox_op=prox_op, cost_op=cost_op, xi_restart=0.96, s_greedy=1.1, mu=mu, restart_strategy='greedy', pov='analysis', max_nb_of_iter=max_iter, metrics=None, metric_call_period=1, verbose=0, progress=False, ) x_final = np.abs(x_final) x_final = crop_center(x_final, 320) return x_final
def set_objective(self, X, y, lmbd): self.X, self.y, self.lmbd = X, y, lmbd n_features = self.X.shape[1] sigma_bar = 0.96 var_init = np.zeros(n_features) self.pogm = POGM( x=var_init, # this is the coefficient w u=var_init, y=var_init, z=var_init, grad=GradBasic( op=lambda w: self.X @ w, trans_op=lambda res: self.X.T @ res, data=y, ), prox=SparseThreshold(Identity(), lmbd), beta_param=1.0, metric_call_period=None, sigma_bar=sigma_bar, auto_iterate=False, progress=False, cost=None, )
def generate_operators(data, wavelet_name, samples, nb_scales=4, non_cartesian=False, uniform_data_shape=None, gradient_space="analysis"): """ Function that ease the creation of a set of common operators. .. note:: At the moment, supports only 2D data. Parameters ---------- data: ndarray the data to reconstruct: observation are expected in Fourier space. wavelet_name: str the wavelet name to be used during the decomposition. samples: np.ndarray the mask samples in the Fourier domain. nb_scales: int, default 4 the number of scales in the wavelet decomposition. non_cartesian: bool (optional, default False) if set, use the nfftw rather than the fftw. Expect an 1D input dataset. uniform_data_shape: uplet (optional, default None) the shape of the matrix containing the uniform data. Only required for non-cartesian reconstructions. gradient_space: str (optional, default 'analysis') the space where the gradient operator is defined: 'analysis' or 'synthesis' Returns ------- gradient_op: instance of class GradBase the gradient operator. linear_op: instance of LinearBase the linear operator: seek the sparsity, ie. a wavelet transform. prox_op: instance of ProximityParent the proximal operator. cost_op: instance of costObj the cost function used to check for convergence during the optimization. """ # Local imports from mri.numerics.cost import DualGapCost from mri.numerics.linear import Wavelet2 from mri.numerics.fourier import FFT2 from mri.numerics.fourier import NFFT from mri.numerics.gradient import GradAnalysis2 from mri.numerics.gradient import GradSynthesis2 from modopt.opt.proximity import SparseThreshold # Check input parameters if gradient_space not in ("analysis", "synthesis"): raise ValueError( "Unsupported gradient space '{0}'.".format(gradient_space)) if non_cartesian and data.ndim != 1: raise ValueError("Expect 1D data with the non-cartesian option.") elif non_cartesian and uniform_data_shape is None: raise ValueError("Need to set the 'uniform_data_shape' parameter with " "the non-cartesian option.") elif not non_cartesian and data.ndim != 2: raise ValueError("At the moment, this functuion only supports 2D " "data.") # Define the gradient/linear/fourier operators linear_op = Wavelet2(nb_scale=nb_scales, wavelet_name=wavelet_name) if non_cartesian: fourier_op = NFFT(samples=samples, shape=uniform_data_shape) else: fourier_op = FFT2(samples=samples, shape=data.shape) if gradient_space == "synthesis": gradient_op = GradSynthesis2(data=data, linear_op=linear_op, fourier_op=fourier_op) else: gradient_op = GradAnalysis2(data=data, fourier_op=fourier_op) # Define the proximity dual/primal operator prox_op = SparseThreshold(linear_op, None, thresh_type="soft") # Define the cost function if gradient_space == "synthesis": cost_op = None else: cost_op = DualGapCost(linear_op=linear_op, initial_cost=1e6, tolerance=1e-4, cost_interval=1, test_range=4, verbose=0, plot_output=None) return gradient_op, linear_op, prox_op, cost_op
############################################################################# # FISTA optimization # ------------------ # # We now want to refine the zero order solution using a FISTA optimization. # The cost function is set to Proximity Cost + Gradient Cost # Setup the operators linear_op = WaveletN( wavelet_name="sym8", nb_scales=4, dim=3, padding_mode="periodization", ) regularizer_op = SparseThreshold(Identity(), 2 * 1e-11, thresh_type="soft") # Setup Reconstructor reconstructor = SingleChannelReconstructor( fourier_op=fourier_op, linear_op=linear_op, regularizer_op=regularizer_op, gradient_formulation='synthesis', verbose=1, ) # Start Reconstruction x_final, costs, metrics = reconstructor.reconstruct( kspace_data=kspace_data, optimization_alg='fista', num_iterations=200, ) image_rec = pysap.Image(data=np.abs(x_final))
def __init__(self,linop, weights,mass=None,pos_en=False): self.linop = linop self.weights = weights self.thresh = SparseThreshold(self.linop, self.weights) self.simplex = Simplex(mass=mass,pos_en=pos_en)
base_ssim = ssim(image_rec0, image) print('The Base SSIM is : ' + str(base_ssim)) ############################################################################# # FISTA optimization # ------------------ # # We now want to refine the zero order solution using a FISTA optimization. # The cost function is set to Proximity Cost + Gradient Cost # Setup the operators linear_op = WaveletN( wavelet_name='sym8', nb_scale=4, ) regularizer_op = SparseThreshold(Identity(), 1.5e-8, thresh_type="soft") # Setup Reconstructor reconstructor = SelfCalibrationReconstructor( fourier_op=fourier_op, linear_op=linear_op, regularizer_op=regularizer_op, gradient_formulation='synthesis', kspace_portion=0.01, verbose=1, ) x_final, costs, metrics = reconstructor.reconstruct( kspace_data=kspace_obs, optimization_alg='fista', num_iterations=10, ) image_rec = pysap.Image(data=x_final)
def sparse_deconv_condatvu(data, psf, n_iter=300, n_reweights=1): """Sparse Deconvolution with Condat-Vu Parameters ---------- data : np.ndarray Input data, 2D image psf : np.ndarray Input PSF, 2D image n_iter : int, optional Maximum number of iterations n_reweights : int, optional Number of reweightings Returns ------- np.ndarray deconvolved image """ # Print the algorithm set-up print(condatvu_logo()) # Define the wavelet filters filters = (get_cospy_filters( data.shape, transform_name='LinearWaveletTransformATrousAlgorithm')) # Set the reweighting scheme reweight = cwbReweight(get_weights(data, psf, filters)) # Set the initial variable values primal = np.ones(data.shape) dual = np.ones(filters.shape) # Set the gradient operators grad_op = GradBasic(data, lambda x: psf_convolve(x, psf), lambda x: psf_convolve(x, psf, psf_rot=True)) # Set the linear operator linear_op = WaveletConvolve2(filters) # Set the proximity operators prox_op = Positivity() prox_dual_op = SparseThreshold(linear_op, reweight.weights) # Set the cost function cost_op = costObj([grad_op, prox_op, prox_dual_op], tolerance=1e-6, cost_interval=1, plot_output=True, verbose=False) # Set the optimisation algorithm alg = Condat(primal, dual, grad_op, prox_op, prox_dual_op, linear_op, cost_op, rho=0.8, sigma=0.5, tau=0.5, auto_iterate=False) # Run the algorithm alg.iterate(max_iter=n_iter) # Implement reweigting for rw_num in range(n_reweights): print(' - Reweighting: {}'.format(rw_num + 1)) reweight.reweight(linear_op.op(alg.x_final)) alg.iterate(max_iter=n_iter) # Return the final result return alg.x_final
def sparse_rec_fista(data, wavelet_name, samples, mu, nb_scales=4, lambda_init=1.0, max_nb_of_iter=300, atol=1e-4, non_cartesian=False, uniform_data_shape=None, verbose=0): """ The FISTA sparse reconstruction without reweightings. .. note:: At the moment, supports only 2D data. Parameters ---------- data: ndarray the data to reconstruct (observation are expected in the Fourier space). wavelet_name: str the wavelet name to be used during the decomposition. samples: np.ndarray the mask samples in the Fourier domain. mu: float coefficient of regularization. nb_scales: int, default 4 the number of scales in the wavelet decomposition. lambda_init: float, (default 1.0) initial value for the FISTA step. max_nb_of_iter: int (optional, default 300) the maximum number of iterations in the Condat-Vu proximal-dual splitting algorithm. atol: float (optional, default 1e-4) tolerance threshold for convergence. non_cartesian: bool (optional, default False) if set, use the nfftw rather than the fftw. Expect an 1D input dataset. uniform_data_shape: uplet (optional, default None) the shape of the matrix containing the uniform data. Only required for non-cartesian reconstructions. verbose: int (optional, default 0) the verbosity level. Returns ------- x_final: ndarray the estimated FISTA solution. transform: a WaveletTransformBase derived instance the wavelet transformation instance. """ # Check inputs start = time.clock() if non_cartesian and data.ndim != 1: raise ValueError("Expect 1D data with the non-cartesian option.") elif non_cartesian and uniform_data_shape is None: raise ValueError("Need to set the 'uniform_data_shape' parameter with " "the non-cartesian option.") elif not non_cartesian and data.ndim != 2: raise ValueError("At the moment, this functuion only supports 2D " "data.") # Define the gradient/linear/fourier operators linear_op = Wavelet2( nb_scale=nb_scales, wavelet_name=wavelet_name) if non_cartesian: fourier_op = NFFT2( samples=samples, shape=uniform_data_shape) else: fourier_op = FFT2( samples=samples, shape=data.shape) gradient_op = GradSynthesis2( data=data, linear_op=linear_op, fourier_op=fourier_op) # Define the initial primal and dual solutions x_init = np.zeros(fourier_op.shape, dtype=np.complex) alpha = linear_op.op(x_init) alpha[...] = 0.0 # Welcome message if verbose > 0: print(fista_logo()) print(" - mu: ", mu) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - data: ", data.shape) print(" - wavelet: ", wavelet_name, "-", nb_scales) print(" - max iterations: ", max_nb_of_iter) print(" - image variable shape: ", x_init.shape) print(" - alpha variable shape: ", alpha.shape) print("-" * 40) # Define the proximity dual operator weights = copy.deepcopy(alpha) weights[...] = mu prox_op = SparseThreshold(linear_op, weights, thresh_type="soft") # Define the optimizer cost_op = None opt = ForwardBackward( x=alpha, grad=gradient_op, prox=prox_op, cost=cost_op, auto_iterate=False) # Perform the reconstruction end = time.clock() if verbose > 0: print("Starting optimization...") opt.iterate(max_iter=max_nb_of_iter) if verbose > 0: # cost_op.plot_cost() # print(" - final iteration number: ", cost_op._iteration) # print(" - final log10 cost value: ", np.log10(cost_op.cost)) print(" - converged: ", opt.converge) print("Done.") print("Execution time: ", end - start, " seconds") print("-" * 40) x_final = linear_op.adj_op(opt.x_final) return x_final, linear_op.transform
def sparse_rec_condatvu(data, wavelet_name, samples, nb_scales=4, std_est=None, std_est_method=None, std_thr=2., mu=1e-6, tau=None, sigma=None, relaxation_factor=1.0, nb_of_reweights=1, max_nb_of_iter=150, add_positivity=False, atol=1e-4, non_cartesian=False, uniform_data_shape=None, verbose=0): """ The Condat-Vu sparse reconstruction with reweightings. .. note:: At the moment, supports only 2D data. Parameters ---------- data: ndarray the data to reconstruct: observation are expected in Fourier space. wavelet_name: str the wavelet name to be used during the decomposition. samples: np.ndarray the mask samples in the Fourier domain. nb_scales: int, default 4 the number of scales in the wavelet decomposition. std_est: float, default None the noise std estimate. If None use the MAD as a consistent estimator for the std. std_est_method: str, default None if the standard deviation is not set, estimate this parameter using the mad routine in the image ('primal') or in the sparse wavelet decomposition ('dual') domain. std_thr: float, default 2. use this treshold expressed as a number of sigma in the residual proximity operator during the thresholding. mu: float, default 1e-6 regularization hyperparameter. tau, sigma: float, default None parameters of the Condat-Vu proximal-dual splitting algorithm. If None estimates these parameters. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. nb_of_reweights: int, default 1 the number of reweightings. max_nb_of_iter: int, default 150 the maximum number of iterations in the Condat-Vu proximal-dual splitting algorithm. add_positivity: bool, default False by setting this option, set the proximity operator to identity or positive. atol: float, default 1e-4 tolerance threshold for convergence. non_cartesian: bool (optional, default False) if set, use the nfftw rather than the fftw. Expect an 1D input dataset. uniform_data_shape: uplet (optional, default None) the shape of the matrix containing the uniform data. Only required for non-cartesian reconstructions. verbose: int, default 0 the verbosity level. Returns ------- x_final: ndarray the estimated CONDAT-VU solution. transform: a WaveletTransformBase derived instance the wavelet transformation instance. """ # Check inputs start = time.clock() if non_cartesian and data.ndim != 1: raise ValueError("Expect 1D data with the non-cartesian option.") elif non_cartesian and uniform_data_shape is None: raise ValueError("Need to set the 'uniform_data_shape' parameter with " "the non-cartesian option.") elif not non_cartesian and data.ndim != 2: raise ValueError("At the moment, this functuion only supports 2D " "data.") if std_est_method not in (None, "primal", "dual"): raise ValueError( "Unrecognize std estimation method '{0}'.".format(std_est_method)) # Define the gradient/linear/fourier operators linear_op = Wavelet2( nb_scale=nb_scales, wavelet_name=wavelet_name) if non_cartesian: data_shape = uniform_data_shape fourier_op = NFFT2( samples=samples, shape=uniform_data_shape) else: data_shape = data.shape fourier_op = FFT2( samples=samples, shape=data.shape) gradient_op = GradAnalysis2( data=data, fourier_op=fourier_op) # Define the initial primal and dual solutions x_init = np.zeros(fourier_op.shape, dtype=np.complex) weights = linear_op.op(x_init) # Define the weights used during the thresholding in the dual domain, # the reweighting strategy, and the prox dual operator # Case1: estimate the noise std in the image domain if std_est_method == "primal": if std_est is None: std_est = sigma_mad(gradient_op.MtX(data)) weights[...] = std_thr * std_est reweight_op = cwbReweight(weights) prox_dual_op = SparseThreshold(linear_op, reweight_op.weights, thresh_type="soft") # Case2: estimate the noise std in the sparse wavelet domain elif std_est_method == "dual": if std_est is None: std_est = 0.0 weights[...] = std_thr * std_est reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr) prox_dual_op = SparseThreshold(linear_op, weights=reweight_op.weights, thresh_type="soft") # Case3: manual regularization mode, no reweighting else: weights[...] = mu reweight_op = None prox_dual_op = SparseThreshold(linear_op, weights, thresh_type="soft") nb_of_reweights = 0 # Define the Condat Vu optimizer: define the tau and sigma in the # Condat-Vu proximal-dual splitting algorithm if not already provided. # Check also that the combination of values will lead to convergence. norm = linear_op.l2norm(data_shape) lipschitz_cst = gradient_op.spec_rad if sigma is None: sigma = 0.5 if tau is None: # to avoid numerics troubles with the convergence bound eps = 1.0e-8 # due to the convergence bound tau = 1.0 / (lipschitz_cst/2 + sigma * norm**2 + eps) convergence_test = ( 1.0 / tau - sigma * norm ** 2 >= lipschitz_cst / 2.0) # Define initial primal and dual solutions primal = np.zeros(data_shape, dtype=np.complex) dual = linear_op.op(primal) dual[...] = 0.0 # Welcome message if verbose > 0: print(condatvu_logo()) print(" - mu: ", mu) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - tau: ", tau) print(" - sigma: ", sigma) print(" - rho: ", relaxation_factor) print(" - std: ", std_est) print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test) print(" - data: ", data.shape) print(" - wavelet: ", wavelet_name, "-", nb_scales) print(" - max iterations: ", max_nb_of_iter) print(" - number of reweights: ", nb_of_reweights) print(" - primal variable shape: ", primal.shape) print(" - dual variable shape: ", dual.shape) print("-" * 40) # Define the proximity operator if add_positivity: prox_op = Positive() else: prox_op = Identity() # Define the cost function cost_op = DualGapCost( linear_op=linear_op, initial_cost=1e6, tolerance=1e-4, cost_interval=1, test_range=4, verbose=0, plot_output=None) # Define the optimizer opt = Condat( x=primal, y=dual, grad=gradient_op, prox=prox_op, prox_dual=prox_dual_op, linear=linear_op, cost=cost_op, rho=relaxation_factor, sigma=sigma, tau=tau, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=False) # Perform the first reconstruction if verbose > 0: print("Starting optimization...") opt.iterate(max_iter=max_nb_of_iter) # Loop through the number of reweightings for reweight_index in range(nb_of_reweights): # Generate the new weights following reweighting prescription if std_est_method == "primal": reweight_op.reweight(linear_op.op(opt._x_new)) else: std_est = reweight_op.reweight(opt._x_new) # Welcome message if verbose > 0: print(" - reweight: ", reweight_index + 1) print(" - std: ", std_est) # Update the weights in the dual proximity operator prox_dual_op.weights = reweight_op.weights # Perform optimisation with new weights opt.iterate(max_iter=max_nb_of_iter) # Goodbye message end = time.clock() if verbose > 0: print(" - final iteration number: ", cost_op._iteration) print(" - final cost value: ", cost_op.cost) print(" - converged: ", opt.converge) print("Done.") print("Execution time: ", end - start, " seconds") print("-" * 40) # Get the final solution x_final = opt.x_final linear_op.transform.analysis_data = unflatten( opt.y_final, linear_op.coeffs_shape) return x_final, linear_op.transform