Ejemplo n.º 1
0
    def __init__(self,
                 u,
                 x,
                 y,
                 z,
                 grad,
                 prox,
                 cost='auto',
                 linear=None,
                 beta_param=1.0,
                 sigma_bar=1.0,
                 auto_iterate=True,
                 metric_call_period=5,
                 metrics={},
                 **kwargs):

        # Set default algorithm properties
        super(POGM, self).__init__(metric_call_period=metric_call_period,
                                   metrics=metrics,
                                   linear=linear,
                                   **kwargs)

        # set the initial variable values
        (self._check_input_data(data) for data in (u, x, y, z))
        self._u_old = np.copy(u)
        self._x_old = np.copy(x)
        self._y_old = np.copy(y)
        self._z = np.copy(z)

        # Set the algorithm operators
        (self._check_operator(operator) for operator in (grad, prox, cost))
        self._grad = grad
        self._prox = prox
        self._linear = linear
        if cost == 'auto':
            self._cost_func = costObj([self._grad, self._prox])
        else:
            self._cost_func = cost
        # If linear is None, make it Identity for call of metrics
        if self._linear is None:
            self._linear = Identity()
        # Set the algorithm parameters
        (self._check_param(param) for param in (beta_param, sigma_bar))
        if not (0 <= sigma_bar <= 1):
            raise ValueError('The sigma bar parameter needs to be in [0, 1]')
        self._beta = self.step_size or beta_param
        self._sigma_bar = sigma_bar
        self._xi = self._sigma = self._t_old = 1.0
        self._grad.get_grad(self._x_old)
        self._g_old = self._grad.grad

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate()
Ejemplo n.º 2
0
def get_operators(
    kspace_data,
    loc,
    mask,
    fourier_type=1,
    max_iter=80,
    regularisation=None,
    linear=None,
):
    """Create the various operators from the config file."""
    n_coils = 1 if kspace_data.ndim == 2 else kspace_data.shape[0]
    shape = kspace_data.shape[-2:]
    if fourier_type == 0:  # offline reconstruction
        kspace_generator = KspaceGeneratorBase(full_kspace=kspace_data,
                                               mask=mask,
                                               max_iter=max_iter)
        fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask)
    elif fourier_type == 1:  # online type I reconstruction
        kspace_generator = Column2DKspaceGenerator(full_kspace=kspace_data,
                                                   mask_cols=loc)
        fourier_op = FFT(shape=shape, n_coils=n_coils, mask=mask)
    elif fourier_type == 2:  # online type II reconstruction
        kspace_generator = DataOnlyKspaceGenerator(full_kspace=kspace_data,
                                                   mask_cols=loc)
        fourier_op = ColumnFFT(shape=shape, n_coils=n_coils)
    else:
        raise NotImplementedError
    if linear is None:
        linear_op = Identity()
    else:
        lin_cls = linear.pop("class", None)
        if lin_cls == "WaveletN":
            linear_op = WaveletN(n_coils=n_coils, n_jobs=4, **linear)
            linear_op.op(np.zeros_like(kspace_data))
        elif lin_cls == "Identity":
            linear_op = Identity()
        else:
            raise NotImplementedError

    prox_op = IdentityProx()
    if regularisation is not None:
        reg_cls = regularisation.pop("class")
        if reg_cls == "LASSO":
            prox_op = LASSO(weights=regularisation["weights"])
        if reg_cls == "GroupLASSO":
            prox_op = GroupLASSO(weights=regularisation["weights"])
        elif reg_cls == "OWL":
            prox_op = OWL(**regularisation,
                          n_coils=n_coils,
                          bands_shape=linear_op.coeffs_shape)
        elif reg_cls == "IdentityProx":
            prox_op = IdentityProx()
            linear_op = Identity()
    return kspace_generator, fourier_op, linear_op, prox_op
Ejemplo n.º 3
0
 def set_objective(self, X, y, lmbd):
     self.X, self.y, self.lmbd = X, y, lmbd
     n_features = self.X.shape[1]
     if self.restart_strategy == 'greedy':
         min_beta = 1.0
         s_greedy = 1.1
         p_lazy = 1.0
         q_lazy = 1.0
     else:
         min_beta = None
         s_greedy = None
         p_lazy = 1 / 30
         q_lazy = 1 / 10
     self.fb = ForwardBackward(
         x=np.zeros(n_features),  # this is the coefficient w
         grad=GradBasic(
             input_data=y,
             op=lambda w: self.X@w,
             trans_op=lambda res: self.X.T@res,
         ),
         prox=SparseThreshold(Identity(), lmbd),
         beta_param=1.0,
         min_beta=min_beta,
         metric_call_period=None,
         restart_strategy=self.restart_strategy,
         xi_restart=0.96,
         s_greedy=s_greedy,
         p_lazy=p_lazy,
         q_lazy=q_lazy,
         auto_iterate=False,
         progress=False,
         cost=None,
     )
Ejemplo n.º 4
0
 def get_linear_n_regularization_operator(wavelet_name,
                                          image_shape,
                                          dimension=2,
                                          nb_scale=3,
                                          n_coils=1,
                                          n_jobs=1,
                                          verbose=0):
     # A helper function to obtain linear and regularization operator
     try:
         linear_op = WaveletN(
             nb_scale=nb_scale,
             wavelet_name=wavelet_name,
             dim=dimension,
             n_coils=n_coils,
             n_jobs=n_jobs,
             verbose=verbose,
         )
     except ValueError:
         # TODO this is a hack and we need to have a separate WaveletUD2.
         # For Undecimated wavelets, the wavelet_name is wavelet_id
         linear_op = WaveletUD2(
             wavelet_id=wavelet_name,
             nb_scale=nb_scale,
             n_coils=n_coils,
             n_jobs=n_jobs,
             verbose=verbose,
         )
     linear_op.op(np.squeeze(np.zeros((n_coils, *image_shape))))
     regularizer_op = WeightedSparseThreshold(
         linear=Identity(),
         weights=0,
         coeffs_shape=linear_op.coeffs_shape,
         thresh_type="soft")
     return linear_op, regularizer_op
Ejemplo n.º 5
0
 def get_linear_n_regularization_operator(self,
                                          gradient_formulation,
                                          wavelet_name,
                                          dimension=2,
                                          nb_scale=3,
                                          n_coils=1,
                                          n_jobs=1,
                                          verbose=0):
     # A helper function to obtain linear and regularization operator
     try:
         linear_op = WaveletN(
             nb_scale=nb_scale,
             wavelet_name=wavelet_name,
             dim=dimension,
             n_coils=n_coils,
             n_jobs=n_jobs,
             verbose=verbose,
         )
     except ValueError:
         # TODO this is a hack and we need to have a separate WaveletUD2.
         # For Undecimated wavelets, the wavelet_name is wavelet_id
         linear_op = WaveletUD2(
             wavelet_id=wavelet_name,
             nb_scale=nb_scale,
             n_coils=n_coils,
             n_jobs=n_jobs,
             verbose=verbose,
         )
     if gradient_formulation == 'synthesis':
         regularizer_op = SparseThreshold(Identity(), 0, thresh_type="soft")
     elif gradient_formulation == "analysis":
         regularizer_op = SparseThreshold(linear_op, 0, thresh_type="soft")
     return linear_op, regularizer_op
Ejemplo n.º 6
0
 def __init__(self,
              fourier_op,
              linear_op,
              regularizer_op,
              gradient_formulation,
              grad_class,
              init_gradient_op=True,
              verbose=0,
              **extra_grad_args):
     self.fourier_op = fourier_op
     self.linear_op = linear_op
     self.prox_op = regularizer_op
     self.gradient_method = gradient_formulation
     self.grad_class = grad_class
     self.verbose = verbose
     self.extra_grad_args = extra_grad_args
     if regularizer_op is None:
         warnings.warn("The prox_op is not set. Setting to identity. "
                       "Note that optimization is just a gradient descent.")
         self.prox_op = Identity()
     # TODO try to not use gradient_formulation and
     #  rely on static attributes
     # If the reconstruction formulation is synthesis,
     # we send the linear operator as well.
     if gradient_formulation == 'synthesis':
         self.extra_grad_args['linear_op'] = self.linear_op
     if init_gradient_op:
         self.initialize_gradient_op(**self.extra_grad_args)
Ejemplo n.º 7
0
 def __init__(self,
              fourier_op,
              linear_op,
              regularizer_op=None,
              opt='condatvu',
              verbose=0):
     self.fourier_op = fourier_op
     self.linear_op = linear_op
     self.verbose = verbose
     if regularizer_op is None:
         warnings.warn("The prox_op is not set. Setting to identity. "
                       "Note that optimization is just a gradient descent.")
         self.prox_op = IdentityProx()
         self.linear_op = Identity()
     else:
         self.prox_op = regularizer_op
     assert opt in OPTIMIZERS.keys()
     self.opt = opt
     grad_formulation = ANALYSIS_OPT.get(opt, 'synthesis')
     if grad_formulation == 'analysis':
         self.gradient_op = OnlineGradAnalysis(self.fourier_op,
                                               verbose=self.verbose,
                                               num_check_lips=0,
                                               lipschitz_cst=1.1)
     elif grad_formulation == 'synthesis':
         self.gradient_op = OnlineGradSynthesis(self.linear_op,
                                                self.fourier_op,
                                                verbose=self.verbose,
                                                num_check_lips=0,
                                                lipschitz_cst=1.1)
     else:
         raise RuntimeError("Unknown gradient formulation")
     self.grad_formulation = grad_formulation
def reco_wav(kspace,
             gradient_op,
             mu=1 * 1e-8,
             max_iter=10,
             nb_scales=4,
             wavelet_name='db4'):
    # for now this is only working with my fork of pysap-fastMRI
    # I will get it changed soon so that we don't need to ask for a specific
    # pysap-mri install
    from ..wavelets import WaveletDecimated
    from mri.numerics.reconstruct import sparse_rec_fista

    linear_op = WaveletDecimated(
        nb_scale=nb_scales,
        wavelet_name=wavelet_name,
        padding='periodization',
    )

    prox_op = LinearCompositionProx(
        linear_op=linear_op,
        prox_op=SparseThreshold(Identity(), None, thresh_type="soft"),
    )
    gradient_op.obs_data = kspace
    cost_op = None
    x_final, _, _, _ = sparse_rec_fista(
        gradient_op=gradient_op,
        linear_op=Identity(),
        prox_op=prox_op,
        cost_op=cost_op,
        xi_restart=0.96,
        s_greedy=1.1,
        mu=mu,
        restart_strategy='greedy',
        pov='analysis',
        max_nb_of_iter=max_iter,
        metrics=None,
        metric_call_period=1,
        verbose=0,
        progress=False,
    )
    x_final = np.abs(x_final)
    x_final = crop_center(x_final, 320)
    return x_final
Ejemplo n.º 9
0
 def __init__(self,
              weights,
              coeffs_shape,
              weight_type='scale_based',
              zero_weight_coarse=True,
              linear=Identity(),
              **kwargs):
     self.cf_shape = coeffs_shape
     self.weight_type = weight_type
     available_weight_type = ('scale_based', 'custom')
     if self.weight_type not in available_weight_type:
         raise ValueError('Weight type must be one of ' +
                          ' '.join(available_weight_type))
     self.zero_weight_coarse = zero_weight_coarse
     self.mu = weights
     super(WeightedSparseThreshold, self).__init__(weights=self.mu,
                                                   linear=linear,
                                                   **kwargs)
Ejemplo n.º 10
0
 def set_objective(self, X, y, lmbd):
     self.X, self.y, self.lmbd = X, y, lmbd
     n_features = self.X.shape[1]
     sigma_bar = 0.96
     var_init = np.zeros(n_features)
     self.pogm = POGM(
         x=var_init,  # this is the coefficient w
         u=var_init,
         y=var_init,
         z=var_init,
         grad=GradBasic(
             op=lambda w: self.X @ w,
             trans_op=lambda res: self.X.T @ res,
             data=y,
         ),
         prox=SparseThreshold(Identity(), lmbd),
         beta_param=1.0,
         metric_call_period=None,
         sigma_bar=sigma_bar,
         auto_iterate=False,
         progress=False,
         cost=None,
     )
