def get_operators( kspace_data, loc, mask, fourier_type=1, max_iter=80, regularisation=None, linear=None, ): """Create the various operators from the config file.""" n_coils = 1 if kspace_data.ndim == 2 else kspace_data.shape[0] shape = kspace_data.shape[-2:] if fourier_type == 0: # offline reconstruction kspace_generator = KspaceGeneratorBase(full_kspace=kspace_data, mask=mask, max_iter=max_iter) fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask) elif fourier_type == 1: # online type I reconstruction kspace_generator = Column2DKspaceGenerator(full_kspace=kspace_data, mask_cols=loc) fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask) elif fourier_type == 2: # online type II reconstruction kspace_generator = DataOnlyKspaceGenerator(full_kspace=kspace_data, mask_cols=loc) fourier_op = ColumnFFT(shape=shape, n_coils=n_coils) else: raise NotImplementedError if linear is None: linear_op = Identity() else: lin_cls = linear.pop("class", None) if lin_cls == "WaveletN": linear_op = WaveletN(n_coils=n_coils, n_jobs=4, **linear) linear_op.op(np.zeros_like(kspace_data)) elif lin_cls == "Identity": linear_op = Identity() else: raise NotImplementedError prox_op = IdentityProx() if regularisation is not None: reg_cls = regularisation.pop("class") if reg_cls == "LASSO": prox_op = LASSO(weights=regularisation["weights"]) if reg_cls == "GroupLASSO": prox_op = GroupLASSO(weights=regularisation["weights"]) elif reg_cls == "OWL": prox_op = OWL(**regularisation, n_coils=n_coils, bands_shape=linear_op.coeffs_shape) elif reg_cls == "IdentityProx": prox_op = IdentityProx() linear_op = Identity() return kspace_generator, fourier_op, linear_op, prox_op
class Condat(SetUp): """Condat optimisation. This class implements algorithm 3.1 from :cite:`condat2013` Parameters ---------- x : numpy.ndarray Initial guess for the primal variable y : numpy.ndarray Initial guess for the dual variable grad : class instance Gradient operator class prox : class instance Proximity primal operator class prox_dual : class instance Proximity dual operator class linear : class instance, optional Linear operator class (default is ``None``) cost : class or str, optional Cost function class (default is 'auto'); Use 'auto' to automatically generate a costObj instance reweight : class instance, optional Reweighting class rho : float, optional Relaxation parameter (default is ``0.5``) sigma : float, optional Proximal dual parameter (default is ``1.0``) tau : float, optional Proximal primal paramater (default is ``1.0``) rho_update : function, optional Relaxation parameter update method (default is ``None``) sigma_update : function, optional Proximal dual parameter update method (default is ``None``) tau_update : function, optional Proximal primal parameter update method (default is ``None``) auto_iterate : bool, optional Option to automatically begin iterations upon initialisation (default is ``True``) max_iter : int, optional Maximum number of iterations (default is ``150``) n_rewightings : int, optional Number of reweightings to perform (default is ``1``) Notes ----- The `tau_param` can also be set using the keyword `step_size`, which will override the value of `tau_param`. See Also -------- SetUp : parent class """ def __init__( self, x, y, grad, prox, prox_dual, linear=None, cost='auto', reweight=None, rho=0.5, sigma=1.0, tau=1.0, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=True, max_iter=150, n_rewightings=1, metric_call_period=5, metrics=None, **kwargs, ): # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, metrics=metrics, **kwargs, ) # Set the initial variable values for input_data in (x, y): self._check_input_data(input_data) self._x_old = self.xp.copy(x) self._y_old = self.xp.copy(y) # Set the algorithm operators for operator in (grad, prox, prox_dual, linear, cost): self._check_operator(operator) self._grad = grad self._prox = prox self._prox_dual = prox_dual self._reweight = reweight if isinstance(linear, type(None)): self._linear = Identity() else: self._linear = linear if cost == 'auto': self._cost_func = costObj([ self._grad, self._prox, self._prox_dual, ]) else: self._cost_func = cost # Set the algorithm parameters for param_val in (rho, sigma, tau): self._check_param(param_val) self._rho = rho self._sigma = sigma self._tau = self.step_size or tau # Set the algorithm parameter update methods for param_update in (rho_update, sigma_update, tau_update): self._check_param_update(param_update) self._rho_update = rho_update self._sigma_update = sigma_update self._tau_update = tau_update # Automatically run the algorithm if auto_iterate: self.iterate(max_iter=max_iter, n_rewightings=n_rewightings) def _update_param(self): """Update parameters. This method updates the values of the algorthm parameters with the methods provided """ # Update relaxation parameter. if not isinstance(self._rho_update, type(None)): self._rho = self._rho_update(self._rho) # Update proximal dual parameter. if not isinstance(self._sigma_update, type(None)): self._sigma = self._sigma_update(self._sigma) # Update proximal primal parameter. if not isinstance(self._tau_update, type(None)): self._tau = self._tau_update(self._tau) def _update(self): """Update. This method updates the current reconstruction Notes ----- Implements equation 9 (algorithm 3.1) from :cite:`condat2013` - primal proximity operator set up for positivity constraint """ # Step 1 from eq.9. self._grad.get_grad(self._x_old) x_prox = self._prox.op( self._x_old - self._tau * self._grad.grad - self._tau * self._linear.adj_op(self._y_old), ) # Step 2 from eq.9. y_temp = (self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old)) y_prox = (y_temp - self._sigma * self._prox_dual.op( y_temp / self._sigma, extra_factor=(1.0 / self._sigma), )) # Step 3 from eq.9. self._x_new = self._rho * x_prox + (1 - self._rho) * self._x_old self._y_new = self._rho * y_prox + (1 - self._rho) * self._y_old del x_prox, y_prox, y_temp # Update old values for next iteration. self.xp.copyto(self._x_old, self._x_new) self.xp.copyto(self._y_old, self._y_new) # Update parameter values for next iteration. self._update_param() # Test cost function for convergence. if self._cost_func: self.converge = (self.any_convergence_flag() or self._cost_func.get_cost( self._x_new, self._y_new)) def iterate(self, max_iter=150, n_rewightings=1): """Iterate. This method calls update until either convergence criteria is met or the maximum number of iterations is reached Parameters ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) n_rewightings : int, optional Number of reweightings to perform (default is ``1``) """ self._run_alg(max_iter) if not isinstance(self._reweight, type(None)): for _ in range(n_rewightings): self._reweight.reweight(self._linear.op(self._x_new)) self._run_alg(max_iter) # retrieve metrics results self.retrieve_outputs() # rename outputs as attributes self.x_final = self._x_new self.y_final = self._y_new def get_notify_observers_kwargs(self): """Notify observers. Return the mapping between the metrics call and the iterated variables. Returns ------- notify_observers_kwargs : dict, The mapping between the iterated variables """ return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx} def retrieve_outputs(self): """Retrieve outputs. Declare the outputs of the algorithms as attributes: x_final, y_final, metrics. """ metrics = {} for obs in self._observers['cv_metrics']: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics