def __init__(self, u, x, y, z, grad, prox, cost='auto', linear=None, beta_param=1.0, sigma_bar=1.0, auto_iterate=True, metric_call_period=5, metrics={}, **kwargs): # Set default algorithm properties super(POGM, self).__init__(metric_call_period=metric_call_period, metrics=metrics, linear=linear, **kwargs) # set the initial variable values (self._check_input_data(data) for data in (u, x, y, z)) self._u_old = np.copy(u) self._x_old = np.copy(x) self._y_old = np.copy(y) self._z = np.copy(z) # Set the algorithm operators (self._check_operator(operator) for operator in (grad, prox, cost)) self._grad = grad self._prox = prox self._linear = linear if cost == 'auto': self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost # If linear is None, make it Identity for call of metrics if self._linear is None: self._linear = Identity() # Set the algorithm parameters (self._check_param(param) for param in (beta_param, sigma_bar)) if not (0 <= sigma_bar <= 1): raise ValueError('The sigma bar parameter needs to be in [0, 1]') self._beta = self.step_size or beta_param self._sigma_bar = sigma_bar self._xi = self._sigma = self._t_old = 1.0 self._grad.get_grad(self._x_old) self._g_old = self._grad.grad # Automatically run the algorithm if auto_iterate: self.iterate()
def get_operators( kspace_data, loc, mask, fourier_type=1, max_iter=80, regularisation=None, linear=None, ): """Create the various operators from the config file.""" n_coils = 1 if kspace_data.ndim == 2 else kspace_data.shape[0] shape = kspace_data.shape[-2:] if fourier_type == 0: # offline reconstruction kspace_generator = KspaceGeneratorBase(full_kspace=kspace_data, mask=mask, max_iter=max_iter) fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask) elif fourier_type == 1: # online type I reconstruction kspace_generator = Column2DKspaceGenerator(full_kspace=kspace_data, mask_cols=loc) fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask) elif fourier_type == 2: # online type II reconstruction kspace_generator = DataOnlyKspaceGenerator(full_kspace=kspace_data, mask_cols=loc) fourier_op = ColumnFFT(shape=shape, n_coils=n_coils) else: raise NotImplementedError if linear is None: linear_op = Identity() else: lin_cls = linear.pop("class", None) if lin_cls == "WaveletN": linear_op = WaveletN(n_coils=n_coils, n_jobs=4, **linear) linear_op.op(np.zeros_like(kspace_data)) elif lin_cls == "Identity": linear_op = Identity() else: raise NotImplementedError prox_op = IdentityProx() if regularisation is not None: reg_cls = regularisation.pop("class") if reg_cls == "LASSO": prox_op = LASSO(weights=regularisation["weights"]) if reg_cls == "GroupLASSO": prox_op = GroupLASSO(weights=regularisation["weights"]) elif reg_cls == "OWL": prox_op = OWL(**regularisation, n_coils=n_coils, bands_shape=linear_op.coeffs_shape) elif reg_cls == "IdentityProx": prox_op = IdentityProx() linear_op = Identity() return kspace_generator, fourier_op, linear_op, prox_op
def set_objective(self, X, y, lmbd): self.X, self.y, self.lmbd = X, y, lmbd n_features = self.X.shape[1] if self.restart_strategy == 'greedy': min_beta = 1.0 s_greedy = 1.1 p_lazy = 1.0 q_lazy = 1.0 else: min_beta = None s_greedy = None p_lazy = 1 / 30 q_lazy = 1 / 10 self.fb = ForwardBackward( x=np.zeros(n_features), # this is the coefficient w grad=GradBasic( input_data=y, op=lambda w: self.X@w, trans_op=lambda res: self.X.T@res, ), prox=SparseThreshold(Identity(), lmbd), beta_param=1.0, min_beta=min_beta, metric_call_period=None, restart_strategy=self.restart_strategy, xi_restart=0.96, s_greedy=s_greedy, p_lazy=p_lazy, q_lazy=q_lazy, auto_iterate=False, progress=False, cost=None, )
def get_linear_n_regularization_operator(wavelet_name, image_shape, dimension=2, nb_scale=3, n_coils=1, n_jobs=1, verbose=0): # A helper function to obtain linear and regularization operator try: linear_op = WaveletN( nb_scale=nb_scale, wavelet_name=wavelet_name, dim=dimension, n_coils=n_coils, n_jobs=n_jobs, verbose=verbose, ) except ValueError: # TODO this is a hack and we need to have a separate WaveletUD2. # For Undecimated wavelets, the wavelet_name is wavelet_id linear_op = WaveletUD2( wavelet_id=wavelet_name, nb_scale=nb_scale, n_coils=n_coils, n_jobs=n_jobs, verbose=verbose, ) linear_op.op(np.squeeze(np.zeros((n_coils, *image_shape)))) regularizer_op = WeightedSparseThreshold( linear=Identity(), weights=0, coeffs_shape=linear_op.coeffs_shape, thresh_type="soft") return linear_op, regularizer_op
def get_linear_n_regularization_operator(self, gradient_formulation, wavelet_name, dimension=2, nb_scale=3, n_coils=1, n_jobs=1, verbose=0): # A helper function to obtain linear and regularization operator try: linear_op = WaveletN( nb_scale=nb_scale, wavelet_name=wavelet_name, dim=dimension, n_coils=n_coils, n_jobs=n_jobs, verbose=verbose, ) except ValueError: # TODO this is a hack and we need to have a separate WaveletUD2. # For Undecimated wavelets, the wavelet_name is wavelet_id linear_op = WaveletUD2( wavelet_id=wavelet_name, nb_scale=nb_scale, n_coils=n_coils, n_jobs=n_jobs, verbose=verbose, ) if gradient_formulation == 'synthesis': regularizer_op = SparseThreshold(Identity(), 0, thresh_type="soft") elif gradient_formulation == "analysis": regularizer_op = SparseThreshold(linear_op, 0, thresh_type="soft") return linear_op, regularizer_op
def __init__(self, fourier_op, linear_op, regularizer_op, gradient_formulation, grad_class, init_gradient_op=True, verbose=0, **extra_grad_args): self.fourier_op = fourier_op self.linear_op = linear_op self.prox_op = regularizer_op self.gradient_method = gradient_formulation self.grad_class = grad_class self.verbose = verbose self.extra_grad_args = extra_grad_args if regularizer_op is None: warnings.warn("The prox_op is not set. Setting to identity. " "Note that optimization is just a gradient descent.") self.prox_op = Identity() # TODO try to not use gradient_formulation and # rely on static attributes # If the reconstruction formulation is synthesis, # we send the linear operator as well. if gradient_formulation == 'synthesis': self.extra_grad_args['linear_op'] = self.linear_op if init_gradient_op: self.initialize_gradient_op(**self.extra_grad_args)
def __init__(self, fourier_op, linear_op, regularizer_op=None, opt='condatvu', verbose=0): self.fourier_op = fourier_op self.linear_op = linear_op self.verbose = verbose if regularizer_op is None: warnings.warn("The prox_op is not set. Setting to identity. " "Note that optimization is just a gradient descent.") self.prox_op = IdentityProx() self.linear_op = Identity() else: self.prox_op = regularizer_op assert opt in OPTIMIZERS.keys() self.opt = opt grad_formulation = ANALYSIS_OPT.get(opt, 'synthesis') if grad_formulation == 'analysis': self.gradient_op = OnlineGradAnalysis(self.fourier_op, verbose=self.verbose, num_check_lips=0, lipschitz_cst=1.1) elif grad_formulation == 'synthesis': self.gradient_op = OnlineGradSynthesis(self.linear_op, self.fourier_op, verbose=self.verbose, num_check_lips=0, lipschitz_cst=1.1) else: raise RuntimeError("Unknown gradient formulation") self.grad_formulation = grad_formulation
def reco_wav(kspace, gradient_op, mu=1 * 1e-8, max_iter=10, nb_scales=4, wavelet_name='db4'): # for now this is only working with my fork of pysap-fastMRI # I will get it changed soon so that we don't need to ask for a specific # pysap-mri install from ..wavelets import WaveletDecimated from mri.numerics.reconstruct import sparse_rec_fista linear_op = WaveletDecimated( nb_scale=nb_scales, wavelet_name=wavelet_name, padding='periodization', ) prox_op = LinearCompositionProx( linear_op=linear_op, prox_op=SparseThreshold(Identity(), None, thresh_type="soft"), ) gradient_op.obs_data = kspace cost_op = None x_final, _, _, _ = sparse_rec_fista( gradient_op=gradient_op, linear_op=Identity(), prox_op=prox_op, cost_op=cost_op, xi_restart=0.96, s_greedy=1.1, mu=mu, restart_strategy='greedy', pov='analysis', max_nb_of_iter=max_iter, metrics=None, metric_call_period=1, verbose=0, progress=False, ) x_final = np.abs(x_final) x_final = crop_center(x_final, 320) return x_final
def __init__(self, weights, coeffs_shape, weight_type='scale_based', zero_weight_coarse=True, linear=Identity(), **kwargs): self.cf_shape = coeffs_shape self.weight_type = weight_type available_weight_type = ('scale_based', 'custom') if self.weight_type not in available_weight_type: raise ValueError('Weight type must be one of ' + ' '.join(available_weight_type)) self.zero_weight_coarse = zero_weight_coarse self.mu = weights super(WeightedSparseThreshold, self).__init__(weights=self.mu, linear=linear, **kwargs)
def set_objective(self, X, y, lmbd): self.X, self.y, self.lmbd = X, y, lmbd n_features = self.X.shape[1] sigma_bar = 0.96 var_init = np.zeros(n_features) self.pogm = POGM( x=var_init, # this is the coefficient w u=var_init, y=var_init, z=var_init, grad=GradBasic( op=lambda w: self.X @ w, trans_op=lambda res: self.X.T @ res, data=y, ), prox=SparseThreshold(Identity(), lmbd), beta_param=1.0, metric_call_period=None, sigma_bar=sigma_bar, auto_iterate=False, progress=False, cost=None, )
class GenForwardBackward(SetUp): """Generalized Forward-Backward Algorithm. This class implements algorithm 1 from :cite:`raguet2011` Parameters ---------- x : list, tuple or numpy.ndarray Initial guess for the primal variable grad : class instance Gradient operator class prox_list : list List of proximity operator class instances cost : class or str, optional Cost function class (default is 'auto'); Use 'auto' to automatically generate a costObj instance gamma_param : float, optional Initial value of the gamma parameter (default is ``1.0``) lambda_param : float, optional Initial value of the lambda parameter (default is ``1.0``) gamma_update : function, optional Gamma parameter update method (default is ``None``) lambda_update : function, optional Lambda parameter parameter update method (default is ``None``) weights : list, tuple or numpy.ndarray, optional Proximity operator weights (default is ``None``) auto_iterate : bool, optional Option to automatically begin iterations upon initialisation (default is ``True``) Notes ----- The `gamma_param` can also be set using the keyword `step_size`, which will override the value of `gamma_param`. See Also -------- SetUp : parent class """ def __init__( self, x, grad, prox_list, cost='auto', gamma_param=1.0, lambda_param=1.0, gamma_update=None, lambda_update=None, weights=None, auto_iterate=True, metric_call_period=5, metrics=None, linear=None, **kwargs, ): # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, metrics=metrics, **kwargs, ) # Set the initial variable values self._check_input_data(x) self._x_old = self.xp.copy(x) # Set the algorithm operators for operator in [grad, cost] + prox_list: self._check_operator(operator) self._grad = grad self._prox_list = self.xp.array(prox_list) self._linear = linear if cost == 'auto': self._cost_func = costObj([self._grad] + prox_list) else: self._cost_func = cost # Check if there is a linear op, needed for metrics in the FB algoritm if metrics and self._linear is None: raise ValueError( 'When using metrics, you must pass a linear operator', ) if self._linear is None: self._linear = Identity() # Set the algorithm parameters for param_val in (gamma_param, lambda_param): self._check_param(param_val) self._gamma = self.step_size or gamma_param self._lambda_param = lambda_param # Set the algorithm parameter update methods for param_update in (gamma_update, lambda_update): self._check_param_update(param_update) self._gamma_update = gamma_update self._lambda_update = lambda_update # Set the proximity weights self._set_weights(weights) # Set initial z self._z = self.xp.array( [self._x_old for i in range(self._prox_list.size)]) # Automatically run the algorithm if auto_iterate: self.iterate() def _set_weights(self, weights): """Set weights. This method sets weights on each of the proximty operators provided Parameters ---------- weights : list, tuple or numpy.ndarray List of weights Raises ------ TypeError For invalid input type ValueError If weights do not sum to one """ if isinstance(weights, type(None)): weights = self.xp.repeat( 1.0 / self._prox_list.size, self._prox_list.size, ) elif not isinstance(weights, (list, tuple, np.ndarray)): raise TypeError('Weights must be provided as a list.') weights = self.xp.array(weights) if not np.issubdtype(weights.dtype, np.floating): raise ValueError('Weights must be list of float values.') if weights.size != self._prox_list.size: raise ValueError( 'The number of weights must match the number of proximity ' + 'operators.', ) expected_weight_sum = 1.0 if self.xp.sum(weights) != expected_weight_sum: raise ValueError( 'Proximity operator weights must sum to 1.0. Current sum of ' + 'weights = {0}'.format(self.xp.sum(weights)), ) self._weights = weights def _update_param(self): """Update parameters. This method updates the values of the algorthm parameters with the methods provided """ # Update the gamma parameter. if not isinstance(self._gamma_update, type(None)): self._gamma = self._gamma_update(self._gamma) # Update lambda parameter. if not isinstance(self._lambda_update, type(None)): self._lambda_param = self._lambda_update(self._lambda_param) def _update(self): """Update. This method updates the current reconstruction Notes ----- Implements algorithm 1 from :cite:`raguet2011` """ # Calculate gradient for current iteration. self._grad.get_grad(self._x_old) # Update z values. for i in range(self._prox_list.size): z_temp = (2 * self._x_old - self._z[i] - self._gamma * self._grad.grad) z_prox = self._prox_list[i].op( z_temp, extra_factor=self._gamma / self._weights[i], ) self._z[i] += self._lambda_param * (z_prox - self._x_old) # Update current reconstruction. self._x_new = self.xp.sum( [z_i * w_i for z_i, w_i in zip(self._z, self._weights)], axis=0, ) # Update old values for next iteration. self.xp.copyto(self._x_old, self._x_new) # Update parameter values for next iteration. self._update_param() # Test cost function for convergence. if self._cost_func: self.converge = self._cost_func.get_cost(self._x_new) def iterate(self, max_iter=150): """Iterate. This method calls update until either convergence criteria is met or the maximum number of iterations is reached. Parameters ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) """ self._run_alg(max_iter) # retrieve metrics results self.retrieve_outputs() self.x_final = self._x_new def get_notify_observers_kwargs(self): """Notify observers. Return the mapping between the metrics call and the iterated variables. Returns ------- dict The mapping between the iterated variables """ return { 'x_new': self._linear.adj_op(self._x_new), 'z_new': self._z, 'idx': self.idx, } def retrieve_outputs(self): """Retrieve outputs. Declare the outputs of the algorithms as attributes: x_final, y_final, metrics. """ metrics = {} for obs in self._observers['cv_metrics']: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics
class ForwardBackward(SetUp): """Forward-Backward optimisation. This class implements standard forward-backward optimisation with an the option to use the FISTA speed-up Parameters ---------- x : numpy.ndarray Initial guess for the primal variable grad : class Gradient operator class prox : class Proximity operator class cost : class or str, optional Cost function class (default is 'auto'); Use 'auto' to automatically generate a costObj instance beta_param : float, optional Initial value of the beta parameter (default is ``1.0``) lambda_param : float, optional Initial value of the lambda parameter (default is ```1.0``) beta_update : function, optional Beta parameter update method (default is ``None``) lambda_update : function or str, optional Lambda parameter update method (default is 'fista') auto_iterate : bool, optional Option to automatically begin iterations upon initialisation (default is ``True``) Notes ----- The `beta_param` can also be set using the keyword `step_size`, which will override the value of `beta_param`. See Also -------- FISTA : complementary class SetUp : parent class """ def __init__( self, x, grad, prox, cost='auto', beta_param=1.0, lambda_param=1.0, beta_update=None, lambda_update='fista', auto_iterate=True, metric_call_period=5, metrics=None, linear=None, **kwargs, ): # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, metrics=metrics, **kwargs, ) # Set the initial variable values self._check_input_data(x) self._x_old = self.copy_data(x) self._z_old = self.copy_data(x) # Set the algorithm operators for operator in (grad, prox, cost): self._check_operator(operator) self._grad = grad self._prox = prox self._linear = linear if cost == 'auto': self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost # Check if there is a linear op, needed for metrics in the FB algoritm if metrics and self._linear is None: raise ValueError( 'When using metrics, you must pass a linear operator', ) if self._linear is None: self._linear = Identity() # Set the algorithm parameters for param_val in (beta_param, lambda_param): self._check_param(param_val) self._beta = self.step_size or beta_param self._lambda = lambda_param # Set the algorithm parameter update methods self._check_param_update(beta_update) self._beta_update = beta_update if isinstance(lambda_update, str) and lambda_update == 'fista': fista = FISTA(**kwargs) self._lambda_update = fista.update_lambda self._is_restart = fista.is_restart self._beta_update = fista.update_beta else: self._check_param_update(lambda_update) self._lambda_update = lambda_update self._is_restart = lambda *args, **kwargs: False # Automatically run the algorithm if auto_iterate: self.iterate() def _update_param(self): """Update parameters. This method updates the values of the algorthm parameters with the methods provided """ # Update the gamma parameter. if not isinstance(self._beta_update, type(None)): self._beta = self._beta_update(self._beta) # Update lambda parameter. if not isinstance(self._lambda_update, type(None)): self._lambda = self._lambda_update(self._lambda) def _update(self): """Update. This method updates the current reconstruction Notes ----- Implements algorithm 10.7 (or 10.5) from :cite:`bauschke2009` """ # Step 1 from alg.10.7. self._grad.get_grad(self._z_old) y_old = self._z_old - self._beta * self._grad.grad # Step 2 from alg.10.7. self._x_new = self._prox.op(y_old, extra_factor=self._beta) # Step 5 from alg.10.7. self._z_new = self._x_old + self._lambda * (self._x_new - self._x_old) # Restarting step from alg.4-5 in [L2018] if self._is_restart(self._z_old, self._x_new, self._x_old): self._z_new = self._x_new # Update old values for next iteration. self._x_old = self.xp.copy(self._x_new) self._z_old = self.xp.copy(self._z_new) # Update parameter values for next iteration. self._update_param() # Test cost function for convergence. if self._cost_func: self.converge = (self.any_convergence_flag() or self._cost_func.get_cost(self._x_new)) def iterate(self, max_iter=150): """Iterate. This method calls update until either convergence criteria is met or the maximum number of iterations is reached Parameters ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) """ self._run_alg(max_iter) # retrieve metrics results self.retrieve_outputs() # rename outputs as attributes self.x_final = self._z_new def get_notify_observers_kwargs(self): """Notify observers. Return the mapping between the metrics call and the iterated variables. Returns ------- dict The mapping between the iterated variables """ return { 'x_new': self._linear.adj_op(self._x_new), 'z_new': self._z_new, 'idx': self.idx, } def retrieve_outputs(self): """Retireve outputs. Declare the outputs of the algorithms as attributes: x_final, y_final, metrics. """ metrics = {} for obs in self._observers['cv_metrics']: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics
def __init__( self, x, grad, prox, cost='auto', beta_param=1.0, lambda_param=1.0, beta_update=None, lambda_update='fista', auto_iterate=True, metric_call_period=5, metrics=None, linear=None, **kwargs, ): # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, metrics=metrics, **kwargs, ) # Set the initial variable values self._check_input_data(x) self._x_old = self.copy_data(x) self._z_old = self.copy_data(x) # Set the algorithm operators for operator in (grad, prox, cost): self._check_operator(operator) self._grad = grad self._prox = prox self._linear = linear if cost == 'auto': self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost # Check if there is a linear op, needed for metrics in the FB algoritm if metrics and self._linear is None: raise ValueError( 'When using metrics, you must pass a linear operator', ) if self._linear is None: self._linear = Identity() # Set the algorithm parameters for param_val in (beta_param, lambda_param): self._check_param(param_val) self._beta = self.step_size or beta_param self._lambda = lambda_param # Set the algorithm parameter update methods self._check_param_update(beta_update) self._beta_update = beta_update if isinstance(lambda_update, str) and lambda_update == 'fista': fista = FISTA(**kwargs) self._lambda_update = fista.update_lambda self._is_restart = fista.is_restart self._beta_update = fista.update_beta else: self._check_param_update(lambda_update) self._lambda_update = lambda_update self._is_restart = lambda *args, **kwargs: False # Automatically run the algorithm if auto_iterate: self.iterate()
def polychromatic_psf_field_est_2(im_stack_in,spectrums,wvl,D,opt_shift_est,nb_comp,field_pos=None,nb_iter=4,nb_subiter=100,mu=0.3,\ tol = 0.1,sig_supp = 3,sig=None,shifts=None,flux=None,nsig_shift_est=4,pos_en = True,simplex_en=False,\ wvl_en=True,wvl_opt=None,nsig=3,graph_cons_en=False): """ Main LambdaRCA function. Calls: * :func:`utils.get_noise_arr` * :func:`utils.diagonally_dominated_mat_stack` * :func:`psf_learning_utils.full_displacement` * :func:`utils.im_gauss_nois_est_cube` * :func:`utils.thresholding_3D` * :func:`utils.shift_est` * :func:`utils.shift_ker_stack` * :func:`utils.flux_estimate_stack` * :func:`optim_utils.analysis` * :func:`utils.cube_svd` * :func:`grads.polychrom_eigen_psf` * :func:`grads.polychrom_eigen_psf_coeff_graph` * :func:`grads.polychrom_eigen_psf_coeff` * :func:`psf_learning_utils.field_reconstruction` * :func:`operators.transport_plan_lin_comb_wavelet` * :func:`operators.transport_plan_marg_wavelet` * :func:`operators.transport_plan_lin_comb` * :func:`operators.transport_plan_lin_comb_coeff` * :func:`proxs.simplex_threshold` * :func:`proxs.Simplex` * :func:`proxs.KThreshold` """ im_stack = copy(im_stack_in) if wvl_en: from utils import get_noise_arr print "--------------- Transport architecture setting ------------------" nb_im = im_stack.shape[-1] shap_obs = im_stack.shape shap = (shap_obs[0]*D,shap_obs[1]*D) P_stack = utils.diagonally_dominated_mat_stack(shap,nb_comp,sig=sig_supp,thresh_en=True) i,j = where(P_stack[:,:,0]>0) supp = transpose(array([i,j])) t = (wvl-wvl.min()).astype(float)/(wvl.max()-wvl.min()) neighbors_graph,weights_neighbors,cent,coord_map,knn = psf_learning_utils.full_displacement(shap,supp,t,\ pol_en=True,cent=None,theta_param=1,pol_mod=True,coord_map=None,knn=None) print "------------------- Forward operator parameters estimation ------------------------" centroids = None if sig is None: sig,filters = utils.im_gauss_nois_est_cube(copy(im_stack),opt=opt_shift_est) if shifts is None: map = ones(im_stack.shape) for i in range(0,shap_obs[2]): map[:,:,i] *= nsig_shift_est*sig[i] print 'Shifts estimation...' psf_stack_shift = utils.thresholding_3D(copy(im_stack),map,0) shifts,centroids = utils.shift_est(psf_stack_shift) print 'Done...' else: print "---------- /!\ Warning: shifts provided /!\ ---------" ker,ker_rot = utils.shift_ker_stack(shifts,D) sig /=sig.min() for k in range(0,shap_obs[2]): im_stack[:,:,k] = im_stack[:,:,k]/sig[k] print " ------ ref energy: ",(im_stack**2).sum()," ------- " if flux is None: flux = utils.flux_estimate_stack(copy(im_stack),rad=4) if graph_cons_en: print "-------------------- Spatial constraint setting -----------------------" e_opt,p_opt,weights,comp_temp,data,basis,alph = analysis(im_stack,0.1*prod(shap_obs)*sig.min()**2,field_pos,nb_max=nb_comp) print "------------- Coeff init ------------" A,comp,cube_est = utils.cube_svd(im_stack,nb_comp=nb_comp) i=0 print " --------- Optimization instances setting ---------- " # Data fidelity related instances polychrom_grad = grad.polychrom_eigen_psf(im_stack, supp, neighbors_graph, \ weights_neighbors, spectrums, A, flux, sig, ker, ker_rot, D) if graph_cons_en: polychrom_grad_coeff = grad.polychrom_eigen_psf_coeff_graph(im_stack, supp, neighbors_graph, \ weights_neighbors, spectrums, P_stack, flux, sig, ker, ker_rot, D, basis) else: polychrom_grad_coeff = grad.polychrom_eigen_psf_coeff(im_stack, supp, neighbors_graph, \ weights_neighbors, spectrums, P_stack, flux, sig, ker, ker_rot, D) # Dual variable related linear operators instances dual_var_coeff = zeros((supp.shape[0],nb_im)) if wvl_en and pos_en: lin_com = lambdaops.transport_plan_lin_comb_wavelet(A,supp,weights_neighbors,neighbors_graph,shap,wavelet_opt=wvl_opt) else: if wvl_en: lin_com = lambdaops.transport_plan_marg_wavelet(supp,weights_neighbors,neighbors_graph,shap,wavelet_opt=wvl_opt) else: lin_com = lambdaops.transport_plan_lin_comb(A, supp,shap) if not graph_cons_en: lin_com_coeff = lambdaops.transport_plan_lin_comb_coeff(P_stack, supp) # Proximity operators related instances id_prox = Identity() if wvl_en and pos_en: noise_map = get_noise_arr(lin_com.op(polychrom_grad.MtX(im_stack))[1]) dual_var_plan = np.array([zeros((supp.shape[0],nb_im)),zeros(noise_map.shape)]) dual_prox_plan = lambdaprox.simplex_threshold(lin_com, nsig*noise_map,pos_en=(not simplex_en)) else: if wvl_en: # Noise estimation noise_map = get_noise_arr(lin_com.op(polychrom_grad.MtX(im_stack))) dual_var_plan = zeros(noise_map.shape) dual_prox_plan = prox.SparseThreshold(lin_com, nsig*noise_map) else: dual_var_plan = zeros((supp.shape[0],nb_im)) if simplex_en: dual_prox_plan = lambdaprox.Simplex() else: dual_prox_plan = prox.Positivity() if graph_cons_en: iter_func = lambda x: floor(sqrt(x)) prox_coeff = lambdaprox.KThreshold(iter_func) else: if simplex_en: dual_prox_coeff = lambdaprox.Simplex() else: dual_prox_coeff = prox.Positivity() # ---- (Re)Setting hyperparameters delta = (polychrom_grad.inv_spec_rad**(-1)/2)**2 + 4*lin_com.mat_norm**2 w = 0.9 sigma_P = w*(np.sqrt(delta)-polychrom_grad.inv_spec_rad**(-1)/2)/(2*lin_com.mat_norm**2) tau_P = sigma_P rho_P = 1 # Cost function instance cost_op = costObj([polychrom_grad]) condat_min = optimalg.Condat(P_stack, dual_var_plan, polychrom_grad, id_prox, dual_prox_plan, lin_com, cost=cost_op,\ rho=rho_P, sigma=sigma_P, tau=tau_P, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=False) print "------------------- Transport plans estimation ------------------" condat_min.iterate(max_iter=nb_subiter) # ! actually runs optimisation P_stack = condat_min.x_final dual_var_plan = condat_min.y_final obs_est = polychrom_grad.MX(P_stack) res = im_stack - obs_est for i in range(0,nb_iter): print "----------------Iter ",i+1,"/",nb_iter,"-------------------" # Parameters update polychrom_grad_coeff.set_P(P_stack) if not graph_cons_en: lin_com_coeff.set_P_stack(P_stack) # ---- (Re)Setting hyperparameters delta = (polychrom_grad_coeff.inv_spec_rad**(-1)/2)**2 + 4*lin_com_coeff.mat_norm**2 w = 0.9 sigma_coeff = w*(np.sqrt(delta)-polychrom_grad_coeff.inv_spec_rad**(-1)/2)/(2*lin_com_coeff.mat_norm**2) tau_coeff = sigma_coeff rho_coeff = 1 # Coefficients cost function instance cost_op_coeff = costObj([polychrom_grad_coeff]) if graph_cons_en: beta_param = polychrom_grad_coeff.inv_spec_rad# set stepsize to inverse spectral radius of coefficient gradient min_coeff = optimalg.ForwardBackward(alph, polychrom_grad_coeff, prox_coeff, beta_param=beta_param, cost=cost_op_coeff,auto_iterate=False) else: min_coeff = optimalg.Condat(A, dual_var_coeff, polychrom_grad_coeff, id_prox, dual_prox_coeff, lin_com_coeff, cost=cost_op_coeff,\ rho=rho_coeff, sigma=sigma_coeff, tau=tau_coeff, rho_update=None, sigma_update=None,\ tau_update=None, auto_iterate=False) print "------------------- Coefficients estimation ----------------------" min_coeff.iterate(max_iter=nb_subiter) # ! actually runs optimisation if graph_cons_en: prox_coeff.reset_iter() alph = min_coeff.x_final A = alph.dot(basis) else: A = min_coeff.x_final dual_var_coeff = min_coeff.y_final # Parameters update polychrom_grad.set_A(A) if not wvl_en: lin_com.set_A(A) if wvl_en: # Noise estimate update noise_map = get_noise_arr(lin_com.op(polychrom_grad.MtX(im_stack))[1]) dual_prox_plan.update_weights(noise_map) # ---- (Re)Setting hyperparameters delta = (polychrom_grad.inv_spec_rad**(-1)/2)**2 + 4*lin_com.mat_norm**2 w = 0.9 sigma_P = w*(np.sqrt(delta)-polychrom_grad.inv_spec_rad**(-1)/2)/(2*lin_com.mat_norm**2) tau_P = sigma_P rho_P = 1 # Cost function instance condat_min = optimalg.Condat(P_stack, dual_var_plan, polychrom_grad, id_prox, dual_prox_plan, lin_com, cost=cost_op,\ rho=rho_P, sigma=sigma_P, tau=tau_P, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=False) print "------------------- Transport plans estimation ------------------" condat_min.iterate(max_iter=nb_subiter) # ! actually runs optimisation P_stack = condat_min.x_final dual_var_plan = condat_min.y_final # Normalization for j in range(0,nb_comp): l1_P = sum(abs(P_stack[:,:,j])) P_stack[:,:,j]/= l1_P A[j,:] *= l1_P if graph_cons_en: alph[j,:] *= l1_P polychrom_grad.set_A(A) # Flux update obs_est = polychrom_grad.MX(P_stack) err_ref = 0.5*sum((obs_est-im_stack)**2) flux_new = (obs_est*im_stack).sum(axis=(0,1))/(obs_est**2).sum(axis=(0,1)) print "Flux correction: ",flux_new polychrom_grad.set_flux(polychrom_grad.get_flux()*flux_new) polychrom_grad_coeff.set_flux(polychrom_grad_coeff.get_flux()*flux_new) obs_est = polychrom_grad.MX(P_stack) res = im_stack - obs_est err_rec = 0.5*sum(res**2) print "err_ref : ",err_ref," ; err_rec : ", err_rec # Computing residual psf_est = psf_learning_utils.field_reconstruction(P_stack,shap,supp,neighbors_graph,weights_neighbors,A) return psf_est,P_stack,A,res
class POGM(SetUp): """Proximal Optimised Gradient Method. This class implements algorithm 3 from :cite:`kim2017` Parameters ---------- u : numpy.ndarray Initial guess for the u variable x : numpy.ndarray Initial guess for the x variable (primal) y : numpy.ndarray Initial guess for the y variable z : numpy.ndarray Initial guess for the z variable grad : class Gradient operator class prox : class Proximity operator class cost : class or str, optional Cost function class (default is 'auto'); Use 'auto' to automatically generate a costObj instance linear : class instance, optional Linear operator class (default is ``None``) beta_param : float, optional Initial value of the beta parameter (default is ``1.0``). This corresponds to (1 / L) in :cite:`kim2017` sigma_bar : float, optional Value of the shrinking parameter sigma bar (default is ``1.0``) auto_iterate : bool, optional Option to automatically begin iterations upon initialisation (default is ``True``) Notes ----- The `beta_param` can also be set using the keyword `step_size`, which will override the value of `beta_param`. See Also -------- SetUp : parent class """ def __init__( self, u, x, y, z, grad, prox, cost='auto', linear=None, beta_param=1.0, sigma_bar=1.0, auto_iterate=True, metric_call_period=5, metrics=None, **kwargs, ): # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, metrics=metrics, linear=linear, **kwargs, ) # set the initial variable values for input_data in (u, x, y, z): self._check_input_data(input_data) self._u_old = self.xp.copy(u) self._x_old = self.xp.copy(x) self._y_old = self.xp.copy(y) self._z = self.xp.copy(z) # Set the algorithm operators for operator in (grad, prox, cost): self._check_operator(operator) self._grad = grad self._prox = prox self._linear = linear if cost == 'auto': self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost # If linear is None, make it Identity for call of metrics if self._linear is None: self._linear = Identity() # Set the algorithm parameters for param_val in (beta_param, sigma_bar): self._check_param(param_val) if sigma_bar < 0 or sigma_bar > 1: raise ValueError('The sigma bar parameter needs to be in [0, 1]') self._beta = self.step_size or beta_param self._sigma_bar = sigma_bar self._xi = 1.0 self._sigma = 1.0 self._t_old = 1.0 self._grad.get_grad(self._x_old) self._g_old = self._grad.grad # Automatically run the algorithm if auto_iterate: self.iterate() def _update(self): """Update. This method updates the current reconstruction Notes ----- Implements algorithm 3 from :cite:`kim2017` """ # Step 4 from alg. 3 self._grad.get_grad(self._x_old) self._u_new = self._x_old - self._beta * self._grad.grad # Step 5 from alg. 3 self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old**2)) # Step 6 from alg. 3 t_shifted_ratio = (self._t_old - 1) / self._t_new sigma_t_ratio = self._sigma * self._t_old / self._t_new beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi self._z = -beta_xi_t_shifted_ratio * (self._x_old - self._z) self._z += self._u_new self._z += t_shifted_ratio * (self._u_new - self._u_old) self._z += sigma_t_ratio * (self._u_new - self._x_old) # Step 7 from alg. 3 self._xi = self._beta * (1 + t_shifted_ratio + sigma_t_ratio) # Step 8 from alg. 3 self._x_new = self._prox.op(self._z, extra_factor=self._xi) # Restarting and gamma-Decreasing # Step 9 from alg. 3 self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi # Step 10 from alg 3. self._y_new = self._x_old - self._beta * self._g_new # Step 11 from alg. 3 restart_crit = (self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0) if restart_crit: self._t_new = 1 self._sigma = 1 # Step 13 from alg. 3 elif self.xp.vdot(self._g_new, self._g_old) < 0: self._sigma *= self._sigma_bar # updating variables self._t_old = self._t_new self.xp.copyto(self._u_old, self._u_new) self.xp.copyto(self._x_old, self._x_new) self.xp.copyto(self._g_old, self._g_new) self.xp.copyto(self._y_old, self._y_new) # Test cost function for convergence. if self._cost_func: self.converge = (self.any_convergence_flag() or self._cost_func.get_cost(self._x_new)) def iterate(self, max_iter=150): """Iterate. This method calls update until either convergence criteria is met or the maximum number of iterations is reached. Parameters ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) """ self._run_alg(max_iter) # retrieve metrics results self.retrieve_outputs() # rename outputs as attributes self.x_final = self._x_new def get_notify_observers_kwargs(self): """Notify observers. Return the mapping between the metrics call and the iterated variables. Returns ------- dict The mapping between the iterated variables """ return { 'u_new': self._u_new, 'x_new': self._linear.adj_op(self._x_new), 'y_new': self._y_new, 'z_new': self._z, 'xi': self._xi, 'sigma': self._sigma, 't': self._t_new, 'idx': self.idx, } def retrieve_outputs(self): """Retrieve outputs. Declare the outputs of the algorithms as attributes: x_final, y_final, metrics. """ metrics = {} for obs in self._observers['cv_metrics']: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics
def launch_grid(kspace_data, reconstructor_class, reconstructor_kwargs, fourier_op=None, linear_params=None, regularizer_params=None, optimizer_params=None, compare_metric_details=None, n_jobs=1, verbose=0): """This function launches off reconstruction for a grid specified through use of kwarg dictionaries. Dictionary Convention --------------------- These dictionaries each defined to follow the convention: Each dictionary has a key `init_class` that specifies the initialization class for the operator (exception to this is 'optimizer_params'). Later we have key `kwargs` that holds all the input arguments that can be passed as a keyword dictionary. Each value in this keyword dictionary ,ust be a list of all values you want to search in gridsearch. This function finds the search space of parameters and sets up right parameters for '_reconstruct_case' function. Please check the example code for more details. Parameters ---------- kspace_data: np.ndarray the kspace data for reconstruction reconstructor_class: class reconstructor class reconstructor_kwargs: dict extra kwargs for reconstructor fourier_op: object of class FFT this defines the fourier operator. for NonCartesianFFT, please make fourier_op as `None` and pass fourier_params to allow parallel execution linear_params: dict, default None dictionary for linear operator parameters if None, a sym8 wavelet is chosen regularizer_params: dict, default None dictionary for regularizer operator parameters if None, mu=0, ie no regularization is done optimizer_params: dict, default None dictionary for optimizer key word arguments if None, a FISTA optimization is done for 100 iterations compare_metric_details: dict default None dictionary that holds the metric to be compared and metric direction please refer to `gather_result` documentation. if None, all raw_results are returned and best_idx is None n_jobs: int, default 1 number of parallel jobs for execution verbose: int default 0 Verbosity level 0 => No debug prints 1 => View best results if present """ # Convert non-list elements to list so that we can create # search space init_classes = [] key_names = [] if linear_params is None: linear_params = { 'init_class': WaveletN, 'kwargs': { 'wavelet_name': 'sym8', 'nb_scale': 4, } } if regularizer_params is None: regularizer_params = { 'init_class': SparseThreshold, 'kwargs': { 'linear': Identity(), 'weights': [0], } } if optimizer_params is None: optimizer_params = { # Just following convention 'kwargs': { 'optimization_alg': 'fista', 'num_iterations': 100, } } for specific_params in [ linear_params, regularizer_params, optimizer_params ]: for key, value in specific_params['kwargs'].items(): if not isinstance(value, (list, tuple, np.ndarray)): specific_params['kwargs'][key] = [value] # Obtain Initialization classes if specific_params != optimizer_params: init_classes.append(specific_params['init_class']) # Obtain Key Names key_names.append(list(specific_params['kwargs'].keys())) # Create Search space cross_product_list = list( itertools.product( *linear_params['kwargs'].values(), *regularizer_params['kwargs'].values(), *optimizer_params['kwargs'].values(), )) test_cases = [] number_of_test_cases = len(cross_product_list) if verbose > 0: print('Total number of gridsearch cases : ' + str(number_of_test_cases)) # Reshape data such that they match values for key_names for test_case in cross_product_list: iterator = iter(test_case) # Add the test case after reshaping the list all_kwargs_values = [] for indivitual_param_names in key_names: param_kwargs = {} for key in indivitual_param_names: param_kwargs[key] = next(iter(iterator)) all_kwargs_values.append(param_kwargs) test_cases.append( _TestCase(kspace_data, *init_classes, *all_kwargs_values)) if isinstance(fourier_op, NonCartesianFFT): fourier_params = { 'init_class': NonCartesianFFT, 'kwargs': { 'samples': fourier_op.samples, 'shape': fourier_op.shape, } } fourier_op = None else: fourier_params = None # Call for reconstruction results = Parallel(n_jobs=n_jobs)(delayed(test_case.reconstruct_case)( fourier_op=fourier_op, reconstructor_class=reconstructor_class, reconstructor_kwargs=reconstructor_kwargs, fourier_params=fourier_params, ) for test_case in test_cases) best_idx = None if compare_metric_details is not None: best_value, best_idx = \ gather_result( **compare_metric_details, results=results, ) if verbose > 0: print('The best result of grid search is: ' + str(cross_product_list[best_idx])) print('The best value of metric is : ' + str(best_value)) return results, cross_product_list, key_names, best_idx
class ForwardBackward(SetUp): r"""Forward-Backward optimisation This class implements standard forward-backward optimisation with an the option to use the FISTA speed-up Parameters ---------- x : np.ndarray Initial guess for the primal variable grad : class Gradient operator class prox : class Proximity operator class cost : class or str, optional Cost function class (default is 'auto'); Use 'auto' to automatically generate a costObj instance beta_param : float, optional Initial value of the beta parameter (default is 1.0) lambda_param : float, optional Initial value of the lambda parameter (default is 1.0) beta_update : function, optional Beta parameter update method (default is None) lambda_update : function or string, optional Lambda parameter update method (default is 'fista') auto_iterate : bool, optional Option to automatically begin iterations upon initialisation (default is 'True') """ def __init__(self, x, grad, prox, cost='auto', beta_param=1.0, lambda_param=1.0, beta_update=None, lambda_update='fista', auto_iterate=True, metric_call_period=5, metrics={}, linear=None): # Set default algorithm properties super(ForwardBackward, self).__init__(metric_call_period=metric_call_period, metrics=metrics, linear=linear) # Set the initial variable values self._check_input_data(x) self._x_old = np.copy(x) self._z_old = np.copy(x) # Set the algorithm operators (self._check_operator(operator) for operator in (grad, prox, cost)) self._grad = grad self._prox = prox self._linear = linear if cost == 'auto': self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost # Check if there is a linear op, needed for metrics in the FB algoritm if metrics != {} and self._linear is None: raise ValueError('When using metrics, you must pass a linear ' 'operator') if self._linear is None: self._linear = Identity() # Set the algorithm parameters (self._check_param(param) for param in (beta_param, lambda_param)) self._beta = beta_param self._lambda = lambda_param # Set the algorithm parameter update methods if isinstance(lambda_update, str) and lambda_update == 'fista': self._lambda_update = FISTA().update_lambda else: self._check_param_update(lambda_update) self._lambda_update = lambda_update self._check_param_update(beta_update) self._beta_update = beta_update # Automatically run the algorithm if auto_iterate: self.iterate() def _update_param(self): r"""Update parameters This method updates the values of the algorthm parameters with the methods provided """ # Update the gamma parameter. if not isinstance(self._beta_update, type(None)): self._beta = self._beta_update(self._beta) # Update lambda parameter. if not isinstance(self._lambda_update, type(None)): self._lambda = self._lambda_update(self._lambda) def _update(self): r"""Update This method updates the current reconstruction Notes ----- Implements algorithm 10.7 (or 10.5) from [B2011]_ """ # Step 1 from alg.10.7. self._grad.get_grad(self._z_old) y_old = self._z_old - self._beta * self._grad.grad # Step 2 from alg.10.7. self._x_new = self._prox.op(y_old, extra_factor=self._beta) # Step 5 from alg.10.7. self._z_new = self._x_old + self._lambda * (self._x_new - self._x_old) # Update old values for next iteration. np.copyto(self._x_old, self._x_new) np.copyto(self._z_old, self._z_new) # Update parameter values for next iteration. self._update_param() # Test cost function for convergence. if self._cost_func: self.converge = self.any_convergence_flag() or \ self._cost_func.get_cost(self._x_new) def iterate(self, max_iter=150): r"""Iterate This method calls update until either convergence criteria is met or the maximum number of iterations is reached Parameters ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) """ self._run_alg(max_iter) # retrieve metrics results self.retrieve_outputs() # rename outputs as attributes self.x_final = self._z_new def get_notify_observers_kwargs(self): """ Return the mapping between the metrics call and the iterated variables. Return ---------- notify_observers_kwargs: dict, the mapping between the iterated variables. """ return { 'x_new': self._linear.adj_op(self._x_new), 'z_new': self._z_new, 'idx': self.idx } def retrieve_outputs(self): """ Declare the outputs of the algorithms as attributes: x_final, y_final, metrics. """ metrics = {} for obs in self._observers['cv_metrics']: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics
base_ssim = ssim(image_rec0, image) print('The Base SSIM is : ' + str(base_ssim)) ############################################################################# # FISTA optimization # ------------------ # # We now want to refine the zero order solution using a FISTA optimization. # The cost function is set to Proximity Cost + Gradient Cost # Setup the operators linear_op = WaveletN( wavelet_name='sym8', nb_scale=4, ) regularizer_op = SparseThreshold(Identity(), 1.5e-8, thresh_type="soft") # Setup Reconstructor reconstructor = SelfCalibrationReconstructor( fourier_op=fourier_op, linear_op=linear_op, regularizer_op=regularizer_op, gradient_formulation='synthesis', kspace_portion=0.01, verbose=1, ) x_final, costs, metrics = reconstructor.reconstruct( kspace_data=kspace_obs, optimization_alg='fista', num_iterations=10, ) image_rec = pysap.Image(data=x_final)
def __init__( self, x, grad, prox_list, cost='auto', gamma_param=1.0, lambda_param=1.0, gamma_update=None, lambda_update=None, weights=None, auto_iterate=True, metric_call_period=5, metrics=None, linear=None, **kwargs, ): # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, metrics=metrics, **kwargs, ) # Set the initial variable values self._check_input_data(x) self._x_old = self.xp.copy(x) # Set the algorithm operators for operator in [grad, cost] + prox_list: self._check_operator(operator) self._grad = grad self._prox_list = self.xp.array(prox_list) self._linear = linear if cost == 'auto': self._cost_func = costObj([self._grad] + prox_list) else: self._cost_func = cost # Check if there is a linear op, needed for metrics in the FB algoritm if metrics and self._linear is None: raise ValueError( 'When using metrics, you must pass a linear operator', ) if self._linear is None: self._linear = Identity() # Set the algorithm parameters for param_val in (gamma_param, lambda_param): self._check_param(param_val) self._gamma = self.step_size or gamma_param self._lambda_param = lambda_param # Set the algorithm parameter update methods for param_update in (gamma_update, lambda_update): self._check_param_update(param_update) self._gamma_update = gamma_update self._lambda_update = lambda_update # Set the proximity weights self._set_weights(weights) # Set initial z self._z = self.xp.array( [self._x_old for i in range(self._prox_list.size)]) # Automatically run the algorithm if auto_iterate: self.iterate()
def sparse_rec_condatvu(gradient_op, linear_op, std_est=None, std_est_method=None, std_thr=2., mu=1e-6, tau=None, sigma=None, relaxation_factor=1.0, nb_of_reweights=1, max_nb_of_iter=150, add_positivity=False, atol=1e-4, verbose=0): """ The Condat-Vu sparse reconstruction with reweightings. .. note:: At the moment, supports only 2D data. Parameters ---------- data: ndarray the data to reconstruct: observation are expected in Fourier space. wavelet_name: str the wavelet name to be used during the decomposition. samples: np.ndarray the mask samples in the Fourier domain. nb_scales: int, default 4 the number of scales in the wavelet decomposition. std_est: float, default None the noise std estimate. If None use the MAD as a consistent estimator for the std. std_est_method: str, default None if the standard deviation is not set, estimate this parameter using the mad routine in the image ('primal') or in the sparse wavelet decomposition ('dual') domain. std_thr: float, default 2. use this treshold expressed as a number of sigma in the residual proximity operator during the thresholding. mu: float, default 1e-6 regularization hyperparameter. tau, sigma: float, default None parameters of the Condat-Vu proximal-dual splitting algorithm. If None estimates these parameters. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. nb_of_reweights: int, default 1 the number of reweightings. max_nb_of_iter: int, default 150 the maximum number of iterations in the Condat-Vu proximal-dual splitting algorithm. add_positivity: bool, default False by setting this option, set the proximity operator to identity or positive. atol: float, default 1e-4 tolerance threshold for convergence. verbose: int, default 0 the verbosity level. Returns ------- x_final: ndarray the estimated CONDAT-VU solution. transform: a WaveletTransformBase derived instance the wavelet transformation instance. """ # Check inputs # analysis = True # if hasattr(gradient_op, 'linear_op'): # analysis = False start = time.clock() if std_est_method not in (None, "primal", "dual"): raise ValueError( "Unrecognize std estimation method '{0}'.".format(std_est_method)) # Define the initial primal and dual solutions x_init = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex) weights = linear_op.op(x_init) # Define the weights used during the thresholding in the dual domain, # the reweighting strategy, and the prox dual operator # Case1: estimate the noise std in the image domain if std_est_method == "primal": if std_est is None: std_est = sigma_mad(gradient_op.MtX(gradient_op.y)) weights[...] = std_thr * std_est reweight_op = cwbReweight(weights) prox_dual_op = Threshold(reweight_op.weights) # Case2: estimate the noise std in the sparse wavelet domain elif std_est_method == "dual": if std_est is None: std_est = 0.0 weights[...] = std_thr * std_est reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr) prox_dual_op = Threshold(reweight_op.weights) # Case3: manual regularization mode, no reweighting else: weights[...] = mu reweight_op = None prox_dual_op = Threshold(weights) nb_of_reweights = 0 # Define the Condat Vu optimizer: define the tau and sigma in the # Condat-Vu proximal-dual splitting algorithm if not already provided. # Check also that the combination of values will lead to convergence. norm = linear_op.l2norm(gradient_op.fourier_op.shape) lipschitz_cst = gradient_op.spec_rad if sigma is None: sigma = 0.5 if tau is None: # to avoid numerics troubles with the convergence bound eps = 1.0e-8 # due to the convergence bound tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps) convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0) # Define initial primal and dual solutions primal = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex) dual = linear_op.op(primal) dual[...] = 0.0 # Welcome message if verbose > 0: print(condatvu_logo()) print(" - mu: ", mu) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - tau: ", tau) print(" - sigma: ", sigma) print(" - rho: ", relaxation_factor) print(" - std: ", std_est) print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test) print(" - data: ", gradient_op.obs_data.shape) print(" - max iterations: ", max_nb_of_iter) print(" - number of reweights: ", nb_of_reweights) print(" - primal variable shape: ", primal.shape) print(" - dual variable shape: ", dual.shape) print("-" * 40) # Define the proximity operator if add_positivity: prox_op = Positivity() else: prox_op = Identity() # Define the cost function cost_op = DualGapCost(linear_op=linear_op, initial_cost=1e6, tolerance=1e-4, cost_interval=1, test_range=4, verbose=0, plot_output=None) # Define the optimizer opt = Condat(x=primal, y=dual, grad=gradient_op, prox=prox_op, prox_dual=prox_dual_op, linear=linear_op, cost=cost_op, rho=relaxation_factor, sigma=sigma, tau=tau, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=False) # Perform the first reconstruction if verbose > 0: print("Starting optimization...") for i in range(max_nb_of_iter): opt._update() opt.x_final = opt._x_new opt.y_final = opt._y_new # Loop through the number of reweightings for reweight_index in range(nb_of_reweights): # Generate the new weights following reweighting prescription if std_est_method == "primal": reweight_op.reweight(linear_op.op(opt._x_new)) else: std_est = reweight_op.reweight(opt._x_new) # Welcome message if verbose > 0: print(" - reweight: ", reweight_index + 1) print(" - std: ", std_est) # Update the weights in the dual proximity operator prox_dual_op.weights = reweight_op.weights # Perform optimisation with new weights opt.iterate(max_iter=max_nb_of_iter) # Goodbye message end = time.clock() if verbose > 0: print(" - final iteration number: ", cost_op._iteration) print(" - final cost value: ", cost_op.cost) print(" - converged: ", opt.converge) print("Done.") print("Execution time: ", end - start, " seconds") print("-" * 40) # Get the final solution x_final = opt.x_final linear_op.transform.analysis_data = unflatten(opt.y_final, linear_op.coeffs_shape) return x_final, linear_op.transform
def test_gridsearch_single_channel(self): """Test Gridsearch script in mri.scripts for single channel reconstruction this is a test of sanity and not if the reconstruction is right. """ image = get_sample_data('2d-mri') mask = np.ones(image.shape) kspace_loc = convert_mask_to_locations(mask) fourier_op = NonCartesianFFT(samples=kspace_loc, shape=image.shape) kspace_data = fourier_op.op(image.data) # Define the keyword dictionaries based on convention metrics = { 'ssim': { 'metric': ssim, 'mapping': { 'x_new': 'test', 'y_new': None }, 'cst_kwargs': { 'ref': image, 'mask': None }, 'early_stopping': True, }, } linear_params = { 'init_class': WaveletN, 'kwargs': { 'wavelet_name': 'sym8', 'nb_scale': 4, } } regularizer_params = { 'init_class': SparseThreshold, 'kwargs': { 'linear': Identity(), 'weights': [0, 1e-5], } } optimizer_params = { # Just following convention 'kwargs': { 'optimization_alg': 'fista', 'num_iterations': 10, 'metrics': metrics, } } # Call the launch grid function and obtain results raw_results, test_cases, key_names, best_idx = launch_grid( kspace_data=kspace_data, fourier_op=fourier_op, linear_params=linear_params, regularizer_params=regularizer_params, optimizer_params=optimizer_params, reconstructor_kwargs={'gradient_formulation': 'synthesis'}, reconstructor_class=SingleChannelReconstructor, compare_metric_details={'metric': 'ssim'}, n_jobs=self.n_jobs, verbose=1, ) # In this test we dont undersample the kspace so the # reconstruction is indeed with mu=0, ie best_idx=0 np.testing.assert_equal(best_idx, 0) np.testing.assert_allclose( raw_results[best_idx][0], image, atol=1e-7, )
def __init__(self, x, grad, prox, cost='auto', beta_param=1.0, lambda_param=1.0, beta_update=None, lambda_update='fista', auto_iterate=True, metric_call_period=5, metrics={}, linear=None): # Set default algorithm properties super(ForwardBackward, self).__init__(metric_call_period=metric_call_period, metrics=metrics, linear=linear) # Set the initial variable values self._check_input_data(x) self._x_old = np.copy(x) self._z_old = np.copy(x) # Set the algorithm operators (self._check_operator(operator) for operator in (grad, prox, cost)) self._grad = grad self._prox = prox self._linear = linear if cost == 'auto': self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost # Check if there is a linear op, needed for metrics in the FB algoritm if metrics != {} and self._linear is None: raise ValueError('When using metrics, you must pass a linear ' 'operator') if self._linear is None: self._linear = Identity() # Set the algorithm parameters (self._check_param(param) for param in (beta_param, lambda_param)) self._beta = beta_param self._lambda = lambda_param # Set the algorithm parameter update methods if isinstance(lambda_update, str) and lambda_update == 'fista': self._lambda_update = FISTA().update_lambda else: self._check_param_update(lambda_update) self._lambda_update = lambda_update self._check_param_update(beta_update) self._beta_update = beta_update # Automatically run the algorithm if auto_iterate: self.iterate()
def sparse_rec_condatvu(gradient_op, linear_op, prox_dual_op, cost_op, std_est=None, std_est_method=None, std_thr=2., mu=1e-6, tau=None, sigma=None, relaxation_factor=1.0, nb_of_reweights=1, max_nb_of_iter=150, add_positivity=False, metric_call_period=5, metrics=None, verbose=False, progress=True): """ The Condat-Vu sparse reconstruction with reweightings. Parameters ---------- gradient_op: instance of class GradBase the gradient operator. linear_op: instance of LinearBase the linear operator: seek the sparsity, ie. a wavelet transform. prox_dual_op: instance of ProximityParent the proximal dual operator. cost_op: instance of costObj the cost function used to check for convergence during the optimization. std_est: float, default None the noise std estimate. If None use the MAD as a consistent estimator for the std. std_est_method: str, default None if the standard deviation is not set, estimate this parameter using the mad routine in the image ('primal') or in the sparse wavelet decomposition ('dual') domain. std_thr: float, default 2. use this treshold expressed as a number of sigma in the residual proximity operator during the thresholding. mu: float, default 1e-6 regularization hyperparameter. tau, sigma: float, default None parameters of the Condat-Vu proximal-dual splitting algorithm. If None estimates these parameters. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. nb_of_reweights: int, default 1 the number of reweightings. max_nb_of_iter: int, default 150 the maximum number of iterations in the Condat-Vu proximal-dual splitting algorithm. add_positivity: bool, default False by setting this option, set the proximity operator to identity or positive. metric_call_period: int (default 5) the period on which the metrics are compute. metrics: dict (optional, default None) the list of desired convergence metrics: {'metric_name': [@metric, metric_parameter]}. See modopt for the metrics API. verbose: bool, default False the verbosity level. progress: bool, optional Activation key for progression bar displaying Returns ------- x_final: np.ndarray((m,n)) or np.ndarray((m,n,p)) the estimated CONDAT-VU solution. transform_output: a WaveletTransformBase derived instance or an array the wavelet transformation instance or the transformation coefficients. costs: list of float the cost function values. metrics: dict the requested metrics values during the optimization. """ # Check inputs start = time.perf_counter() if std_est_method not in (None, "dual"): raise ValueError( "Unrecognized std estimation method '{}'.".format(std_est_method)) # Define the initial primal and dual solutions x_init = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex) weights = linear_op.op(x_init) # Define the weights used during the thresholding in the dual domain, # the reweighting strategy, and the prox dual operator # case1: estimate the noise std in the sparse wavelet domain if std_est_method == "dual": if std_est is None: std_est = 0.0 weights[...] = std_thr * std_est reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr) prox_dual_op.weights = reweight_op.weights # Case2: manual regularization mode, no reweighting else: weights[...] = mu reweight_op = None prox_dual_op.weights = weights nb_of_reweights = 0 # Define the Condat Vu optimizer: define the tau and sigma in the # Condat-Vu proximal-dual splitting algorithm if not already provided. # Check also that the combination of values will lead to convergence. norm = linear_op.l2norm(gradient_op.fourier_op.shape) lipschitz_cst = gradient_op.spec_rad if sigma is None: sigma = 0.5 if tau is None: # to avoid numerics troubles with the convergence bound eps = 1.0e-8 # due to the convergence bound tau = 1.0 / (lipschitz_cst / 2 + sigma * norm ** 2 + eps) convergence_test = (1.0 / tau - sigma * norm ** 2 >= lipschitz_cst / 2.0) # Define initial primal and dual solutions primal = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex) dual = linear_op.op(primal) dual[...] = 0.0 # Welcome message if verbose: print(" - mu: ", mu) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - tau: ", tau) print(" - sigma: ", sigma) print(" - rho: ", relaxation_factor) print(" - std: ", std_est) print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test) print(" - data: ", gradient_op.fourier_op.shape) if hasattr(linear_op, "nb_scale"): print(" - wavelet: ", linear_op, "-", linear_op.nb_scale) print(" - max iterations: ", max_nb_of_iter) print(" - number of reweights: ", nb_of_reweights) print(" - primal variable shape: ", primal.shape) print(" - dual variable shape: ", dual.shape) print("-" * 40) # Define the proximity operator if add_positivity: prox_op = Positivity() else: prox_op = Identity() # Define the optimizer opt = Condat( x=primal, y=dual, grad=gradient_op, prox=prox_op, prox_dual=prox_dual_op, linear=linear_op, cost=cost_op, rho=relaxation_factor, sigma=sigma, tau=tau, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=False, metric_call_period=metric_call_period, metrics=metrics or {}, progress=progress) cost_op = opt._cost_func # Perform the first reconstruction if verbose: print("Starting optimization...") opt.iterate(max_iter=max_nb_of_iter) # Loop through the number of reweightings for reweight_index in range(nb_of_reweights): # Generate the new weights following reweighting prescription std_est = reweight_op.reweight(opt._x_new) # Welcome message if verbose: print(" - reweight: ", reweight_index + 1) print(" - std: ", std_est) # Update the weights in the dual proximity operator prox_dual_op.weights = reweight_op.weights # Perform optimisation with new weights opt.iterate(max_iter=max_nb_of_iter) # Goodbye message end = time.perf_counter() if verbose: if hasattr(cost_op, "cost"): print(" - final iteration number: ", cost_op._iteration) print(" - final cost value: ", cost_op.cost) print(" - converged: ", opt.converge) print("Done.") print("Execution time: ", end - start, " seconds") print("-" * 40) # Get the final solution x_final = opt.x_final if hasattr(linear_op, "transform"): linear_op.transform.analysis_data = unflatten( opt.y_final, linear_op.coeffs_shape) transform_output = linear_op.transform else: linear_op.coeff = opt.y_final transform_output = linear_op.coeff if hasattr(cost_op, "cost"): costs = cost_op._cost_list else: costs = None return x_final, transform_output, costs, opt.metrics
def __init__(self, x, y, grad, prox, prox_dual, linear=None, cost='auto', reweight=None, rho=0.5, sigma=1.0, tau=1.0, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=True, metric_call_period=5, metrics={}): # Set default algorithm properties super(Condat, self).__init__( metric_call_period=metric_call_period, metrics=metrics, ) # Set the initial variable values (self._check_input_data(data) for data in (x, y)) self._x_old = np.copy(x) self._y_old = np.copy(y) # Set the algorithm operators (self._check_operator(operator) for operator in (grad, prox, prox_dual, linear, cost)) self._grad = grad self._prox = prox self._prox_dual = prox_dual self._reweight = reweight if isinstance(linear, type(None)): self._linear = Identity() else: self._linear = linear if cost == 'auto': self._cost_func = costObj( [self._grad, self._prox, self._prox_dual]) else: self._cost_func = cost # Set the algorithm parameters (self._check_param(param) for param in (rho, sigma, tau)) self._rho = rho self._sigma = sigma self._tau = tau # Set the algorithm parameter update methods (self._check_param_update(param_update) for param_update in (rho_update, sigma_update, tau_update)) self._rho_update = rho_update self._sigma_update = sigma_update self._tau_update = tau_update # Automatically run the algorithm if auto_iterate: self.iterate()
def condatvu_online( kspace_generator, gradient_op, linear_op, prox_op, cost_op, max_nb_of_iter=150, tau=None, sigma=None, relaxation_factor=1.0, x_init=None, std_est=None, nb_run=1, metric_call_period=5, metrics=None, estimate_call_period=None, verbose=0, ): """ The Condat-Vu sparse reconstruction with reweightings. Parameters ---------- kspace_generator: instance of class KspaceGenerator the observed data (ie kspace) generated for each iteration of the algorithm gradient_op: instance of class GradBase the gradient operator. linear_op: instance of LinearBase the linear operator: seek the sparsity, ie. a wavelet transform. prox_op: instance of ProximityParent the dual regularization operator cost_op: instance of costObj the cost function used to check for convergence during the optimization. max_nb_of_iter: int, default 150 the maximum number of iterations in the Condat-Vu proximal-dual splitting algorithm. tau, sigma: float, default None parameters of the Condat-Vu proximal-dual splitting algorithm. If None estimates these parameters. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. x_init: np.ndarray (optional, default None) the initial guess of image std_est: float, default None the noise std estimate. If None use the MAD as a consistent estimator for the std. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. nb_of_reweights: int, default 1 the number of reweightings. metric_call_period: int (default 5) the period on which the metrics are compute. metrics: dict (optional, default None) the list of desired convergence metrics: {'metric_name': [@metric, metric_parameter]}. See modopt for the metrics API. verbose: int, default 0 the verbosity level. Returns ------- x_final: ndarray the estimated CONDAT-VU solution. costs: list of float the cost function values. metrics: dict the requested metrics values during the optimization. y_final: ndarrat the estimated dual CONDAT-VU solution """ # Check inputs if metrics is None: metrics = dict() # Define the initial primal and dual solutions if x_init is None: x_init = np.squeeze( np.zeros((gradient_op.fourier_op.n_coils, *gradient_op.fourier_op.shape), dtype=np.complex128)) primal = x_init dual = linear_op.op(primal) # Define the Condat Vu optimizer: define the tau and sigma in the # Condat-Vu proximal-dual splitting algorithm if not already provided. # Check also that the combination of values will lead to convergence. norm = linear_op.l2norm(x_init.shape) lipschitz_cst = gradient_op.spec_rad if sigma is None: sigma = 0.5 if tau is None: # to avoid numerics troubles with the convergence bound eps = 1.0e-8 # due to the convergence bound tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps) convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0) # Welcome message if verbose > 0: print(" - mu: ", prox_op.weights) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - tau: ", tau) print(" - sigma: ", sigma) print(" - rho: ", relaxation_factor) print(" - std: ", std_est) print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test) print(" - data: ", gradient_op.fourier_op.shape) if hasattr(linear_op, "nb_scale"): print(" - wavelet: ", linear_op, "-", linear_op.nb_scale) print(" - max iterations: ", max_nb_of_iter) print(" - primal variable shape: ", primal.shape) print(" - dual variable shape: ", dual.shape) print("-" * 40) prox_primal = Identity() # Define the optimizer opt = Condat( x=primal, y=dual, grad=gradient_op, prox=prox_primal, prox_dual=prox_op, linear=linear_op, cost=cost_op, rho=relaxation_factor, # sigma=sigma, # tau=tau, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=False, metric_call_period=metric_call_period, metrics=metrics) return online_algorithm(opt, kspace_generator, estimate_call_period=estimate_call_period, nb_run=nb_run)
class Condat(SetUp): """Condat optimisation. This class implements algorithm 3.1 from :cite:`condat2013` Parameters ---------- x : numpy.ndarray Initial guess for the primal variable y : numpy.ndarray Initial guess for the dual variable grad : class instance Gradient operator class prox : class instance Proximity primal operator class prox_dual : class instance Proximity dual operator class linear : class instance, optional Linear operator class (default is ``None``) cost : class or str, optional Cost function class (default is 'auto'); Use 'auto' to automatically generate a costObj instance reweight : class instance, optional Reweighting class rho : float, optional Relaxation parameter (default is ``0.5``) sigma : float, optional Proximal dual parameter (default is ``1.0``) tau : float, optional Proximal primal paramater (default is ``1.0``) rho_update : function, optional Relaxation parameter update method (default is ``None``) sigma_update : function, optional Proximal dual parameter update method (default is ``None``) tau_update : function, optional Proximal primal parameter update method (default is ``None``) auto_iterate : bool, optional Option to automatically begin iterations upon initialisation (default is ``True``) max_iter : int, optional Maximum number of iterations (default is ``150``) n_rewightings : int, optional Number of reweightings to perform (default is ``1``) Notes ----- The `tau_param` can also be set using the keyword `step_size`, which will override the value of `tau_param`. See Also -------- SetUp : parent class """ def __init__( self, x, y, grad, prox, prox_dual, linear=None, cost='auto', reweight=None, rho=0.5, sigma=1.0, tau=1.0, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=True, max_iter=150, n_rewightings=1, metric_call_period=5, metrics=None, **kwargs, ): # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, metrics=metrics, **kwargs, ) # Set the initial variable values for input_data in (x, y): self._check_input_data(input_data) self._x_old = self.xp.copy(x) self._y_old = self.xp.copy(y) # Set the algorithm operators for operator in (grad, prox, prox_dual, linear, cost): self._check_operator(operator) self._grad = grad self._prox = prox self._prox_dual = prox_dual self._reweight = reweight if isinstance(linear, type(None)): self._linear = Identity() else: self._linear = linear if cost == 'auto': self._cost_func = costObj([ self._grad, self._prox, self._prox_dual, ]) else: self._cost_func = cost # Set the algorithm parameters for param_val in (rho, sigma, tau): self._check_param(param_val) self._rho = rho self._sigma = sigma self._tau = self.step_size or tau # Set the algorithm parameter update methods for param_update in (rho_update, sigma_update, tau_update): self._check_param_update(param_update) self._rho_update = rho_update self._sigma_update = sigma_update self._tau_update = tau_update # Automatically run the algorithm if auto_iterate: self.iterate(max_iter=max_iter, n_rewightings=n_rewightings) def _update_param(self): """Update parameters. This method updates the values of the algorthm parameters with the methods provided """ # Update relaxation parameter. if not isinstance(self._rho_update, type(None)): self._rho = self._rho_update(self._rho) # Update proximal dual parameter. if not isinstance(self._sigma_update, type(None)): self._sigma = self._sigma_update(self._sigma) # Update proximal primal parameter. if not isinstance(self._tau_update, type(None)): self._tau = self._tau_update(self._tau) def _update(self): """Update. This method updates the current reconstruction Notes ----- Implements equation 9 (algorithm 3.1) from :cite:`condat2013` - primal proximity operator set up for positivity constraint """ # Step 1 from eq.9. self._grad.get_grad(self._x_old) x_prox = self._prox.op( self._x_old - self._tau * self._grad.grad - self._tau * self._linear.adj_op(self._y_old), ) # Step 2 from eq.9. y_temp = (self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old)) y_prox = (y_temp - self._sigma * self._prox_dual.op( y_temp / self._sigma, extra_factor=(1.0 / self._sigma), )) # Step 3 from eq.9. self._x_new = self._rho * x_prox + (1 - self._rho) * self._x_old self._y_new = self._rho * y_prox + (1 - self._rho) * self._y_old del x_prox, y_prox, y_temp # Update old values for next iteration. self.xp.copyto(self._x_old, self._x_new) self.xp.copyto(self._y_old, self._y_new) # Update parameter values for next iteration. self._update_param() # Test cost function for convergence. if self._cost_func: self.converge = (self.any_convergence_flag() or self._cost_func.get_cost( self._x_new, self._y_new)) def iterate(self, max_iter=150, n_rewightings=1): """Iterate. This method calls update until either convergence criteria is met or the maximum number of iterations is reached Parameters ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) n_rewightings : int, optional Number of reweightings to perform (default is ``1``) """ self._run_alg(max_iter) if not isinstance(self._reweight, type(None)): for _ in range(n_rewightings): self._reweight.reweight(self._linear.op(self._x_new)) self._run_alg(max_iter) # retrieve metrics results self.retrieve_outputs() # rename outputs as attributes self.x_final = self._x_new self.y_final = self._y_new def get_notify_observers_kwargs(self): """Notify observers. Return the mapping between the metrics call and the iterated variables. Returns ------- notify_observers_kwargs : dict, The mapping between the iterated variables """ return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx} def retrieve_outputs(self): """Retrieve outputs. Declare the outputs of the algorithms as attributes: x_final, y_final, metrics. """ metrics = {} for obs in self._observers['cv_metrics']: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics
############################################################################# # FISTA optimization # ------------------ # # We now want to refine the zero order solution using a FISTA optimization. # The cost function is set to Proximity Cost + Gradient Cost # Setup the operators linear_op = WaveletN( wavelet_name="sym8", nb_scales=4, dim=3, padding_mode="periodization", ) regularizer_op = SparseThreshold(Identity(), 2 * 1e-11, thresh_type="soft") # Setup Reconstructor reconstructor = SingleChannelReconstructor( fourier_op=fourier_op, linear_op=linear_op, regularizer_op=regularizer_op, gradient_formulation='synthesis', verbose=1, ) # Start Reconstruction x_final, costs, metrics = reconstructor.reconstruct( kspace_data=kspace_data, optimization_alg='fista', num_iterations=200, ) image_rec = pysap.Image(data=np.abs(x_final))
def __init__( self, x, y, grad, prox, prox_dual, linear=None, cost='auto', reweight=None, rho=0.5, sigma=1.0, tau=1.0, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=True, max_iter=150, n_rewightings=1, metric_call_period=5, metrics=None, **kwargs, ): # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, metrics=metrics, **kwargs, ) # Set the initial variable values for input_data in (x, y): self._check_input_data(input_data) self._x_old = self.xp.copy(x) self._y_old = self.xp.copy(y) # Set the algorithm operators for operator in (grad, prox, prox_dual, linear, cost): self._check_operator(operator) self._grad = grad self._prox = prox self._prox_dual = prox_dual self._reweight = reweight if isinstance(linear, type(None)): self._linear = Identity() else: self._linear = linear if cost == 'auto': self._cost_func = costObj([ self._grad, self._prox, self._prox_dual, ]) else: self._cost_func = cost # Set the algorithm parameters for param_val in (rho, sigma, tau): self._check_param(param_val) self._rho = rho self._sigma = sigma self._tau = self.step_size or tau # Set the algorithm parameter update methods for param_update in (rho_update, sigma_update, tau_update): self._check_param_update(param_update) self._rho_update = rho_update self._sigma_update = sigma_update self._tau_update = tau_update # Automatically run the algorithm if auto_iterate: self.iterate(max_iter=max_iter, n_rewightings=n_rewightings)
def condatvu(gradient_op, linear_op, dual_regularizer, cost_op, max_nb_of_iter=150, tau=None, sigma=None, relaxation_factor=1.0, x_init=None, std_est=None, std_est_method=None, std_thr=2., nb_of_reweights=1, metric_call_period=5, metrics={}, verbose=0): """ The Condat-Vu sparse reconstruction with reweightings. Parameters ---------- gradient_op: instance of class GradBase the gradient operator. linear_op: instance of LinearBase the linear operator: seek the sparsity, ie. a wavelet transform. dual_regularizer: instance of ProximityParent the dual regularization operator cost_op: instance of costObj the cost function used to check for convergence during the optimization. max_nb_of_iter: int, default 150 the maximum number of iterations in the Condat-Vu proximal-dual splitting algorithm. tau, sigma: float, default None parameters of the Condat-Vu proximal-dual splitting algorithm. If None estimates these parameters. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. x_init: np.ndarray (optional, default None) the initial guess of image std_est: float, default None the noise std estimate. If None use the MAD as a consistent estimator for the std. std_est_method: str, default None if the standard deviation is not set, estimate this parameter using the mad routine in the image ('primal') or in the sparse wavelet decomposition ('dual') domain. std_thr: float, default 2. use this treshold expressed as a number of sigma in the residual proximity operator during the thresholding. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. nb_of_reweights: int, default 1 the number of reweightings. metric_call_period: int (default 5) the period on which the metrics are compute. metrics: dict (optional, default None) the list of desired convergence metrics: {'metric_name': [@metric, metric_parameter]}. See modopt for the metrics API. verbose: int, default 0 the verbosity level. Returns ------- x_final: ndarray the estimated CONDAT-VU solution. costs: list of float the cost function values. metrics: dict the requested metrics values during the optimization. y_final: ndarrat the estimated dual CONDAT-VU solution """ # Check inputs start = time.clock() if std_est_method not in (None, "primal", "dual"): raise ValueError( "Unrecognize std estimation method '{0}'.".format(std_est_method)) # Define the initial primal and dual solutions if x_init is None: x_init = np.squeeze( np.zeros((linear_op.n_coils, *gradient_op.fourier_op.shape), dtype=np.complex)) primal = x_init dual = linear_op.op(primal) weights = dual # Define the weights used during the thresholding in the dual domain, # the reweighting strategy, and the prox dual operator # Case1: estimate the noise std in the image domain if std_est_method == "primal": if std_est is None: std_est = sigma_mad(gradient_op.MtX(gradient_op.obs_data)) weights[...] = std_thr * std_est reweight_op = cwbReweight(weights) dual_regularizer.weights = reweight_op.weights # Case2: estimate the noise std in the sparse wavelet domain elif std_est_method == "dual": if std_est is None: std_est = 0.0 weights[...] = std_thr * std_est reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr) dual_regularizer.weights = reweight_op.weights # Case3: manual regularization mode, no reweighting else: reweight_op = None nb_of_reweights = 0 # Define the Condat Vu optimizer: define the tau and sigma in the # Condat-Vu proximal-dual splitting algorithm if not already provided. # Check also that the combination of values will lead to convergence. norm = linear_op.l2norm(x_init.shape) lipschitz_cst = gradient_op.spec_rad if sigma is None: sigma = 0.5 if tau is None: # to avoid numerics troubles with the convergence bound eps = 1.0e-8 # due to the convergence bound tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps) convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0) # Welcome message if verbose > 0: print(" - mu: ", dual_regularizer.weights) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - tau: ", tau) print(" - sigma: ", sigma) print(" - rho: ", relaxation_factor) print(" - std: ", std_est) print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test) print(" - data: ", gradient_op.fourier_op.shape) if hasattr(linear_op, "nb_scale"): print(" - wavelet: ", linear_op, "-", linear_op.nb_scale) print(" - max iterations: ", max_nb_of_iter) print(" - number of reweights: ", nb_of_reweights) print(" - primal variable shape: ", primal.shape) print(" - dual variable shape: ", dual.shape) print("-" * 40) prox_op = Identity() # Define the optimizer opt = Condat(x=primal, y=dual, grad=gradient_op, prox=prox_op, prox_dual=dual_regularizer, linear=linear_op, cost=cost_op, rho=relaxation_factor, sigma=sigma, tau=tau, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=False, metric_call_period=metric_call_period, metrics=metrics) cost_op = opt._cost_func # Perform the first reconstruction if verbose > 0: print("Starting optimization...") opt.iterate(max_iter=max_nb_of_iter) # Loop through the number of reweightings for reweight_index in range(nb_of_reweights): # Generate the new weights following reweighting prescription if std_est_method == "primal": reweight_op.reweight(linear_op.op(opt._x_new)) else: std_est = reweight_op.reweight(opt._x_new) # Welcome message if verbose > 0: print(" - reweight: ", reweight_index + 1) print(" - std: ", std_est) # Update the weights in the dual proximity operator dual_regularizer.weights = reweight_op.weights # Perform optimisation with new weights opt.iterate(max_iter=max_nb_of_iter) # Goodbye message end = time.clock() if verbose > 0: if hasattr(cost_op, "cost"): print(" - final iteration number: ", cost_op._iteration) print(" - final cost value: ", cost_op.cost) print(" - converged: ", opt.converge) print("Done.") print("Execution time: ", end - start, " seconds") print("-" * 40) # Get the final solution x_final = opt.x_final y_final = opt.y_final if hasattr(cost_op, "cost"): costs = cost_op._cost_list else: costs = None return x_final, costs, opt.metrics, y_final
'mask': None }, 'early_stopping': True, }, } linear_params = { 'init_class': WaveletN, 'kwargs': { 'wavelet_name': ['sym8', 'sym12'], 'nb_scale': [3, 4] } } regularizer_params = { 'init_class': SparseThreshold, 'kwargs': { 'linear': Identity(), 'weights': np.logspace(-8, -6, 5), } } optimizer_params = { # Just following convention 'kwargs': { 'optimization_alg': 'fista', 'num_iterations': 20, 'metrics': metrics, } } # Call the launch grid function and obtain results raw_results, test_cases, key_names, best_idx = launch_grid( kspace_data=kspace_data, fourier_op=fourier_op,