Ejemplo n.º 11
0
class GenForwardBackward(SetUp):
    """Generalized Forward-Backward Algorithm.

    This class implements algorithm 1 from :cite:`raguet2011`

    Parameters
    ----------
    x : list, tuple or numpy.ndarray
        Initial guess for the primal variable
    grad : class instance
        Gradient operator class
    prox_list : list
        List of proximity operator class instances
    cost : class or str, optional
        Cost function class (default is 'auto'); Use 'auto' to automatically
        generate a costObj instance
    gamma_param : float, optional
        Initial value of the gamma parameter (default is ``1.0``)
    lambda_param : float, optional
        Initial value of the lambda parameter (default is ``1.0``)
    gamma_update : function, optional
        Gamma parameter update method (default is ``None``)
    lambda_update : function, optional
        Lambda parameter parameter update method (default is ``None``)
    weights : list, tuple or numpy.ndarray, optional
        Proximity operator weights (default is ``None``)
    auto_iterate : bool, optional
        Option to automatically begin iterations upon initialisation (default
        is ``True``)

    Notes
    -----
    The `gamma_param` can also be set using the keyword `step_size`, which will
    override the value of `gamma_param`.

    See Also
    --------
    SetUp : parent class

    """
    def __init__(
        self,
        x,
        grad,
        prox_list,
        cost='auto',
        gamma_param=1.0,
        lambda_param=1.0,
        gamma_update=None,
        lambda_update=None,
        weights=None,
        auto_iterate=True,
        metric_call_period=5,
        metrics=None,
        linear=None,
        **kwargs,
    ):

        # Set default algorithm properties
        super().__init__(
            metric_call_period=metric_call_period,
            metrics=metrics,
            **kwargs,
        )

        # Set the initial variable values
        self._check_input_data(x)
        self._x_old = self.xp.copy(x)

        # Set the algorithm operators
        for operator in [grad, cost] + prox_list:
            self._check_operator(operator)

        self._grad = grad
        self._prox_list = self.xp.array(prox_list)
        self._linear = linear

        if cost == 'auto':
            self._cost_func = costObj([self._grad] + prox_list)
        else:
            self._cost_func = cost

        # Check if there is a linear op, needed for metrics in the FB algoritm
        if metrics and self._linear is None:
            raise ValueError(
                'When using metrics, you must pass a linear operator', )

        if self._linear is None:
            self._linear = Identity()

        # Set the algorithm parameters
        for param_val in (gamma_param, lambda_param):
            self._check_param(param_val)

        self._gamma = self.step_size or gamma_param
        self._lambda_param = lambda_param

        # Set the algorithm parameter update methods
        for param_update in (gamma_update, lambda_update):
            self._check_param_update(param_update)

        self._gamma_update = gamma_update
        self._lambda_update = lambda_update

        # Set the proximity weights
        self._set_weights(weights)

        # Set initial z
        self._z = self.xp.array(
            [self._x_old for i in range(self._prox_list.size)])

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate()

    def _set_weights(self, weights):
        """Set weights.

        This method sets weights on each of the proximty operators provided

        Parameters
        ----------
        weights : list, tuple or numpy.ndarray
            List of weights

        Raises
        ------
        TypeError
            For invalid input type
        ValueError
            If weights do not sum to one

        """
        if isinstance(weights, type(None)):
            weights = self.xp.repeat(
                1.0 / self._prox_list.size,
                self._prox_list.size,
            )
        elif not isinstance(weights, (list, tuple, np.ndarray)):
            raise TypeError('Weights must be provided as a list.')

        weights = self.xp.array(weights)

        if not np.issubdtype(weights.dtype, np.floating):
            raise ValueError('Weights must be list of float values.')

        if weights.size != self._prox_list.size:
            raise ValueError(
                'The number of weights must match the number of proximity ' +
                'operators.', )

        expected_weight_sum = 1.0

        if self.xp.sum(weights) != expected_weight_sum:
            raise ValueError(
                'Proximity operator weights must sum to 1.0. Current sum of ' +
                'weights = {0}'.format(self.xp.sum(weights)), )

        self._weights = weights

    def _update_param(self):
        """Update parameters.

        This method updates the values of the algorthm parameters with the
        methods provided

        """
        # Update the gamma parameter.
        if not isinstance(self._gamma_update, type(None)):
            self._gamma = self._gamma_update(self._gamma)

        # Update lambda parameter.
        if not isinstance(self._lambda_update, type(None)):
            self._lambda_param = self._lambda_update(self._lambda_param)

    def _update(self):
        """Update.

        This method updates the current reconstruction

        Notes
        -----
        Implements algorithm 1 from :cite:`raguet2011`

        """
        # Calculate gradient for current iteration.
        self._grad.get_grad(self._x_old)

        # Update z values.
        for i in range(self._prox_list.size):
            z_temp = (2 * self._x_old - self._z[i] -
                      self._gamma * self._grad.grad)
            z_prox = self._prox_list[i].op(
                z_temp,
                extra_factor=self._gamma / self._weights[i],
            )
            self._z[i] += self._lambda_param * (z_prox - self._x_old)

        # Update current reconstruction.
        self._x_new = self.xp.sum(
            [z_i * w_i for z_i, w_i in zip(self._z, self._weights)],
            axis=0,
        )

        # Update old values for next iteration.
        self.xp.copyto(self._x_old, self._x_new)

        # Update parameter values for next iteration.
        self._update_param()

        # Test cost function for convergence.
        if self._cost_func:
            self.converge = self._cost_func.get_cost(self._x_new)

    def iterate(self, max_iter=150):
        """Iterate.

        This method calls update until either convergence criteria is met or
        the maximum number of iterations is reached.

        Parameters
        ----------
        max_iter : int, optional
            Maximum number of iterations (default is ``150``)

        """
        self._run_alg(max_iter)

        # retrieve metrics results
        self.retrieve_outputs()

        self.x_final = self._x_new

    def get_notify_observers_kwargs(self):
        """Notify observers.

        Return the mapping between the metrics call and the iterated
        variables.

        Returns
        -------
        dict
           The mapping between the iterated variables

        """
        return {
            'x_new': self._linear.adj_op(self._x_new),
            'z_new': self._z,
            'idx': self.idx,
        }

    def retrieve_outputs(self):
        """Retrieve outputs.

        Declare the outputs of the algorithms as attributes: x_final,
        y_final, metrics.

        """
        metrics = {}
        for obs in self._observers['cv_metrics']:
            metrics[obs.name] = obs.retrieve_metrics()
        self.metrics = metrics
Ejemplo n.º 12
0
class ForwardBackward(SetUp):
    """Forward-Backward optimisation.

    This class implements standard forward-backward optimisation with an the
    option to use the FISTA speed-up

    Parameters
    ----------
    x : numpy.ndarray
        Initial guess for the primal variable
    grad : class
        Gradient operator class
    prox : class
        Proximity operator class
    cost : class or str, optional
        Cost function class (default is 'auto'); Use 'auto' to automatically
        generate a costObj instance
    beta_param : float, optional
        Initial value of the beta parameter (default is ``1.0``)
    lambda_param : float, optional
        Initial value of the lambda parameter (default is ```1.0``)
    beta_update : function, optional
        Beta parameter update method (default is ``None``)
    lambda_update : function or str, optional
        Lambda parameter update method (default is 'fista')
    auto_iterate : bool, optional
        Option to automatically begin iterations upon initialisation (default
        is ``True``)

    Notes
    -----
    The `beta_param` can also be set using the keyword `step_size`, which will
    override the value of `beta_param`.

    See Also
    --------
    FISTA : complementary class
    SetUp : parent class

    """
    def __init__(
        self,
        x,
        grad,
        prox,
        cost='auto',
        beta_param=1.0,
        lambda_param=1.0,
        beta_update=None,
        lambda_update='fista',
        auto_iterate=True,
        metric_call_period=5,
        metrics=None,
        linear=None,
        **kwargs,
    ):

        # Set default algorithm properties
        super().__init__(
            metric_call_period=metric_call_period,
            metrics=metrics,
            **kwargs,
        )

        # Set the initial variable values
        self._check_input_data(x)
        self._x_old = self.copy_data(x)
        self._z_old = self.copy_data(x)

        # Set the algorithm operators
        for operator in (grad, prox, cost):
            self._check_operator(operator)

        self._grad = grad
        self._prox = prox
        self._linear = linear

        if cost == 'auto':
            self._cost_func = costObj([self._grad, self._prox])
        else:
            self._cost_func = cost

        # Check if there is a linear op, needed for metrics in the FB algoritm
        if metrics and self._linear is None:
            raise ValueError(
                'When using metrics, you must pass a linear operator', )

        if self._linear is None:
            self._linear = Identity()

        # Set the algorithm parameters
        for param_val in (beta_param, lambda_param):
            self._check_param(param_val)

        self._beta = self.step_size or beta_param
        self._lambda = lambda_param

        # Set the algorithm parameter update methods
        self._check_param_update(beta_update)
        self._beta_update = beta_update
        if isinstance(lambda_update, str) and lambda_update == 'fista':
            fista = FISTA(**kwargs)
            self._lambda_update = fista.update_lambda
            self._is_restart = fista.is_restart
            self._beta_update = fista.update_beta
        else:
            self._check_param_update(lambda_update)
            self._lambda_update = lambda_update
            self._is_restart = lambda *args, **kwargs: False

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate()

    def _update_param(self):
        """Update parameters.

        This method updates the values of the algorthm parameters with the
        methods provided

        """
        # Update the gamma parameter.
        if not isinstance(self._beta_update, type(None)):
            self._beta = self._beta_update(self._beta)

        # Update lambda parameter.
        if not isinstance(self._lambda_update, type(None)):
            self._lambda = self._lambda_update(self._lambda)

    def _update(self):
        """Update.

        This method updates the current reconstruction

        Notes
        -----
        Implements algorithm 10.7 (or 10.5) from :cite:`bauschke2009`

        """
        # Step 1 from alg.10.7.
        self._grad.get_grad(self._z_old)
        y_old = self._z_old - self._beta * self._grad.grad

        # Step 2 from alg.10.7.
        self._x_new = self._prox.op(y_old, extra_factor=self._beta)

        # Step 5 from alg.10.7.
        self._z_new = self._x_old + self._lambda * (self._x_new - self._x_old)

        # Restarting step from alg.4-5 in [L2018]
        if self._is_restart(self._z_old, self._x_new, self._x_old):
            self._z_new = self._x_new

        # Update old values for next iteration.
        self._x_old = self.xp.copy(self._x_new)
        self._z_old = self.xp.copy(self._z_new)

        # Update parameter values for next iteration.
        self._update_param()

        # Test cost function for convergence.
        if self._cost_func:
            self.converge = (self.any_convergence_flag()
                             or self._cost_func.get_cost(self._x_new))

    def iterate(self, max_iter=150):
        """Iterate.

        This method calls update until either convergence criteria is met or
        the maximum number of iterations is reached

        Parameters
        ----------
        max_iter : int, optional
            Maximum number of iterations (default is ``150``)

        """
        self._run_alg(max_iter)

        # retrieve metrics results
        self.retrieve_outputs()
        # rename outputs as attributes
        self.x_final = self._z_new

    def get_notify_observers_kwargs(self):
        """Notify observers.

        Return the mapping between the metrics call and the iterated
        variables.

        Returns
        -------
        dict
           The mapping between the iterated variables

        """
        return {
            'x_new': self._linear.adj_op(self._x_new),
            'z_new': self._z_new,
            'idx': self.idx,
        }

    def retrieve_outputs(self):
        """Retireve outputs.

        Declare the outputs of the algorithms as attributes: x_final,
        y_final, metrics.

        """
        metrics = {}
        for obs in self._observers['cv_metrics']:
            metrics[obs.name] = obs.retrieve_metrics()
        self.metrics = metrics
Ejemplo n.º 13
0
    def __init__(
        self,
        x,
        grad,
        prox,
        cost='auto',
        beta_param=1.0,
        lambda_param=1.0,
        beta_update=None,
        lambda_update='fista',
        auto_iterate=True,
        metric_call_period=5,
        metrics=None,
        linear=None,
        **kwargs,
    ):

        # Set default algorithm properties
        super().__init__(
            metric_call_period=metric_call_period,
            metrics=metrics,
            **kwargs,
        )

        # Set the initial variable values
        self._check_input_data(x)
        self._x_old = self.copy_data(x)
        self._z_old = self.copy_data(x)

        # Set the algorithm operators
        for operator in (grad, prox, cost):
            self._check_operator(operator)

        self._grad = grad
        self._prox = prox
        self._linear = linear

        if cost == 'auto':
            self._cost_func = costObj([self._grad, self._prox])
        else:
            self._cost_func = cost

        # Check if there is a linear op, needed for metrics in the FB algoritm
        if metrics and self._linear is None:
            raise ValueError(
                'When using metrics, you must pass a linear operator', )

        if self._linear is None:
            self._linear = Identity()

        # Set the algorithm parameters
        for param_val in (beta_param, lambda_param):
            self._check_param(param_val)

        self._beta = self.step_size or beta_param
        self._lambda = lambda_param

        # Set the algorithm parameter update methods
        self._check_param_update(beta_update)
        self._beta_update = beta_update
        if isinstance(lambda_update, str) and lambda_update == 'fista':
            fista = FISTA(**kwargs)
            self._lambda_update = fista.update_lambda
            self._is_restart = fista.is_restart
            self._beta_update = fista.update_beta
        else:
            self._check_param_update(lambda_update)
            self._lambda_update = lambda_update
            self._is_restart = lambda *args, **kwargs: False

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate()
Ejemplo n.º 14
0
def polychromatic_psf_field_est_2(im_stack_in,spectrums,wvl,D,opt_shift_est,nb_comp,field_pos=None,nb_iter=4,nb_subiter=100,mu=0.3,\
                        tol = 0.1,sig_supp = 3,sig=None,shifts=None,flux=None,nsig_shift_est=4,pos_en = True,simplex_en=False,\
                        wvl_en=True,wvl_opt=None,nsig=3,graph_cons_en=False):
    """ Main LambdaRCA function.
    
    Calls:
    
    * :func:`utils.get_noise_arr`
    * :func:`utils.diagonally_dominated_mat_stack` 
    * :func:`psf_learning_utils.full_displacement` 
    * :func:`utils.im_gauss_nois_est_cube` 
    * :func:`utils.thresholding_3D` 
    * :func:`utils.shift_est` 
    * :func:`utils.shift_ker_stack` 
    * :func:`utils.flux_estimate_stack` 
    * :func:`optim_utils.analysis` 
    * :func:`utils.cube_svd`
    * :func:`grads.polychrom_eigen_psf`
    * :func:`grads.polychrom_eigen_psf_coeff_graph`
    * :func:`grads.polychrom_eigen_psf_coeff`
    * :func:`psf_learning_utils.field_reconstruction`
    * :func:`operators.transport_plan_lin_comb_wavelet`
    * :func:`operators.transport_plan_marg_wavelet`
    * :func:`operators.transport_plan_lin_comb`
    * :func:`operators.transport_plan_lin_comb_coeff`
    * :func:`proxs.simplex_threshold`
    * :func:`proxs.Simplex`
    * :func:`proxs.KThreshold`
    """

    im_stack = copy(im_stack_in)
    if wvl_en:
        from utils import get_noise_arr

    print "--------------- Transport architecture setting ------------------"
    nb_im = im_stack.shape[-1]
    shap_obs = im_stack.shape
    shap = (shap_obs[0]*D,shap_obs[1]*D)
    P_stack = utils.diagonally_dominated_mat_stack(shap,nb_comp,sig=sig_supp,thresh_en=True)
    i,j = where(P_stack[:,:,0]>0)
    supp = transpose(array([i,j]))
    t = (wvl-wvl.min()).astype(float)/(wvl.max()-wvl.min())

    neighbors_graph,weights_neighbors,cent,coord_map,knn = psf_learning_utils.full_displacement(shap,supp,t,\
    pol_en=True,cent=None,theta_param=1,pol_mod=True,coord_map=None,knn=None)

    print "------------------- Forward operator parameters estimation ------------------------"
    centroids = None
    if sig is None:
        sig,filters = utils.im_gauss_nois_est_cube(copy(im_stack),opt=opt_shift_est)

    if shifts is None:
        map = ones(im_stack.shape)
        for i in range(0,shap_obs[2]):
            map[:,:,i] *= nsig_shift_est*sig[i]
        print 'Shifts estimation...'
        psf_stack_shift = utils.thresholding_3D(copy(im_stack),map,0)
        shifts,centroids = utils.shift_est(psf_stack_shift)
        print 'Done...'
    else:
        print "---------- /!\ Warning: shifts provided /!\ ---------"
    ker,ker_rot = utils.shift_ker_stack(shifts,D)
    sig /=sig.min()
    for k in range(0,shap_obs[2]):
        im_stack[:,:,k] = im_stack[:,:,k]/sig[k]
    print " ------ ref energy: ",(im_stack**2).sum()," ------- "
    if flux is None:
        flux = utils.flux_estimate_stack(copy(im_stack),rad=4)

    if graph_cons_en:
        print "-------------------- Spatial constraint setting -----------------------"
        e_opt,p_opt,weights,comp_temp,data,basis,alph  = analysis(im_stack,0.1*prod(shap_obs)*sig.min()**2,field_pos,nb_max=nb_comp)

    print "------------- Coeff init ------------"
    A,comp,cube_est = utils.cube_svd(im_stack,nb_comp=nb_comp)

    i=0
    print " --------- Optimization instances setting ---------- "

    # Data fidelity related instances
    polychrom_grad = grad.polychrom_eigen_psf(im_stack, supp, neighbors_graph, \
                weights_neighbors, spectrums, A, flux, sig, ker, ker_rot, D)

    if graph_cons_en:
        polychrom_grad_coeff = grad.polychrom_eigen_psf_coeff_graph(im_stack, supp, neighbors_graph, \
                weights_neighbors, spectrums, P_stack, flux, sig, ker, ker_rot, D, basis)
    else:
        polychrom_grad_coeff = grad.polychrom_eigen_psf_coeff(im_stack, supp, neighbors_graph, \
                weights_neighbors, spectrums, P_stack, flux, sig, ker, ker_rot, D)


    # Dual variable related linear operators instances
    dual_var_coeff = zeros((supp.shape[0],nb_im))
    if wvl_en and pos_en:
        lin_com = lambdaops.transport_plan_lin_comb_wavelet(A,supp,weights_neighbors,neighbors_graph,shap,wavelet_opt=wvl_opt)
    else:
        if wvl_en:
            lin_com = lambdaops.transport_plan_marg_wavelet(supp,weights_neighbors,neighbors_graph,shap,wavelet_opt=wvl_opt)
        else:
            lin_com = lambdaops.transport_plan_lin_comb(A, supp,shap)

    if not graph_cons_en:
        lin_com_coeff = lambdaops.transport_plan_lin_comb_coeff(P_stack, supp)

    # Proximity operators related instances
    id_prox = Identity()
    if wvl_en and pos_en:
        noise_map = get_noise_arr(lin_com.op(polychrom_grad.MtX(im_stack))[1])
        dual_var_plan = np.array([zeros((supp.shape[0],nb_im)),zeros(noise_map.shape)])
        dual_prox_plan = lambdaprox.simplex_threshold(lin_com, nsig*noise_map,pos_en=(not simplex_en))
    else:
        if wvl_en:
            # Noise estimation
            noise_map = get_noise_arr(lin_com.op(polychrom_grad.MtX(im_stack)))
            dual_var_plan = zeros(noise_map.shape)
            dual_prox_plan = prox.SparseThreshold(lin_com, nsig*noise_map)
        else:
            dual_var_plan = zeros((supp.shape[0],nb_im))
            if simplex_en:
                dual_prox_plan = lambdaprox.Simplex()
            else:
                dual_prox_plan = prox.Positivity()

    if graph_cons_en:
        iter_func = lambda x: floor(sqrt(x))
        prox_coeff = lambdaprox.KThreshold(iter_func)
    else:
        if simplex_en:
            dual_prox_coeff = lambdaprox.Simplex()
        else:
            dual_prox_coeff = prox.Positivity()

    # ---- (Re)Setting hyperparameters
    delta  = (polychrom_grad.inv_spec_rad**(-1)/2)**2 + 4*lin_com.mat_norm**2
    w = 0.9
    sigma_P = w*(np.sqrt(delta)-polychrom_grad.inv_spec_rad**(-1)/2)/(2*lin_com.mat_norm**2)
    tau_P = sigma_P
    rho_P = 1

    # Cost function instance
    cost_op = costObj([polychrom_grad])

    condat_min = optimalg.Condat(P_stack, dual_var_plan, polychrom_grad, id_prox, dual_prox_plan, lin_com, cost=cost_op,\
                 rho=rho_P,  sigma=sigma_P, tau=tau_P, rho_update=None, sigma_update=None,
                 tau_update=None, auto_iterate=False)
    print "------------------- Transport plans estimation ------------------"

    condat_min.iterate(max_iter=nb_subiter) # ! actually runs optimisation
    P_stack = condat_min.x_final
    dual_var_plan = condat_min.y_final

    obs_est = polychrom_grad.MX(P_stack)
    res = im_stack - obs_est

    for i in range(0,nb_iter):
        print "----------------Iter ",i+1,"/",nb_iter,"-------------------"

        # Parameters update
        polychrom_grad_coeff.set_P(P_stack)
        if not graph_cons_en:
            lin_com_coeff.set_P_stack(P_stack)
            # ---- (Re)Setting hyperparameters
            delta  = (polychrom_grad_coeff.inv_spec_rad**(-1)/2)**2 + 4*lin_com_coeff.mat_norm**2
            w = 0.9
            sigma_coeff = w*(np.sqrt(delta)-polychrom_grad_coeff.inv_spec_rad**(-1)/2)/(2*lin_com_coeff.mat_norm**2)
            tau_coeff = sigma_coeff
            rho_coeff = 1

        # Coefficients cost function instance
        cost_op_coeff = costObj([polychrom_grad_coeff])

        if graph_cons_en:
            beta_param = polychrom_grad_coeff.inv_spec_rad# set stepsize to inverse spectral radius of coefficient gradient
            min_coeff = optimalg.ForwardBackward(alph, polychrom_grad_coeff, prox_coeff, beta_param=beta_param, 
                                                 cost=cost_op_coeff,auto_iterate=False)
        else:
            min_coeff = optimalg.Condat(A, dual_var_coeff, polychrom_grad_coeff, id_prox, dual_prox_coeff, lin_com_coeff, cost=cost_op_coeff,\
                                            rho=rho_coeff,  sigma=sigma_coeff, tau=tau_coeff, rho_update=None, sigma_update=None,\
                                            tau_update=None, auto_iterate=False)

        print "------------------- Coefficients estimation ----------------------"
        min_coeff.iterate(max_iter=nb_subiter) # ! actually runs optimisation
        if graph_cons_en:
            prox_coeff.reset_iter()
            alph = min_coeff.x_final
            A = alph.dot(basis)
        else:
            A = min_coeff.x_final
            dual_var_coeff = min_coeff.y_final

        # Parameters update
        polychrom_grad.set_A(A)
        if not wvl_en:
            lin_com.set_A(A)
        if wvl_en:
            # Noise estimate update
            noise_map = get_noise_arr(lin_com.op(polychrom_grad.MtX(im_stack))[1])
            dual_prox_plan.update_weights(noise_map)

        # ---- (Re)Setting hyperparameters
        delta  = (polychrom_grad.inv_spec_rad**(-1)/2)**2 + 4*lin_com.mat_norm**2
        w = 0.9
        sigma_P = w*(np.sqrt(delta)-polychrom_grad.inv_spec_rad**(-1)/2)/(2*lin_com.mat_norm**2)
        tau_P = sigma_P
        rho_P = 1

        # Cost function instance
        condat_min = optimalg.Condat(P_stack, dual_var_plan, polychrom_grad, id_prox, dual_prox_plan, lin_com, cost=cost_op,\
                     rho=rho_P,  sigma=sigma_P, tau=tau_P, rho_update=None, sigma_update=None,
                     tau_update=None, auto_iterate=False)
        print "------------------- Transport plans estimation ------------------"

        condat_min.iterate(max_iter=nb_subiter) # ! actually runs optimisation
        P_stack = condat_min.x_final
        dual_var_plan = condat_min.y_final

        # Normalization
        for j in range(0,nb_comp):
            l1_P = sum(abs(P_stack[:,:,j]))
            P_stack[:,:,j]/= l1_P
            A[j,:] *= l1_P
            if graph_cons_en:
                alph[j,:] *= l1_P
        polychrom_grad.set_A(A)
        # Flux update
        obs_est = polychrom_grad.MX(P_stack)
        err_ref = 0.5*sum((obs_est-im_stack)**2)
        flux_new = (obs_est*im_stack).sum(axis=(0,1))/(obs_est**2).sum(axis=(0,1))
        print "Flux correction: ",flux_new
        polychrom_grad.set_flux(polychrom_grad.get_flux()*flux_new)
        polychrom_grad_coeff.set_flux(polychrom_grad_coeff.get_flux()*flux_new)

        obs_est = polychrom_grad.MX(P_stack)
        res = im_stack - obs_est
        err_rec = 0.5*sum(res**2)
        print "err_ref : ",err_ref," ; err_rec : ", err_rec
        # Computing residual


    psf_est = psf_learning_utils.field_reconstruction(P_stack,shap,supp,neighbors_graph,weights_neighbors,A)

    return psf_est,P_stack,A,res
Ejemplo n.º 15
0
class POGM(SetUp):
    """Proximal Optimised Gradient Method.

    This class implements algorithm 3 from :cite:`kim2017`

    Parameters
    ----------
    u : numpy.ndarray
        Initial guess for the u variable
    x : numpy.ndarray
        Initial guess for the x variable (primal)
    y : numpy.ndarray
        Initial guess for the y variable
    z : numpy.ndarray
        Initial guess for the z variable
    grad : class
        Gradient operator class
    prox : class
        Proximity operator class
    cost : class or str, optional
        Cost function class (default is 'auto'); Use 'auto' to automatically
        generate a costObj instance
    linear : class instance, optional
        Linear operator class (default is ``None``)
    beta_param : float, optional
        Initial value of the beta parameter (default is ``1.0``).
        This corresponds to (1 / L) in :cite:`kim2017`
    sigma_bar : float, optional
        Value of the shrinking parameter sigma bar (default is ``1.0``)
    auto_iterate : bool, optional
        Option to automatically begin iterations upon initialisation (default
        is ``True``)

    Notes
    -----
    The `beta_param` can also be set using the keyword `step_size`, which will
    override the value of `beta_param`.

    See Also
    --------
    SetUp : parent class

    """
    def __init__(
        self,
        u,
        x,
        y,
        z,
        grad,
        prox,
        cost='auto',
        linear=None,
        beta_param=1.0,
        sigma_bar=1.0,
        auto_iterate=True,
        metric_call_period=5,
        metrics=None,
        **kwargs,
    ):

        # Set default algorithm properties
        super().__init__(
            metric_call_period=metric_call_period,
            metrics=metrics,
            linear=linear,
            **kwargs,
        )

        # set the initial variable values
        for input_data in (u, x, y, z):
            self._check_input_data(input_data)

        self._u_old = self.xp.copy(u)
        self._x_old = self.xp.copy(x)
        self._y_old = self.xp.copy(y)
        self._z = self.xp.copy(z)

        # Set the algorithm operators
        for operator in (grad, prox, cost):
            self._check_operator(operator)

        self._grad = grad
        self._prox = prox
        self._linear = linear
        if cost == 'auto':
            self._cost_func = costObj([self._grad, self._prox])
        else:
            self._cost_func = cost

        # If linear is None, make it Identity for call of metrics
        if self._linear is None:
            self._linear = Identity()

        # Set the algorithm parameters
        for param_val in (beta_param, sigma_bar):
            self._check_param(param_val)
        if sigma_bar < 0 or sigma_bar > 1:
            raise ValueError('The sigma bar parameter needs to be in [0, 1]')

        self._beta = self.step_size or beta_param
        self._sigma_bar = sigma_bar
        self._xi = 1.0
        self._sigma = 1.0
        self._t_old = 1.0
        self._grad.get_grad(self._x_old)
        self._g_old = self._grad.grad

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate()

    def _update(self):
        """Update.

        This method updates the current reconstruction

        Notes
        -----
        Implements algorithm 3 from :cite:`kim2017`

        """
        # Step 4 from alg. 3
        self._grad.get_grad(self._x_old)
        self._u_new = self._x_old - self._beta * self._grad.grad

        # Step 5 from alg. 3
        self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old**2))

        # Step 6 from alg. 3
        t_shifted_ratio = (self._t_old - 1) / self._t_new
        sigma_t_ratio = self._sigma * self._t_old / self._t_new
        beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi
        self._z = -beta_xi_t_shifted_ratio * (self._x_old - self._z)
        self._z += self._u_new
        self._z += t_shifted_ratio * (self._u_new - self._u_old)
        self._z += sigma_t_ratio * (self._u_new - self._x_old)

        # Step 7 from alg. 3
        self._xi = self._beta * (1 + t_shifted_ratio + sigma_t_ratio)

        # Step 8 from alg. 3
        self._x_new = self._prox.op(self._z, extra_factor=self._xi)

        # Restarting and gamma-Decreasing
        # Step 9 from alg. 3
        self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi

        # Step 10 from alg 3.
        self._y_new = self._x_old - self._beta * self._g_new

        # Step 11 from alg. 3
        restart_crit = (self.xp.vdot(-self._g_new, self._y_new - self._y_old) <
                        0)
        if restart_crit:
            self._t_new = 1
            self._sigma = 1

        # Step 13 from alg. 3
        elif self.xp.vdot(self._g_new, self._g_old) < 0:
            self._sigma *= self._sigma_bar

        # updating variables
        self._t_old = self._t_new
        self.xp.copyto(self._u_old, self._u_new)
        self.xp.copyto(self._x_old, self._x_new)
        self.xp.copyto(self._g_old, self._g_new)
        self.xp.copyto(self._y_old, self._y_new)

        # Test cost function for convergence.
        if self._cost_func:
            self.converge = (self.any_convergence_flag()
                             or self._cost_func.get_cost(self._x_new))

    def iterate(self, max_iter=150):
        """Iterate.

        This method calls update until either convergence criteria is met or
        the maximum number of iterations is reached.

        Parameters
        ----------
        max_iter : int, optional
            Maximum number of iterations (default is ``150``)

        """
        self._run_alg(max_iter)

        # retrieve metrics results
        self.retrieve_outputs()
        # rename outputs as attributes
        self.x_final = self._x_new

    def get_notify_observers_kwargs(self):
        """Notify observers.

        Return the mapping between the metrics call and the iterated
        variables.

        Returns
        -------
        dict
           The mapping between the iterated variables

        """
        return {
            'u_new': self._u_new,
            'x_new': self._linear.adj_op(self._x_new),
            'y_new': self._y_new,
            'z_new': self._z,
            'xi': self._xi,
            'sigma': self._sigma,
            't': self._t_new,
            'idx': self.idx,
        }

    def retrieve_outputs(self):
        """Retrieve outputs.

        Declare the outputs of the algorithms as attributes: x_final,
        y_final, metrics.

        """
        metrics = {}
        for obs in self._observers['cv_metrics']:
            metrics[obs.name] = obs.retrieve_metrics()
        self.metrics = metrics
Ejemplo n.º 16
0
def launch_grid(kspace_data,
                reconstructor_class,
                reconstructor_kwargs,
                fourier_op=None,
                linear_params=None,
                regularizer_params=None,
                optimizer_params=None,
                compare_metric_details=None,
                n_jobs=1,
                verbose=0):
    """This function launches off reconstruction for a grid specified
    through use of kwarg dictionaries.

    Dictionary Convention
    ---------------------
    These dictionaries each defined to follow the convention:
    Each dictionary has a key `init_class` that specifies the
    initialization class for the operator (exception to
    this is 'optimizer_params'). Later we have key `kwargs` that holds
    all the input arguments that can be passed as a keyword dictionary.
    Each value in this keyword dictionary ,ust be a list of all
    values you want to search in gridsearch.

    This function finds the search space of parameters and
    sets up right parameters for '_reconstruct_case' function.
    Please check the example code for more details.

    Parameters
    ----------
    kspace_data: np.ndarray
        the kspace data for reconstruction
    reconstructor_class: class
        reconstructor class
    reconstructor_kwargs: dict
        extra kwargs for reconstructor
    fourier_op: object of class FFT
        this defines the fourier operator. for NonCartesianFFT, please make
        fourier_op as `None` and pass fourier_params to allow
        parallel execution
    linear_params: dict, default None
        dictionary for linear operator parameters
        if None, a sym8 wavelet is chosen
    regularizer_params: dict, default None
        dictionary for regularizer operator parameters
        if None, mu=0, ie no regularization is done
    optimizer_params: dict, default None
        dictionary for optimizer key word arguments
        if None, a FISTA optimization is done for 100 iterations
    compare_metric_details: dict default None
        dictionary that holds the metric to be compared and metric
        direction please refer to `gather_result` documentation.
        if None, all raw_results are returned and best_idx is None
    n_jobs: int, default 1
        number of parallel jobs for execution
    verbose: int default 0
        Verbosity level
        0 => No debug prints
        1 => View best results if present
    """
    # Convert non-list elements to list so that we can create
    # search space
    init_classes = []
    key_names = []
    if linear_params is None:
        linear_params = {
            'init_class': WaveletN,
            'kwargs': {
                'wavelet_name': 'sym8',
                'nb_scale': 4,
            }
        }
    if regularizer_params is None:
        regularizer_params = {
            'init_class': SparseThreshold,
            'kwargs': {
                'linear': Identity(),
                'weights': [0],
            }
        }
    if optimizer_params is None:
        optimizer_params = {
            # Just following convention
            'kwargs': {
                'optimization_alg': 'fista',
                'num_iterations': 100,
            }
        }
    for specific_params in [
            linear_params, regularizer_params, optimizer_params
    ]:
        for key, value in specific_params['kwargs'].items():
            if not isinstance(value, (list, tuple, np.ndarray)):
                specific_params['kwargs'][key] = [value]
        # Obtain Initialization classes
        if specific_params != optimizer_params:
            init_classes.append(specific_params['init_class'])
        # Obtain Key Names
        key_names.append(list(specific_params['kwargs'].keys()))
    # Create Search space
    cross_product_list = list(
        itertools.product(
            *linear_params['kwargs'].values(),
            *regularizer_params['kwargs'].values(),
            *optimizer_params['kwargs'].values(),
        ))
    test_cases = []
    number_of_test_cases = len(cross_product_list)
    if verbose > 0:
        print('Total number of gridsearch cases : ' +
              str(number_of_test_cases))
    # Reshape data such that they match values for key_names
    for test_case in cross_product_list:
        iterator = iter(test_case)
        # Add the test case after reshaping the list
        all_kwargs_values = []
        for indivitual_param_names in key_names:
            param_kwargs = {}
            for key in indivitual_param_names:
                param_kwargs[key] = next(iter(iterator))
            all_kwargs_values.append(param_kwargs)
        test_cases.append(
            _TestCase(kspace_data, *init_classes, *all_kwargs_values))
    if isinstance(fourier_op, NonCartesianFFT):
        fourier_params = {
            'init_class': NonCartesianFFT,
            'kwargs': {
                'samples': fourier_op.samples,
                'shape': fourier_op.shape,
            }
        }
        fourier_op = None
    else:
        fourier_params = None
    # Call for reconstruction
    results = Parallel(n_jobs=n_jobs)(delayed(test_case.reconstruct_case)(
        fourier_op=fourier_op,
        reconstructor_class=reconstructor_class,
        reconstructor_kwargs=reconstructor_kwargs,
        fourier_params=fourier_params,
    ) for test_case in test_cases)
    best_idx = None
    if compare_metric_details is not None:
        best_value, best_idx = \
            gather_result(
                **compare_metric_details,
                results=results,
            )
        if verbose > 0:
            print('The best result of grid search is: ' +
                  str(cross_product_list[best_idx]))
            print('The best value of metric is : ' + str(best_value))
    return results, cross_product_list, key_names, best_idx
Ejemplo n.º 17
0
class ForwardBackward(SetUp):
    r"""Forward-Backward optimisation

    This class implements standard forward-backward optimisation with an the
    option to use the FISTA speed-up

    Parameters
    ----------
    x : np.ndarray
        Initial guess for the primal variable
    grad : class
        Gradient operator class
    prox : class
        Proximity operator class
    cost : class or str, optional
        Cost function class (default is 'auto'); Use 'auto' to automatically
        generate a costObj instance
    beta_param : float, optional
        Initial value of the beta parameter (default is 1.0)
    lambda_param : float, optional
        Initial value of the lambda parameter (default is 1.0)
    beta_update : function, optional
        Beta parameter update method (default is None)
    lambda_update : function or string, optional
        Lambda parameter update method (default is 'fista')
    auto_iterate : bool, optional
        Option to automatically begin iterations upon initialisation (default
        is 'True')

    """
    def __init__(self,
                 x,
                 grad,
                 prox,
                 cost='auto',
                 beta_param=1.0,
                 lambda_param=1.0,
                 beta_update=None,
                 lambda_update='fista',
                 auto_iterate=True,
                 metric_call_period=5,
                 metrics={},
                 linear=None):

        # Set default algorithm properties
        super(ForwardBackward,
              self).__init__(metric_call_period=metric_call_period,
                             metrics=metrics,
                             linear=linear)

        # Set the initial variable values
        self._check_input_data(x)
        self._x_old = np.copy(x)
        self._z_old = np.copy(x)

        # Set the algorithm operators
        (self._check_operator(operator) for operator in (grad, prox, cost))
        self._grad = grad
        self._prox = prox
        self._linear = linear

        if cost == 'auto':
            self._cost_func = costObj([self._grad, self._prox])
        else:
            self._cost_func = cost

        # Check if there is a linear op, needed for metrics in the FB algoritm
        if metrics != {} and self._linear is None:
            raise ValueError('When using metrics, you must pass a linear '
                             'operator')

        if self._linear is None:
            self._linear = Identity()

        # Set the algorithm parameters
        (self._check_param(param) for param in (beta_param, lambda_param))
        self._beta = beta_param
        self._lambda = lambda_param

        # Set the algorithm parameter update methods
        if isinstance(lambda_update, str) and lambda_update == 'fista':
            self._lambda_update = FISTA().update_lambda
        else:
            self._check_param_update(lambda_update)
            self._lambda_update = lambda_update
        self._check_param_update(beta_update)
        self._beta_update = beta_update

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate()

    def _update_param(self):
        r"""Update parameters

        This method updates the values of the algorthm parameters with the
        methods provided

        """

        # Update the gamma parameter.
        if not isinstance(self._beta_update, type(None)):
            self._beta = self._beta_update(self._beta)

        # Update lambda parameter.
        if not isinstance(self._lambda_update, type(None)):
            self._lambda = self._lambda_update(self._lambda)

    def _update(self):
        r"""Update

        This method updates the current reconstruction

        Notes
        -----
        Implements algorithm 10.7 (or 10.5) from [B2011]_

        """

        # Step 1 from alg.10.7.
        self._grad.get_grad(self._z_old)
        y_old = self._z_old - self._beta * self._grad.grad

        # Step 2 from alg.10.7.
        self._x_new = self._prox.op(y_old, extra_factor=self._beta)

        # Step 5 from alg.10.7.
        self._z_new = self._x_old + self._lambda * (self._x_new - self._x_old)

        # Update old values for next iteration.
        np.copyto(self._x_old, self._x_new)
        np.copyto(self._z_old, self._z_new)

        # Update parameter values for next iteration.
        self._update_param()

        # Test cost function for convergence.
        if self._cost_func:
            self.converge = self.any_convergence_flag() or \
                            self._cost_func.get_cost(self._x_new)

    def iterate(self, max_iter=150):
        r"""Iterate

        This method calls update until either convergence criteria is met or
        the maximum number of iterations is reached

        Parameters
        ----------
        max_iter : int, optional
            Maximum number of iterations (default is ``150``)

        """

        self._run_alg(max_iter)

        # retrieve metrics results
        self.retrieve_outputs()
        # rename outputs as attributes
        self.x_final = self._z_new

    def get_notify_observers_kwargs(self):
        """ Return the mapping between the metrics call and the iterated
        variables.

        Return
        ----------
        notify_observers_kwargs: dict,
           the mapping between the iterated variables.
        """
        return {
            'x_new': self._linear.adj_op(self._x_new),
            'z_new': self._z_new,
            'idx': self.idx
        }

    def retrieve_outputs(self):
        """ Declare the outputs of the algorithms as attributes: x_final,
        y_final, metrics.
        """

        metrics = {}
        for obs in self._observers['cv_metrics']:
            metrics[obs.name] = obs.retrieve_metrics()
        self.metrics = metrics
base_ssim = ssim(image_rec0, image)
print('The Base SSIM is : ' + str(base_ssim))

#############################################################################
# FISTA optimization
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
# The cost function is set to Proximity Cost + Gradient Cost

# Setup the operators
linear_op = WaveletN(
    wavelet_name='sym8',
    nb_scale=4,
)
regularizer_op = SparseThreshold(Identity(), 1.5e-8, thresh_type="soft")
# Setup Reconstructor
reconstructor = SelfCalibrationReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    kspace_portion=0.01,
    verbose=1,
)
x_final, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_obs,
    optimization_alg='fista',
    num_iterations=10,
)
image_rec = pysap.Image(data=x_final)
Ejemplo n.º 19
0
    def __init__(
        self,
        x,
        grad,
        prox_list,
        cost='auto',
        gamma_param=1.0,
        lambda_param=1.0,
        gamma_update=None,
        lambda_update=None,
        weights=None,
        auto_iterate=True,
        metric_call_period=5,
        metrics=None,
        linear=None,
        **kwargs,
    ):

        # Set default algorithm properties
        super().__init__(
            metric_call_period=metric_call_period,
            metrics=metrics,
            **kwargs,
        )

        # Set the initial variable values
        self._check_input_data(x)
        self._x_old = self.xp.copy(x)

        # Set the algorithm operators
        for operator in [grad, cost] + prox_list:
            self._check_operator(operator)

        self._grad = grad
        self._prox_list = self.xp.array(prox_list)
        self._linear = linear

        if cost == 'auto':
            self._cost_func = costObj([self._grad] + prox_list)
        else:
            self._cost_func = cost

        # Check if there is a linear op, needed for metrics in the FB algoritm
        if metrics and self._linear is None:
            raise ValueError(
                'When using metrics, you must pass a linear operator', )

        if self._linear is None:
            self._linear = Identity()

        # Set the algorithm parameters
        for param_val in (gamma_param, lambda_param):
            self._check_param(param_val)

        self._gamma = self.step_size or gamma_param
        self._lambda_param = lambda_param

        # Set the algorithm parameter update methods
        for param_update in (gamma_update, lambda_update):
            self._check_param_update(param_update)

        self._gamma_update = gamma_update
        self._lambda_update = lambda_update

        # Set the proximity weights
        self._set_weights(weights)

        # Set initial z
        self._z = self.xp.array(
            [self._x_old for i in range(self._prox_list.size)])

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate()
Ejemplo n.º 20
0
def sparse_rec_condatvu(gradient_op,
                        linear_op,
                        std_est=None,
                        std_est_method=None,
                        std_thr=2.,
                        mu=1e-6,
                        tau=None,
                        sigma=None,
                        relaxation_factor=1.0,
                        nb_of_reweights=1,
                        max_nb_of_iter=150,
                        add_positivity=False,
                        atol=1e-4,
                        verbose=0):
    """ The Condat-Vu sparse reconstruction with reweightings.

    .. note:: At the moment, supports only 2D data.

    Parameters
    ----------
    data: ndarray
        the data to reconstruct: observation are expected in Fourier space.
    wavelet_name: str
        the wavelet name to be used during the decomposition.
    samples: np.ndarray
        the mask samples in the Fourier domain.
    nb_scales: int, default 4
        the number of scales in the wavelet decomposition.
    std_est: float, default None
        the noise std estimate.
        If None use the MAD as a consistent estimator for the std.
    std_est_method: str, default None
        if the standard deviation is not set, estimate this parameter using
        the mad routine in the image ('primal') or in the sparse wavelet
        decomposition ('dual') domain.
    std_thr: float, default 2.
        use this treshold expressed as a number of sigma in the residual
        proximity operator during the thresholding.
    mu: float, default 1e-6
        regularization hyperparameter.
    tau, sigma: float, default None
        parameters of the Condat-Vu proximal-dual splitting algorithm.
        If None estimates these parameters.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    nb_of_reweights: int, default 1
        the number of reweightings.
    max_nb_of_iter: int, default 150
        the maximum number of iterations in the Condat-Vu proximal-dual
        splitting algorithm.
    add_positivity: bool, default False
        by setting this option, set the proximity operator to identity or
        positive.
    atol: float, default 1e-4
        tolerance threshold for convergence.
    verbose: int, default 0
        the verbosity level.

    Returns
    -------
    x_final: ndarray
        the estimated CONDAT-VU solution.
    transform: a WaveletTransformBase derived instance
        the wavelet transformation instance.
    """
    # Check inputs
    # analysis = True
    # if hasattr(gradient_op, 'linear_op'):
    #     analysis = False

    start = time.clock()

    if std_est_method not in (None, "primal", "dual"):
        raise ValueError(
            "Unrecognize std estimation method '{0}'.".format(std_est_method))

    # Define the initial primal and dual solutions
    x_init = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex)
    weights = linear_op.op(x_init)

    # Define the weights used during the thresholding in the dual domain,
    # the reweighting strategy, and the prox dual operator

    # Case1: estimate the noise std in the image domain
    if std_est_method == "primal":
        if std_est is None:
            std_est = sigma_mad(gradient_op.MtX(gradient_op.y))
        weights[...] = std_thr * std_est
        reweight_op = cwbReweight(weights)
        prox_dual_op = Threshold(reweight_op.weights)

    # Case2: estimate the noise std in the sparse wavelet domain
    elif std_est_method == "dual":
        if std_est is None:
            std_est = 0.0
        weights[...] = std_thr * std_est
        reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr)
        prox_dual_op = Threshold(reweight_op.weights)

    # Case3: manual regularization mode, no reweighting
    else:
        weights[...] = mu
        reweight_op = None
        prox_dual_op = Threshold(weights)
        nb_of_reweights = 0

    # Define the Condat Vu optimizer: define the tau and sigma in the
    # Condat-Vu proximal-dual splitting algorithm if not already provided.
    # Check also that the combination of values will lead to convergence.
    norm = linear_op.l2norm(gradient_op.fourier_op.shape)
    lipschitz_cst = gradient_op.spec_rad
    if sigma is None:
        sigma = 0.5
    if tau is None:
        # to avoid numerics troubles with the convergence bound
        eps = 1.0e-8
        # due to the convergence bound
        tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps)
    convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0)

    # Define initial primal and dual solutions
    primal = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex)
    dual = linear_op.op(primal)
    dual[...] = 0.0

    # Welcome message
    if verbose > 0:
        print(condatvu_logo())
        print(" - mu: ", mu)
        print(" - lipschitz constant: ", gradient_op.spec_rad)
        print(" - tau: ", tau)
        print(" - sigma: ", sigma)
        print(" - rho: ", relaxation_factor)
        print(" - std: ", std_est)
        print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test)
        print(" - data: ", gradient_op.obs_data.shape)
        print(" - max iterations: ", max_nb_of_iter)
        print(" - number of reweights: ", nb_of_reweights)
        print(" - primal variable shape: ", primal.shape)
        print(" - dual variable shape: ", dual.shape)
        print("-" * 40)

    # Define the proximity operator
    if add_positivity:
        prox_op = Positivity()
    else:
        prox_op = Identity()

    # Define the cost function
    cost_op = DualGapCost(linear_op=linear_op,
                          initial_cost=1e6,
                          tolerance=1e-4,
                          cost_interval=1,
                          test_range=4,
                          verbose=0,
                          plot_output=None)

    # Define the optimizer
    opt = Condat(x=primal,
                 y=dual,
                 grad=gradient_op,
                 prox=prox_op,
                 prox_dual=prox_dual_op,
                 linear=linear_op,
                 cost=cost_op,
                 rho=relaxation_factor,
                 sigma=sigma,
                 tau=tau,
                 rho_update=None,
                 sigma_update=None,
                 tau_update=None,
                 auto_iterate=False)

    # Perform the first reconstruction
    if verbose > 0:
        print("Starting optimization...")

    for i in range(max_nb_of_iter):
        opt._update()

    opt.x_final = opt._x_new
    opt.y_final = opt._y_new

    # Loop through the number of reweightings
    for reweight_index in range(nb_of_reweights):

        # Generate the new weights following reweighting prescription
        if std_est_method == "primal":
            reweight_op.reweight(linear_op.op(opt._x_new))
        else:
            std_est = reweight_op.reweight(opt._x_new)

        # Welcome message
        if verbose > 0:
            print(" - reweight: ", reweight_index + 1)
            print(" - std: ", std_est)

        # Update the weights in the dual proximity operator
        prox_dual_op.weights = reweight_op.weights

        # Perform optimisation with new weights
        opt.iterate(max_iter=max_nb_of_iter)

    # Goodbye message
    end = time.clock()
    if verbose > 0:
        print(" - final iteration number: ", cost_op._iteration)
        print(" - final cost value: ", cost_op.cost)
        print(" - converged: ", opt.converge)
        print("Done.")
        print("Execution time: ", end - start, " seconds")
        print("-" * 40)

    # Get the final solution
    x_final = opt.x_final
    linear_op.transform.analysis_data = unflatten(opt.y_final,
                                                  linear_op.coeffs_shape)

    return x_final, linear_op.transform
Ejemplo n.º 21
0
 def test_gridsearch_single_channel(self):
     """Test Gridsearch script in mri.scripts for
     single channel reconstruction this is a test of sanity
     and not if the reconstruction is right.
     """
     image = get_sample_data('2d-mri')
     mask = np.ones(image.shape)
     kspace_loc = convert_mask_to_locations(mask)
     fourier_op = NonCartesianFFT(samples=kspace_loc, shape=image.shape)
     kspace_data = fourier_op.op(image.data)
     # Define the keyword dictionaries based on convention
     metrics = {
         'ssim': {
             'metric': ssim,
             'mapping': {
                 'x_new': 'test',
                 'y_new': None
             },
             'cst_kwargs': {
                 'ref': image,
                 'mask': None
             },
             'early_stopping': True,
         },
     }
     linear_params = {
         'init_class': WaveletN,
         'kwargs': {
             'wavelet_name': 'sym8',
             'nb_scale': 4,
         }
     }
     regularizer_params = {
         'init_class': SparseThreshold,
         'kwargs': {
             'linear': Identity(),
             'weights': [0, 1e-5],
         }
     }
     optimizer_params = {
         # Just following convention
         'kwargs': {
             'optimization_alg': 'fista',
             'num_iterations': 10,
             'metrics': metrics,
         }
     }
     # Call the launch grid function and obtain results
     raw_results, test_cases, key_names, best_idx = launch_grid(
         kspace_data=kspace_data,
         fourier_op=fourier_op,
         linear_params=linear_params,
         regularizer_params=regularizer_params,
         optimizer_params=optimizer_params,
         reconstructor_kwargs={'gradient_formulation': 'synthesis'},
         reconstructor_class=SingleChannelReconstructor,
         compare_metric_details={'metric': 'ssim'},
         n_jobs=self.n_jobs,
         verbose=1,
     )
     # In this test we dont undersample the kspace so the
     # reconstruction is indeed with mu=0, ie best_idx=0
     np.testing.assert_equal(best_idx, 0)
     np.testing.assert_allclose(
         raw_results[best_idx][0],
         image,
         atol=1e-7,
     )
Ejemplo n.º 22
0
    def __init__(self,
                 x,
                 grad,
                 prox,
                 cost='auto',
                 beta_param=1.0,
                 lambda_param=1.0,
                 beta_update=None,
                 lambda_update='fista',
                 auto_iterate=True,
                 metric_call_period=5,
                 metrics={},
                 linear=None):

        # Set default algorithm properties
        super(ForwardBackward,
              self).__init__(metric_call_period=metric_call_period,
                             metrics=metrics,
                             linear=linear)

        # Set the initial variable values
        self._check_input_data(x)
        self._x_old = np.copy(x)
        self._z_old = np.copy(x)

        # Set the algorithm operators
        (self._check_operator(operator) for operator in (grad, prox, cost))
        self._grad = grad
        self._prox = prox
        self._linear = linear

        if cost == 'auto':
            self._cost_func = costObj([self._grad, self._prox])
        else:
            self._cost_func = cost

        # Check if there is a linear op, needed for metrics in the FB algoritm
        if metrics != {} and self._linear is None:
            raise ValueError('When using metrics, you must pass a linear '
                             'operator')

        if self._linear is None:
            self._linear = Identity()

        # Set the algorithm parameters
        (self._check_param(param) for param in (beta_param, lambda_param))
        self._beta = beta_param
        self._lambda = lambda_param

        # Set the algorithm parameter update methods
        if isinstance(lambda_update, str) and lambda_update == 'fista':
            self._lambda_update = FISTA().update_lambda
        else:
            self._check_param_update(lambda_update)
            self._lambda_update = lambda_update
        self._check_param_update(beta_update)
        self._beta_update = beta_update

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate()
Ejemplo n.º 23
0
def sparse_rec_condatvu(gradient_op, linear_op, prox_dual_op, cost_op,
                        std_est=None, std_est_method=None, std_thr=2.,
                        mu=1e-6, tau=None, sigma=None, relaxation_factor=1.0,
                        nb_of_reweights=1, max_nb_of_iter=150,
                        add_positivity=False, metric_call_period=5,
                        metrics=None, verbose=False, progress=True):
    """
    The Condat-Vu sparse reconstruction with reweightings.

    Parameters
    ----------
    gradient_op: instance of class GradBase
        the gradient operator.
    linear_op: instance of LinearBase
        the linear operator: seek the sparsity, ie. a wavelet transform.
    prox_dual_op: instance of ProximityParent
        the proximal dual operator.
    cost_op: instance of costObj
        the cost function used to check for convergence during the
        optimization.
    std_est: float, default None
        the noise std estimate.
        If None use the MAD as a consistent estimator for the std.
    std_est_method: str, default None
        if the standard deviation is not set, estimate this parameter using
        the mad routine in the image ('primal') or in the sparse wavelet
        decomposition ('dual') domain.
    std_thr: float, default 2.
        use this treshold expressed as a number of sigma in the residual
        proximity operator during the thresholding.
    mu: float, default 1e-6
        regularization hyperparameter.
    tau, sigma: float, default None
        parameters of the Condat-Vu proximal-dual splitting algorithm.
        If None estimates these parameters.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    nb_of_reweights: int, default 1
        the number of reweightings.
    max_nb_of_iter: int, default 150
        the maximum number of iterations in the Condat-Vu proximal-dual
        splitting algorithm.
    add_positivity: bool, default False
        by setting this option, set the proximity operator to identity or
        positive.
    metric_call_period: int (default 5)
        the period on which the metrics are compute.
    metrics: dict (optional, default None)
        the list of desired convergence metrics: {'metric_name':
        [@metric, metric_parameter]}. See modopt for the metrics API.
    verbose: bool, default False
        the verbosity level.
    progress: bool, optional
        Activation key for progression bar displaying

    Returns
    -------
    x_final: np.ndarray((m,n)) or np.ndarray((m,n,p))
        the estimated CONDAT-VU solution.
    transform_output: a WaveletTransformBase derived instance or an array
        the wavelet transformation instance or the transformation coefficients.
    costs: list of float
        the cost function values.
    metrics: dict
        the requested metrics values during the optimization.
    """
    # Check inputs
    start = time.perf_counter()
    if std_est_method not in (None, "dual"):
        raise ValueError(
            "Unrecognized std estimation method '{}'.".format(std_est_method))

    # Define the initial primal and dual solutions
    x_init = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex)
    weights = linear_op.op(x_init)

    # Define the weights used during the thresholding in the dual domain,
    # the reweighting strategy, and the prox dual operator

    # case1: estimate the noise std in the sparse wavelet domain
    if std_est_method == "dual":
        if std_est is None:
            std_est = 0.0
        weights[...] = std_thr * std_est
        reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr)
        prox_dual_op.weights = reweight_op.weights

    # Case2: manual regularization mode, no reweighting
    else:
        weights[...] = mu
        reweight_op = None
        prox_dual_op.weights = weights
        nb_of_reweights = 0

    # Define the Condat Vu optimizer: define the tau and sigma in the
    # Condat-Vu proximal-dual splitting algorithm if not already provided.
    # Check also that the combination of values will lead to convergence.
    norm = linear_op.l2norm(gradient_op.fourier_op.shape)
    lipschitz_cst = gradient_op.spec_rad
    if sigma is None:
        sigma = 0.5
    if tau is None:
        # to avoid numerics troubles with the convergence bound
        eps = 1.0e-8
        # due to the convergence bound
        tau = 1.0 / (lipschitz_cst / 2 + sigma * norm ** 2 + eps)
    convergence_test = (1.0 / tau - sigma * norm ** 2 >= lipschitz_cst / 2.0)

    # Define initial primal and dual solutions
    primal = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex)
    dual = linear_op.op(primal)
    dual[...] = 0.0

    # Welcome message
    if verbose:
        print(" - mu: ", mu)
        print(" - lipschitz constant: ", gradient_op.spec_rad)
        print(" - tau: ", tau)
        print(" - sigma: ", sigma)
        print(" - rho: ", relaxation_factor)
        print(" - std: ", std_est)
        print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test)
        print(" - data: ", gradient_op.fourier_op.shape)
        if hasattr(linear_op, "nb_scale"):
            print(" - wavelet: ", linear_op, "-", linear_op.nb_scale)
        print(" - max iterations: ", max_nb_of_iter)
        print(" - number of reweights: ", nb_of_reweights)
        print(" - primal variable shape: ", primal.shape)
        print(" - dual variable shape: ", dual.shape)
        print("-" * 40)

    # Define the proximity operator
    if add_positivity:
        prox_op = Positivity()
    else:
        prox_op = Identity()

    # Define the optimizer
    opt = Condat(
        x=primal,
        y=dual,
        grad=gradient_op,
        prox=prox_op,
        prox_dual=prox_dual_op,
        linear=linear_op,
        cost=cost_op,
        rho=relaxation_factor,
        sigma=sigma,
        tau=tau,
        rho_update=None,
        sigma_update=None,
        tau_update=None,
        auto_iterate=False,
        metric_call_period=metric_call_period,
        metrics=metrics or {},
        progress=progress)
    cost_op = opt._cost_func

    # Perform the first reconstruction
    if verbose:
        print("Starting optimization...")
    opt.iterate(max_iter=max_nb_of_iter)

    # Loop through the number of reweightings
    for reweight_index in range(nb_of_reweights):

        # Generate the new weights following reweighting prescription
        std_est = reweight_op.reweight(opt._x_new)

        # Welcome message
        if verbose:
            print(" - reweight: ", reweight_index + 1)
            print(" - std: ", std_est)

        # Update the weights in the dual proximity operator
        prox_dual_op.weights = reweight_op.weights

        # Perform optimisation with new weights
        opt.iterate(max_iter=max_nb_of_iter)

    # Goodbye message
    end = time.perf_counter()
    if verbose:
        if hasattr(cost_op, "cost"):
            print(" - final iteration number: ", cost_op._iteration)
            print(" - final cost value: ", cost_op.cost)
        print(" - converged: ", opt.converge)
        print("Done.")
        print("Execution time: ", end - start, " seconds")
        print("-" * 40)

    # Get the final solution
    x_final = opt.x_final
    if hasattr(linear_op, "transform"):
        linear_op.transform.analysis_data = unflatten(
            opt.y_final, linear_op.coeffs_shape)
        transform_output = linear_op.transform
    else:
        linear_op.coeff = opt.y_final
        transform_output = linear_op.coeff
    if hasattr(cost_op, "cost"):
        costs = cost_op._cost_list
    else:
        costs = None

    return x_final, transform_output, costs, opt.metrics
Ejemplo n.º 24
0
    def __init__(self,
                 x,
                 y,
                 grad,
                 prox,
                 prox_dual,
                 linear=None,
                 cost='auto',
                 reweight=None,
                 rho=0.5,
                 sigma=1.0,
                 tau=1.0,
                 rho_update=None,
                 sigma_update=None,
                 tau_update=None,
                 auto_iterate=True,
                 metric_call_period=5,
                 metrics={}):

        # Set default algorithm properties
        super(Condat, self).__init__(
            metric_call_period=metric_call_period,
            metrics=metrics,
        )

        # Set the initial variable values
        (self._check_input_data(data) for data in (x, y))
        self._x_old = np.copy(x)
        self._y_old = np.copy(y)

        # Set the algorithm operators
        (self._check_operator(operator)
         for operator in (grad, prox, prox_dual, linear, cost))
        self._grad = grad
        self._prox = prox
        self._prox_dual = prox_dual
        self._reweight = reweight
        if isinstance(linear, type(None)):
            self._linear = Identity()
        else:
            self._linear = linear
        if cost == 'auto':
            self._cost_func = costObj(
                [self._grad, self._prox, self._prox_dual])
        else:
            self._cost_func = cost

        # Set the algorithm parameters
        (self._check_param(param) for param in (rho, sigma, tau))
        self._rho = rho
        self._sigma = sigma
        self._tau = tau

        # Set the algorithm parameter update methods
        (self._check_param_update(param_update)
         for param_update in (rho_update, sigma_update, tau_update))
        self._rho_update = rho_update
        self._sigma_update = sigma_update
        self._tau_update = tau_update

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate()
Ejemplo n.º 25
0
def condatvu_online(
    kspace_generator,
    gradient_op,
    linear_op,
    prox_op,
    cost_op,
    max_nb_of_iter=150,
    tau=None,
    sigma=None,
    relaxation_factor=1.0,
    x_init=None,
    std_est=None,
    nb_run=1,
    metric_call_period=5,
    metrics=None,
    estimate_call_period=None,
    verbose=0,
):
    """ The Condat-Vu sparse reconstruction with reweightings.

    Parameters
    ----------
    kspace_generator: instance of class KspaceGenerator
        the observed data (ie kspace) generated for each iteration of the algorithm
    gradient_op: instance of class GradBase
        the gradient operator.
    linear_op: instance of LinearBase
        the linear operator: seek the sparsity, ie. a wavelet transform.
    prox_op: instance of ProximityParent
        the  dual regularization operator
    cost_op: instance of costObj
        the cost function used to check for convergence during the
        optimization.
    max_nb_of_iter: int, default 150
        the maximum number of iterations in the Condat-Vu proximal-dual
        splitting algorithm.
    tau, sigma: float, default None
        parameters of the Condat-Vu proximal-dual splitting algorithm.
        If None estimates these parameters.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    x_init: np.ndarray (optional, default None)
        the initial guess of image
    std_est: float, default None
        the noise std estimate.
        If None use the MAD as a consistent estimator for the std.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    nb_of_reweights: int, default 1
        the number of reweightings.
    metric_call_period: int (default 5)
        the period on which the metrics are compute.
    metrics: dict (optional, default None)
        the list of desired convergence metrics: {'metric_name':
        [@metric, metric_parameter]}. See modopt for the metrics API.
    verbose: int, default 0
        the verbosity level.

    Returns
    -------
    x_final: ndarray
        the estimated CONDAT-VU solution.
    costs: list of float
        the cost function values.
    metrics: dict
        the requested metrics values during the optimization.
    y_final: ndarrat
        the estimated dual CONDAT-VU solution
    """

    # Check inputs
    if metrics is None:
        metrics = dict()

    # Define the initial primal and dual solutions
    if x_init is None:
        x_init = np.squeeze(
            np.zeros((gradient_op.fourier_op.n_coils,
                      *gradient_op.fourier_op.shape),
                     dtype=np.complex128))
    primal = x_init
    dual = linear_op.op(primal)

    # Define the Condat Vu optimizer: define the tau and sigma in the
    # Condat-Vu proximal-dual splitting algorithm if not already provided.
    # Check also that the combination of values will lead to convergence.

    norm = linear_op.l2norm(x_init.shape)
    lipschitz_cst = gradient_op.spec_rad
    if sigma is None:
        sigma = 0.5
    if tau is None:
        # to avoid numerics troubles with the convergence bound
        eps = 1.0e-8
        # due to the convergence bound
        tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps)
    convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0)

    # Welcome message
    if verbose > 0:
        print(" - mu: ", prox_op.weights)
        print(" - lipschitz constant: ", gradient_op.spec_rad)
        print(" - tau: ", tau)
        print(" - sigma: ", sigma)
        print(" - rho: ", relaxation_factor)
        print(" - std: ", std_est)
        print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test)
        print(" - data: ", gradient_op.fourier_op.shape)
        if hasattr(linear_op, "nb_scale"):
            print(" - wavelet: ", linear_op, "-", linear_op.nb_scale)
        print(" - max iterations: ", max_nb_of_iter)
        print(" - primal variable shape: ", primal.shape)
        print(" - dual variable shape: ", dual.shape)
        print("-" * 40)

    prox_primal = Identity()

    # Define the optimizer
    opt = Condat(
        x=primal,
        y=dual,
        grad=gradient_op,
        prox=prox_primal,
        prox_dual=prox_op,
        linear=linear_op,
        cost=cost_op,
        rho=relaxation_factor,
        # sigma=sigma,
        # tau=tau,
        rho_update=None,
        sigma_update=None,
        tau_update=None,
        auto_iterate=False,
        metric_call_period=metric_call_period,
        metrics=metrics)

    return online_algorithm(opt,
                            kspace_generator,
                            estimate_call_period=estimate_call_period,
                            nb_run=nb_run)
Ejemplo n.º 26
0
class Condat(SetUp):
    """Condat optimisation.

    This class implements algorithm 3.1 from :cite:`condat2013`

    Parameters
    ----------
    x : numpy.ndarray
        Initial guess for the primal variable
    y : numpy.ndarray
        Initial guess for the dual variable
    grad : class instance
        Gradient operator class
    prox : class instance
        Proximity primal operator class
    prox_dual : class instance
        Proximity dual operator class
    linear : class instance, optional
        Linear operator class (default is ``None``)
    cost : class or str, optional
        Cost function class (default is 'auto'); Use 'auto' to automatically
        generate a costObj instance
    reweight : class instance, optional
        Reweighting class
    rho : float, optional
        Relaxation parameter (default is ``0.5``)
    sigma : float, optional
        Proximal dual parameter (default is ``1.0``)
    tau : float, optional
        Proximal primal paramater (default is ``1.0``)
    rho_update : function, optional
        Relaxation parameter update method (default is ``None``)
    sigma_update : function, optional
        Proximal dual parameter update method (default is ``None``)
    tau_update : function, optional
        Proximal primal parameter update method (default is ``None``)
    auto_iterate : bool, optional
        Option to automatically begin iterations upon initialisation (default
        is ``True``)
    max_iter : int, optional
        Maximum number of iterations (default is ``150``)
    n_rewightings : int, optional
        Number of reweightings to perform (default is ``1``)

    Notes
    -----
    The `tau_param` can also be set using the keyword `step_size`, which will
    override the value of `tau_param`.

    See Also
    --------
    SetUp : parent class

    """
    def __init__(
        self,
        x,
        y,
        grad,
        prox,
        prox_dual,
        linear=None,
        cost='auto',
        reweight=None,
        rho=0.5,
        sigma=1.0,
        tau=1.0,
        rho_update=None,
        sigma_update=None,
        tau_update=None,
        auto_iterate=True,
        max_iter=150,
        n_rewightings=1,
        metric_call_period=5,
        metrics=None,
        **kwargs,
    ):

        # Set default algorithm properties
        super().__init__(
            metric_call_period=metric_call_period,
            metrics=metrics,
            **kwargs,
        )

        # Set the initial variable values
        for input_data in (x, y):
            self._check_input_data(input_data)

        self._x_old = self.xp.copy(x)
        self._y_old = self.xp.copy(y)

        # Set the algorithm operators
        for operator in (grad, prox, prox_dual, linear, cost):
            self._check_operator(operator)

        self._grad = grad
        self._prox = prox
        self._prox_dual = prox_dual
        self._reweight = reweight
        if isinstance(linear, type(None)):
            self._linear = Identity()
        else:
            self._linear = linear
        if cost == 'auto':
            self._cost_func = costObj([
                self._grad,
                self._prox,
                self._prox_dual,
            ])
        else:
            self._cost_func = cost

        # Set the algorithm parameters
        for param_val in (rho, sigma, tau):
            self._check_param(param_val)

        self._rho = rho
        self._sigma = sigma
        self._tau = self.step_size or tau

        # Set the algorithm parameter update methods
        for param_update in (rho_update, sigma_update, tau_update):
            self._check_param_update(param_update)

        self._rho_update = rho_update
        self._sigma_update = sigma_update
        self._tau_update = tau_update

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate(max_iter=max_iter, n_rewightings=n_rewightings)

    def _update_param(self):
        """Update parameters.

        This method updates the values of the algorthm parameters with the
        methods provided

        """
        # Update relaxation parameter.
        if not isinstance(self._rho_update, type(None)):
            self._rho = self._rho_update(self._rho)

        # Update proximal dual parameter.
        if not isinstance(self._sigma_update, type(None)):
            self._sigma = self._sigma_update(self._sigma)

        # Update proximal primal parameter.
        if not isinstance(self._tau_update, type(None)):
            self._tau = self._tau_update(self._tau)

    def _update(self):
        """Update.

        This method updates the current reconstruction

        Notes
        -----
        Implements equation 9 (algorithm 3.1) from :cite:`condat2013`

        - primal proximity operator set up for positivity constraint

        """
        # Step 1 from eq.9.
        self._grad.get_grad(self._x_old)

        x_prox = self._prox.op(
            self._x_old - self._tau * self._grad.grad -
            self._tau * self._linear.adj_op(self._y_old), )

        # Step 2 from eq.9.
        y_temp = (self._y_old +
                  self._sigma * self._linear.op(2 * x_prox - self._x_old))

        y_prox = (y_temp - self._sigma * self._prox_dual.op(
            y_temp / self._sigma,
            extra_factor=(1.0 / self._sigma),
        ))

        # Step 3 from eq.9.
        self._x_new = self._rho * x_prox + (1 - self._rho) * self._x_old
        self._y_new = self._rho * y_prox + (1 - self._rho) * self._y_old

        del x_prox, y_prox, y_temp

        # Update old values for next iteration.
        self.xp.copyto(self._x_old, self._x_new)
        self.xp.copyto(self._y_old, self._y_new)

        # Update parameter values for next iteration.
        self._update_param()

        # Test cost function for convergence.
        if self._cost_func:
            self.converge = (self.any_convergence_flag()
                             or self._cost_func.get_cost(
                                 self._x_new, self._y_new))

    def iterate(self, max_iter=150, n_rewightings=1):
        """Iterate.

        This method calls update until either convergence criteria is met or
        the maximum number of iterations is reached

        Parameters
        ----------
        max_iter : int, optional
            Maximum number of iterations (default is ``150``)
        n_rewightings : int, optional
            Number of reweightings to perform (default is ``1``)

        """
        self._run_alg(max_iter)

        if not isinstance(self._reweight, type(None)):
            for _ in range(n_rewightings):
                self._reweight.reweight(self._linear.op(self._x_new))
                self._run_alg(max_iter)

        # retrieve metrics results
        self.retrieve_outputs()
        # rename outputs as attributes
        self.x_final = self._x_new
        self.y_final = self._y_new

    def get_notify_observers_kwargs(self):
        """Notify observers.

        Return the mapping between the metrics call and the iterated
        variables.

        Returns
        -------
        notify_observers_kwargs : dict,
           The mapping between the iterated variables

        """
        return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx}

    def retrieve_outputs(self):
        """Retrieve outputs.

        Declare the outputs of the algorithms as attributes: x_final,
        y_final, metrics.

        """
        metrics = {}
        for obs in self._observers['cv_metrics']:
            metrics[obs.name] = obs.retrieve_metrics()
        self.metrics = metrics
#############################################################################
# FISTA optimization
# ------------------
#
# We now want to refine the zero order solution using a FISTA optimization.
# The cost function is set to Proximity Cost + Gradient Cost

# Setup the operators
linear_op = WaveletN(
    wavelet_name="sym8",
    nb_scales=4,
    dim=3,
    padding_mode="periodization",
)
regularizer_op = SparseThreshold(Identity(), 2 * 1e-11, thresh_type="soft")
# Setup Reconstructor
reconstructor = SingleChannelReconstructor(
    fourier_op=fourier_op,
    linear_op=linear_op,
    regularizer_op=regularizer_op,
    gradient_formulation='synthesis',
    verbose=1,
)
# Start Reconstruction
x_final, costs, metrics = reconstructor.reconstruct(
    kspace_data=kspace_data,
    optimization_alg='fista',
    num_iterations=200,
)
image_rec = pysap.Image(data=np.abs(x_final))
Ejemplo n.º 28
0
    def __init__(
        self,
        x,
        y,
        grad,
        prox,
        prox_dual,
        linear=None,
        cost='auto',
        reweight=None,
        rho=0.5,
        sigma=1.0,
        tau=1.0,
        rho_update=None,
        sigma_update=None,
        tau_update=None,
        auto_iterate=True,
        max_iter=150,
        n_rewightings=1,
        metric_call_period=5,
        metrics=None,
        **kwargs,
    ):

        # Set default algorithm properties
        super().__init__(
            metric_call_period=metric_call_period,
            metrics=metrics,
            **kwargs,
        )

        # Set the initial variable values
        for input_data in (x, y):
            self._check_input_data(input_data)

        self._x_old = self.xp.copy(x)
        self._y_old = self.xp.copy(y)

        # Set the algorithm operators
        for operator in (grad, prox, prox_dual, linear, cost):
            self._check_operator(operator)

        self._grad = grad
        self._prox = prox
        self._prox_dual = prox_dual
        self._reweight = reweight
        if isinstance(linear, type(None)):
            self._linear = Identity()
        else:
            self._linear = linear
        if cost == 'auto':
            self._cost_func = costObj([
                self._grad,
                self._prox,
                self._prox_dual,
            ])
        else:
            self._cost_func = cost

        # Set the algorithm parameters
        for param_val in (rho, sigma, tau):
            self._check_param(param_val)

        self._rho = rho
        self._sigma = sigma
        self._tau = self.step_size or tau

        # Set the algorithm parameter update methods
        for param_update in (rho_update, sigma_update, tau_update):
            self._check_param_update(param_update)

        self._rho_update = rho_update
        self._sigma_update = sigma_update
        self._tau_update = tau_update

        # Automatically run the algorithm
        if auto_iterate:
            self.iterate(max_iter=max_iter, n_rewightings=n_rewightings)
Ejemplo n.º 29
0
def condatvu(gradient_op,
             linear_op,
             dual_regularizer,
             cost_op,
             max_nb_of_iter=150,
             tau=None,
             sigma=None,
             relaxation_factor=1.0,
             x_init=None,
             std_est=None,
             std_est_method=None,
             std_thr=2.,
             nb_of_reweights=1,
             metric_call_period=5,
             metrics={},
             verbose=0):
    """ The Condat-Vu sparse reconstruction with reweightings.

    Parameters
    ----------
    gradient_op: instance of class GradBase
        the gradient operator.
    linear_op: instance of LinearBase
        the linear operator: seek the sparsity, ie. a wavelet transform.
    dual_regularizer: instance of ProximityParent
        the  dual regularization operator
    cost_op: instance of costObj
        the cost function used to check for convergence during the
        optimization.
    max_nb_of_iter: int, default 150
        the maximum number of iterations in the Condat-Vu proximal-dual
        splitting algorithm.
    tau, sigma: float, default None
        parameters of the Condat-Vu proximal-dual splitting algorithm.
        If None estimates these parameters.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    x_init: np.ndarray (optional, default None)
        the initial guess of image
    std_est: float, default None
        the noise std estimate.
        If None use the MAD as a consistent estimator for the std.
    std_est_method: str, default None
        if the standard deviation is not set, estimate this parameter using
        the mad routine in the image ('primal') or in the sparse wavelet
        decomposition ('dual') domain.
    std_thr: float, default 2.
        use this treshold expressed as a number of sigma in the residual
        proximity operator during the thresholding.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    nb_of_reweights: int, default 1
        the number of reweightings.
    metric_call_period: int (default 5)
        the period on which the metrics are compute.
    metrics: dict (optional, default None)
        the list of desired convergence metrics: {'metric_name':
        [@metric, metric_parameter]}. See modopt for the metrics API.
    verbose: int, default 0
        the verbosity level.

    Returns
    -------
    x_final: ndarray
        the estimated CONDAT-VU solution.
    costs: list of float
        the cost function values.
    metrics: dict
        the requested metrics values during the optimization.
    y_final: ndarrat
        the estimated dual CONDAT-VU solution
    """
    # Check inputs
    start = time.clock()
    if std_est_method not in (None, "primal", "dual"):
        raise ValueError(
            "Unrecognize std estimation method '{0}'.".format(std_est_method))

    # Define the initial primal and dual solutions
    if x_init is None:
        x_init = np.squeeze(
            np.zeros((linear_op.n_coils, *gradient_op.fourier_op.shape),
                     dtype=np.complex))
    primal = x_init
    dual = linear_op.op(primal)
    weights = dual

    # Define the weights used during the thresholding in the dual domain,
    # the reweighting strategy, and the prox dual operator

    # Case1: estimate the noise std in the image domain
    if std_est_method == "primal":
        if std_est is None:
            std_est = sigma_mad(gradient_op.MtX(gradient_op.obs_data))
        weights[...] = std_thr * std_est
        reweight_op = cwbReweight(weights)
        dual_regularizer.weights = reweight_op.weights

    # Case2: estimate the noise std in the sparse wavelet domain
    elif std_est_method == "dual":
        if std_est is None:
            std_est = 0.0
        weights[...] = std_thr * std_est
        reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr)
        dual_regularizer.weights = reweight_op.weights

    # Case3: manual regularization mode, no reweighting
    else:
        reweight_op = None
        nb_of_reweights = 0

    # Define the Condat Vu optimizer: define the tau and sigma in the
    # Condat-Vu proximal-dual splitting algorithm if not already provided.
    # Check also that the combination of values will lead to convergence.
    norm = linear_op.l2norm(x_init.shape)
    lipschitz_cst = gradient_op.spec_rad
    if sigma is None:
        sigma = 0.5
    if tau is None:
        # to avoid numerics troubles with the convergence bound
        eps = 1.0e-8
        # due to the convergence bound
        tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps)
    convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0)

    # Welcome message
    if verbose > 0:
        print(" - mu: ", dual_regularizer.weights)
        print(" - lipschitz constant: ", gradient_op.spec_rad)
        print(" - tau: ", tau)
        print(" - sigma: ", sigma)
        print(" - rho: ", relaxation_factor)
        print(" - std: ", std_est)
        print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test)
        print(" - data: ", gradient_op.fourier_op.shape)
        if hasattr(linear_op, "nb_scale"):
            print(" - wavelet: ", linear_op, "-", linear_op.nb_scale)
        print(" - max iterations: ", max_nb_of_iter)
        print(" - number of reweights: ", nb_of_reweights)
        print(" - primal variable shape: ", primal.shape)
        print(" - dual variable shape: ", dual.shape)
        print("-" * 40)

    prox_op = Identity()

    # Define the optimizer
    opt = Condat(x=primal,
                 y=dual,
                 grad=gradient_op,
                 prox=prox_op,
                 prox_dual=dual_regularizer,
                 linear=linear_op,
                 cost=cost_op,
                 rho=relaxation_factor,
                 sigma=sigma,
                 tau=tau,
                 rho_update=None,
                 sigma_update=None,
                 tau_update=None,
                 auto_iterate=False,
                 metric_call_period=metric_call_period,
                 metrics=metrics)
    cost_op = opt._cost_func

    # Perform the first reconstruction
    if verbose > 0:
        print("Starting optimization...")
    opt.iterate(max_iter=max_nb_of_iter)

    # Loop through the number of reweightings
    for reweight_index in range(nb_of_reweights):

        # Generate the new weights following reweighting prescription
        if std_est_method == "primal":
            reweight_op.reweight(linear_op.op(opt._x_new))
        else:
            std_est = reweight_op.reweight(opt._x_new)

        # Welcome message
        if verbose > 0:
            print(" - reweight: ", reweight_index + 1)
            print(" - std: ", std_est)

        # Update the weights in the dual proximity operator
        dual_regularizer.weights = reweight_op.weights

        # Perform optimisation with new weights
        opt.iterate(max_iter=max_nb_of_iter)

    # Goodbye message
    end = time.clock()
    if verbose > 0:
        if hasattr(cost_op, "cost"):
            print(" - final iteration number: ", cost_op._iteration)
            print(" - final cost value: ", cost_op.cost)
        print(" - converged: ", opt.converge)
        print("Done.")
        print("Execution time: ", end - start, " seconds")
        print("-" * 40)

    # Get the final solution
    x_final = opt.x_final
    y_final = opt.y_final
    if hasattr(cost_op, "cost"):
        costs = cost_op._cost_list
    else:
        costs = None

    return x_final, costs, opt.metrics, y_final
Ejemplo n.º 30
0
            'mask': None
        },
        'early_stopping': True,
    },
}
linear_params = {
    'init_class': WaveletN,
    'kwargs': {
        'wavelet_name': ['sym8', 'sym12'],
        'nb_scale': [3, 4]
    }
}
regularizer_params = {
    'init_class': SparseThreshold,
    'kwargs': {
        'linear': Identity(),
        'weights': np.logspace(-8, -6, 5),
    }
}
optimizer_params = {
    # Just following convention
    'kwargs': {
        'optimization_alg': 'fista',
        'num_iterations': 20,
        'metrics': metrics,
    }
}
# Call the launch grid function and obtain results
raw_results, test_cases, key_names, best_idx = launch_grid(
    kspace_data=kspace_data,
    fourier_op=fourier_op,