Exemplo n.º 1
0
 def setUp(self):
     """Set test parameter values."""
     self.data1 = np.arange(9).reshape(3, 3).astype(float) + 1
     self.data2 = np.array(
         [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], )
     self.rw = reweight.cwbReweight(self.data1)
     self.rw.reweight(self.data1)
Exemplo n.º 2
0
def set_sparse_weights(data_shape, psf, **kwargs):
    """Set the sparsity weights

    This method defines the weights for thresholding in the sparse domain and
    add them to the keyword arguments. It additionally defines the shape of the
    dual variable.

    Parameters
    ----------
    data_shape : tuple
        Shape of the input data array
    psf : np.ndarray
        PSF data (2D or 3D array)

    Returns
    -------
    dict Updated keyword arguments

    """

    # Convolve the PSF with the wavelet filters
    if kwargs['psf_type'] == 'fixed':

        filter_conv = (filter_convolve(np.rot90(psf, 2),
                                       kwargs['wavelet_filters'],
                                       method=kwargs['convolve_method']))

        filter_norm = np.array([
            norm(a) * b * np.ones(data_shape[1:])
            for a, b in zip(filter_conv, kwargs['wave_thresh_factor'])
        ])

        filter_norm = np.array([filter_norm for i in range(data_shape[0])])

    else:

        filter_conv = (filter_convolve_stack(np.rot90(psf, 2),
                                             kwargs['wavelet_filters'],
                                             method=kwargs['convolve_method']))

        filter_norm = np.array([[
            norm(b) * c * np.ones(data_shape[1:])
            for b, c in zip(a, kwargs['wave_thresh_factor'])
        ] for a in filter_conv])

    # Define a reweighting instance
    kwargs['reweight'] = cwbReweight(kwargs['noise_est'] * filter_norm)

    # Set the shape of the dual variable
    dual_shape = ([kwargs['wavelet_filters'].shape[0]] + list(data_shape))
    dual_shape[0], dual_shape[1] = dual_shape[1], dual_shape[0]
    kwargs['dual_shape'] = dual_shape

    return kwargs
Exemplo n.º 3
0
def set_sparse_weights(data_shape, psf, **kwargs):
    """Set the sparsity weights

    This method defines the weights for thresholding in the sparse domain and
    add them to the keyword arguments. It additionally defines the shape of the
    dual variable.

    Parameters
    ----------
    data_shape : tuple
        Shape of the input data array
    psf : np.ndarray
        PSF data (2D or 3D array)

    Returns
    -------
    dict Updated keyword arguments

    """

    # Convolve the PSF with the wavelet filters
    if kwargs['psf_type'] == 'fixed':

        filter_conv = (filter_convolve(np.rot90(psf, 2),
                       kwargs['wavelet_filters'],
                       method=kwargs['convolve_method']))

        filter_norm = np.array([norm(a) * b * np.ones(data_shape[1:])
                                for a, b in zip(filter_conv,
                                kwargs['wave_thresh_factor'])])

        filter_norm = np.array([filter_norm for i in
                                range(data_shape[0])])

    else:

        filter_conv = (filter_convolve_stack(np.rot90(psf, 2),
                       kwargs['wavelet_filters'],
                       method=kwargs['convolve_method']))

        filter_norm = np.array([[norm(b) * c * np.ones(data_shape[1:])
                                for b, c in zip(a,
                                kwargs['wave_thresh_factor'])]
                                for a in filter_conv])

    # Define a reweighting instance
    kwargs['reweight'] = cwbReweight(kwargs['noise_est'] * filter_norm)

    # Set the shape of the dual variable
    dual_shape = ([kwargs['wavelet_filters'].shape[0]] + list(data_shape))
    dual_shape[0], dual_shape[1] = dual_shape[1], dual_shape[0]
    kwargs['dual_shape'] = dual_shape

    return kwargs
Exemplo n.º 4
0
def condatvu(gradient_op,
             linear_op,
             dual_regularizer,
             cost_op,
             max_nb_of_iter=150,
             tau=None,
             sigma=None,
             relaxation_factor=1.0,
             x_init=None,
             std_est=None,
             std_est_method=None,
             std_thr=2.,
             nb_of_reweights=1,
             metric_call_period=5,
             metrics={},
             verbose=0):
    """ The Condat-Vu sparse reconstruction with reweightings.

    Parameters
    ----------
    gradient_op: instance of class GradBase
        the gradient operator.
    linear_op: instance of LinearBase
        the linear operator: seek the sparsity, ie. a wavelet transform.
    dual_regularizer: instance of ProximityParent
        the  dual regularization operator
    cost_op: instance of costObj
        the cost function used to check for convergence during the
        optimization.
    max_nb_of_iter: int, default 150
        the maximum number of iterations in the Condat-Vu proximal-dual
        splitting algorithm.
    tau, sigma: float, default None
        parameters of the Condat-Vu proximal-dual splitting algorithm.
        If None estimates these parameters.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    x_init: np.ndarray (optional, default None)
        the initial guess of image
    std_est: float, default None
        the noise std estimate.
        If None use the MAD as a consistent estimator for the std.
    std_est_method: str, default None
        if the standard deviation is not set, estimate this parameter using
        the mad routine in the image ('primal') or in the sparse wavelet
        decomposition ('dual') domain.
    std_thr: float, default 2.
        use this treshold expressed as a number of sigma in the residual
        proximity operator during the thresholding.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    nb_of_reweights: int, default 1
        the number of reweightings.
    metric_call_period: int (default 5)
        the period on which the metrics are compute.
    metrics: dict (optional, default None)
        the list of desired convergence metrics: {'metric_name':
        [@metric, metric_parameter]}. See modopt for the metrics API.
    verbose: int, default 0
        the verbosity level.

    Returns
    -------
    x_final: ndarray
        the estimated CONDAT-VU solution.
    costs: list of float
        the cost function values.
    metrics: dict
        the requested metrics values during the optimization.
    y_final: ndarrat
        the estimated dual CONDAT-VU solution
    """
    # Check inputs
    start = time.clock()
    if std_est_method not in (None, "primal", "dual"):
        raise ValueError(
            "Unrecognize std estimation method '{0}'.".format(std_est_method))

    # Define the initial primal and dual solutions
    if x_init is None:
        x_init = np.squeeze(
            np.zeros((linear_op.n_coils, *gradient_op.fourier_op.shape),
                     dtype=np.complex))
    primal = x_init
    dual = linear_op.op(primal)
    weights = dual

    # Define the weights used during the thresholding in the dual domain,
    # the reweighting strategy, and the prox dual operator

    # Case1: estimate the noise std in the image domain
    if std_est_method == "primal":
        if std_est is None:
            std_est = sigma_mad(gradient_op.MtX(gradient_op.obs_data))
        weights[...] = std_thr * std_est
        reweight_op = cwbReweight(weights)
        dual_regularizer.weights = reweight_op.weights

    # Case2: estimate the noise std in the sparse wavelet domain
    elif std_est_method == "dual":
        if std_est is None:
            std_est = 0.0
        weights[...] = std_thr * std_est
        reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr)
        dual_regularizer.weights = reweight_op.weights

    # Case3: manual regularization mode, no reweighting
    else:
        reweight_op = None
        nb_of_reweights = 0

    # Define the Condat Vu optimizer: define the tau and sigma in the
    # Condat-Vu proximal-dual splitting algorithm if not already provided.
    # Check also that the combination of values will lead to convergence.
    norm = linear_op.l2norm(x_init.shape)
    lipschitz_cst = gradient_op.spec_rad
    if sigma is None:
        sigma = 0.5
    if tau is None:
        # to avoid numerics troubles with the convergence bound
        eps = 1.0e-8
        # due to the convergence bound
        tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps)
    convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0)

    # Welcome message
    if verbose > 0:
        print(" - mu: ", dual_regularizer.weights)
        print(" - lipschitz constant: ", gradient_op.spec_rad)
        print(" - tau: ", tau)
        print(" - sigma: ", sigma)
        print(" - rho: ", relaxation_factor)
        print(" - std: ", std_est)
        print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test)
        print(" - data: ", gradient_op.fourier_op.shape)
        if hasattr(linear_op, "nb_scale"):
            print(" - wavelet: ", linear_op, "-", linear_op.nb_scale)
        print(" - max iterations: ", max_nb_of_iter)
        print(" - number of reweights: ", nb_of_reweights)
        print(" - primal variable shape: ", primal.shape)
        print(" - dual variable shape: ", dual.shape)
        print("-" * 40)

    prox_op = Identity()

    # Define the optimizer
    opt = Condat(x=primal,
                 y=dual,
                 grad=gradient_op,
                 prox=prox_op,
                 prox_dual=dual_regularizer,
                 linear=linear_op,
                 cost=cost_op,
                 rho=relaxation_factor,
                 sigma=sigma,
                 tau=tau,
                 rho_update=None,
                 sigma_update=None,
                 tau_update=None,
                 auto_iterate=False,
                 metric_call_period=metric_call_period,
                 metrics=metrics)
    cost_op = opt._cost_func

    # Perform the first reconstruction
    if verbose > 0:
        print("Starting optimization...")
    opt.iterate(max_iter=max_nb_of_iter)

    # Loop through the number of reweightings
    for reweight_index in range(nb_of_reweights):

        # Generate the new weights following reweighting prescription
        if std_est_method == "primal":
            reweight_op.reweight(linear_op.op(opt._x_new))
        else:
            std_est = reweight_op.reweight(opt._x_new)

        # Welcome message
        if verbose > 0:
            print(" - reweight: ", reweight_index + 1)
            print(" - std: ", std_est)

        # Update the weights in the dual proximity operator
        dual_regularizer.weights = reweight_op.weights

        # Perform optimisation with new weights
        opt.iterate(max_iter=max_nb_of_iter)

    # Goodbye message
    end = time.clock()
    if verbose > 0:
        if hasattr(cost_op, "cost"):
            print(" - final iteration number: ", cost_op._iteration)
            print(" - final cost value: ", cost_op.cost)
        print(" - converged: ", opt.converge)
        print("Done.")
        print("Execution time: ", end - start, " seconds")
        print("-" * 40)

    # Get the final solution
    x_final = opt.x_final
    y_final = opt.y_final
    if hasattr(cost_op, "cost"):
        costs = cost_op._cost_list
    else:
        costs = None

    return x_final, costs, opt.metrics, y_final
Exemplo n.º 5
0
    def setUp(self):
        """Set test parameter values."""
        self.data1 = np.arange(9).reshape(3, 3).astype(float)
        self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6
        self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1

        grad_inst = gradient.GradBasic(
            self.data1,
            func_identity,
            func_identity,
        )

        prox_inst = proximity.Positivity()
        prox_dual_inst = proximity.IdentityProx()
        linear_inst = linear.Identity()
        reweight_inst = reweight.cwbReweight(self.data3)
        cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst])
        self.setup = algorithms.SetUp()
        self.max_iter = 20

        self.fb_all_iter = algorithms.ForwardBackward(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            cost=None,
            auto_iterate=False,
            beta_update=func_identity,
        )
        self.fb_all_iter.iterate(self.max_iter)

        self.fb1 = algorithms.ForwardBackward(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            beta_update=func_identity,
        )

        self.fb2 = algorithms.ForwardBackward(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            cost=cost_inst,
            lambda_update=None,
        )

        self.fb3 = algorithms.ForwardBackward(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            beta_update=func_identity,
            a_cd=3,
        )

        self.fb4 = algorithms.ForwardBackward(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            beta_update=func_identity,
            r_lazy=3,
            p_lazy=0.7,
            q_lazy=0.7,
        )

        self.fb5 = algorithms.ForwardBackward(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            restart_strategy='adaptive',
            xi_restart=0.9,
        )

        self.fb6 = algorithms.ForwardBackward(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            restart_strategy='greedy',
            xi_restart=0.9,
            min_beta=1.0,
            s_greedy=1.1,
        )

        self.gfb_all_iter = algorithms.GenForwardBackward(
            self.data1,
            grad=grad_inst,
            prox_list=[prox_inst, prox_dual_inst],
            cost=None,
            auto_iterate=False,
            gamma_update=func_identity,
            beta_update=func_identity,
        )
        self.gfb_all_iter.iterate(self.max_iter)

        self.gfb1 = algorithms.GenForwardBackward(
            self.data1,
            grad=grad_inst,
            prox_list=[prox_inst, prox_dual_inst],
            gamma_update=func_identity,
            lambda_update=func_identity,
        )

        self.gfb2 = algorithms.GenForwardBackward(
            self.data1,
            grad=grad_inst,
            prox_list=[prox_inst, prox_dual_inst],
            cost=cost_inst,
        )

        self.gfb3 = algorithms.GenForwardBackward(
            self.data1,
            grad=grad_inst,
            prox_list=[prox_inst, prox_dual_inst],
            cost=cost_inst,
            step_size=2,
        )

        self.condat_all_iter = algorithms.Condat(
            self.data1,
            self.data2,
            grad=grad_inst,
            prox=prox_inst,
            cost=None,
            prox_dual=prox_dual_inst,
            sigma_update=func_identity,
            tau_update=func_identity,
            rho_update=func_identity,
            auto_iterate=False,
        )
        self.condat_all_iter.iterate(self.max_iter)

        self.condat1 = algorithms.Condat(
            self.data1,
            self.data2,
            grad=grad_inst,
            prox=prox_inst,
            prox_dual=prox_dual_inst,
            sigma_update=func_identity,
            tau_update=func_identity,
            rho_update=func_identity,
        )

        self.condat2 = algorithms.Condat(
            self.data1,
            self.data2,
            grad=grad_inst,
            prox=prox_inst,
            prox_dual=prox_dual_inst,
            linear=linear_inst,
            cost=cost_inst,
            reweight=reweight_inst,
        )

        self.condat3 = algorithms.Condat(
            self.data1,
            self.data2,
            grad=grad_inst,
            prox=prox_inst,
            prox_dual=prox_dual_inst,
            linear=Dummy(),
            cost=cost_inst,
            auto_iterate=False,
        )

        self.pogm_all_iter = algorithms.POGM(
            u=self.data1,
            x=self.data1,
            y=self.data1,
            z=self.data1,
            grad=grad_inst,
            prox=prox_inst,
            auto_iterate=False,
            cost=None,
        )
        self.pogm_all_iter.iterate(self.max_iter)

        self.pogm1 = algorithms.POGM(
            u=self.data1,
            x=self.data1,
            y=self.data1,
            z=self.data1,
            grad=grad_inst,
            prox=prox_inst,
        )

        self.vanilla_grad = algorithms.VanillaGenericGradOpt(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            cost=cost_inst,
        )
        self.ada_grad = algorithms.AdaGenericGradOpt(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            cost=cost_inst,
        )
        self.adam_grad = algorithms.ADAMGradOpt(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            cost=cost_inst,
        )
        self.momentum_grad = algorithms.MomentumGradOpt(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            cost=cost_inst,
        )
        self.rms_grad = algorithms.RMSpropGradOpt(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            cost=cost_inst,
        )
        self.saga_grad = algorithms.SAGAOptGradOpt(
            self.data1,
            grad=grad_inst,
            prox=prox_inst,
            cost=cost_inst,
        )

        self.dummy = Dummy()
        self.dummy.cost = func_identity
        self.setup._check_operator(self.dummy.cost)
Exemplo n.º 6
0
def sparse_rec_condatvu(gradient_op,
                        linear_op,
                        std_est=None,
                        std_est_method=None,
                        std_thr=2.,
                        mu=1e-6,
                        tau=None,
                        sigma=None,
                        relaxation_factor=1.0,
                        nb_of_reweights=1,
                        max_nb_of_iter=150,
                        add_positivity=False,
                        atol=1e-4,
                        verbose=0):
    """ The Condat-Vu sparse reconstruction with reweightings.

    .. note:: At the moment, supports only 2D data.

    Parameters
    ----------
    data: ndarray
        the data to reconstruct: observation are expected in Fourier space.
    wavelet_name: str
        the wavelet name to be used during the decomposition.
    samples: np.ndarray
        the mask samples in the Fourier domain.
    nb_scales: int, default 4
        the number of scales in the wavelet decomposition.
    std_est: float, default None
        the noise std estimate.
        If None use the MAD as a consistent estimator for the std.
    std_est_method: str, default None
        if the standard deviation is not set, estimate this parameter using
        the mad routine in the image ('primal') or in the sparse wavelet
        decomposition ('dual') domain.
    std_thr: float, default 2.
        use this treshold expressed as a number of sigma in the residual
        proximity operator during the thresholding.
    mu: float, default 1e-6
        regularization hyperparameter.
    tau, sigma: float, default None
        parameters of the Condat-Vu proximal-dual splitting algorithm.
        If None estimates these parameters.
    relaxation_factor: float, default 0.5
        parameter of the Condat-Vu proximal-dual splitting algorithm.
        If 1, no relaxation.
    nb_of_reweights: int, default 1
        the number of reweightings.
    max_nb_of_iter: int, default 150
        the maximum number of iterations in the Condat-Vu proximal-dual
        splitting algorithm.
    add_positivity: bool, default False
        by setting this option, set the proximity operator to identity or
        positive.
    atol: float, default 1e-4
        tolerance threshold for convergence.
    verbose: int, default 0
        the verbosity level.

    Returns
    -------
    x_final: ndarray
        the estimated CONDAT-VU solution.
    transform: a WaveletTransformBase derived instance
        the wavelet transformation instance.
    """
    # Check inputs
    # analysis = True
    # if hasattr(gradient_op, 'linear_op'):
    #     analysis = False

    start = time.clock()

    if std_est_method not in (None, "primal", "dual"):
        raise ValueError(
            "Unrecognize std estimation method '{0}'.".format(std_est_method))

    # Define the initial primal and dual solutions
    x_init = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex)
    weights = linear_op.op(x_init)

    # Define the weights used during the thresholding in the dual domain,
    # the reweighting strategy, and the prox dual operator

    # Case1: estimate the noise std in the image domain
    if std_est_method == "primal":
        if std_est is None:
            std_est = sigma_mad(gradient_op.MtX(gradient_op.y))
        weights[...] = std_thr * std_est
        reweight_op = cwbReweight(weights)
        prox_dual_op = Threshold(reweight_op.weights)

    # Case2: estimate the noise std in the sparse wavelet domain
    elif std_est_method == "dual":
        if std_est is None:
            std_est = 0.0
        weights[...] = std_thr * std_est
        reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr)
        prox_dual_op = Threshold(reweight_op.weights)

    # Case3: manual regularization mode, no reweighting
    else:
        weights[...] = mu
        reweight_op = None
        prox_dual_op = Threshold(weights)
        nb_of_reweights = 0

    # Define the Condat Vu optimizer: define the tau and sigma in the
    # Condat-Vu proximal-dual splitting algorithm if not already provided.
    # Check also that the combination of values will lead to convergence.
    norm = linear_op.l2norm(gradient_op.fourier_op.shape)
    lipschitz_cst = gradient_op.spec_rad
    if sigma is None:
        sigma = 0.5
    if tau is None:
        # to avoid numerics troubles with the convergence bound
        eps = 1.0e-8
        # due to the convergence bound
        tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps)
    convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0)

    # Define initial primal and dual solutions
    primal = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex)
    dual = linear_op.op(primal)
    dual[...] = 0.0

    # Welcome message
    if verbose > 0:
        print(condatvu_logo())
        print(" - mu: ", mu)
        print(" - lipschitz constant: ", gradient_op.spec_rad)
        print(" - tau: ", tau)
        print(" - sigma: ", sigma)
        print(" - rho: ", relaxation_factor)
        print(" - std: ", std_est)
        print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test)
        print(" - data: ", gradient_op.obs_data.shape)
        print(" - max iterations: ", max_nb_of_iter)
        print(" - number of reweights: ", nb_of_reweights)
        print(" - primal variable shape: ", primal.shape)
        print(" - dual variable shape: ", dual.shape)
        print("-" * 40)

    # Define the proximity operator
    if add_positivity:
        prox_op = Positivity()
    else:
        prox_op = Identity()

    # Define the cost function
    cost_op = DualGapCost(linear_op=linear_op,
                          initial_cost=1e6,
                          tolerance=1e-4,
                          cost_interval=1,
                          test_range=4,
                          verbose=0,
                          plot_output=None)

    # Define the optimizer
    opt = Condat(x=primal,
                 y=dual,
                 grad=gradient_op,
                 prox=prox_op,
                 prox_dual=prox_dual_op,
                 linear=linear_op,
                 cost=cost_op,
                 rho=relaxation_factor,
                 sigma=sigma,
                 tau=tau,
                 rho_update=None,
                 sigma_update=None,
                 tau_update=None,
                 auto_iterate=False)

    # Perform the first reconstruction
    if verbose > 0:
        print("Starting optimization...")

    for i in range(max_nb_of_iter):
        opt._update()

    opt.x_final = opt._x_new
    opt.y_final = opt._y_new

    # Loop through the number of reweightings
    for reweight_index in range(nb_of_reweights):

        # Generate the new weights following reweighting prescription
        if std_est_method == "primal":
            reweight_op.reweight(linear_op.op(opt._x_new))
        else:
            std_est = reweight_op.reweight(opt._x_new)

        # Welcome message
        if verbose > 0:
            print(" - reweight: ", reweight_index + 1)
            print(" - std: ", std_est)

        # Update the weights in the dual proximity operator
        prox_dual_op.weights = reweight_op.weights

        # Perform optimisation with new weights
        opt.iterate(max_iter=max_nb_of_iter)

    # Goodbye message
    end = time.clock()
    if verbose > 0:
        print(" - final iteration number: ", cost_op._iteration)
        print(" - final cost value: ", cost_op.cost)
        print(" - converged: ", opt.converge)
        print("Done.")
        print("Execution time: ", end - start, " seconds")
        print("-" * 40)

    # Get the final solution
    x_final = opt.x_final
    linear_op.transform.analysis_data = unflatten(opt.y_final,
                                                  linear_op.coeffs_shape)

    return x_final, linear_op.transform
Exemplo n.º 7
0
def sparse_deconv_condatvu(data, psf, n_iter=300, n_reweights=1):
    """Sparse Deconvolution with Condat-Vu

    Parameters
    ----------
    data : np.ndarray
        Input data, 2D image
    psf : np.ndarray
        Input PSF, 2D image
    n_iter : int, optional
        Maximum number of iterations
    n_reweights : int, optional
        Number of reweightings

    Returns
    -------
    np.ndarray deconvolved image

    """

    # Print the algorithm set-up
    print(condatvu_logo())

    # Define the wavelet filters
    filters = (get_cospy_filters(
        data.shape, transform_name='LinearWaveletTransformATrousAlgorithm'))

    # Set the reweighting scheme
    reweight = cwbReweight(get_weights(data, psf, filters))

    # Set the initial variable values
    primal = np.ones(data.shape)
    dual = np.ones(filters.shape)

    # Set the gradient operators
    grad_op = GradBasic(data, lambda x: psf_convolve(x, psf),
                        lambda x: psf_convolve(x, psf, psf_rot=True))

    # Set the linear operator
    linear_op = WaveletConvolve2(filters)

    # Set the proximity operators
    prox_op = Positivity()
    prox_dual_op = SparseThreshold(linear_op, reweight.weights)

    # Set the cost function
    cost_op = costObj([grad_op, prox_op, prox_dual_op],
                      tolerance=1e-6,
                      cost_interval=1,
                      plot_output=True,
                      verbose=False)

    # Set the optimisation algorithm
    alg = Condat(primal,
                 dual,
                 grad_op,
                 prox_op,
                 prox_dual_op,
                 linear_op,
                 cost_op,
                 rho=0.8,
                 sigma=0.5,
                 tau=0.5,
                 auto_iterate=False)

    # Run the algorithm
    alg.iterate(max_iter=n_iter)

    # Implement reweigting
    for rw_num in range(n_reweights):
        print(' - Reweighting: {}'.format(rw_num + 1))
        reweight.reweight(linear_op.op(alg.x_final))
        alg.iterate(max_iter=n_iter)

    # Return the final result
    return alg.x_final
Exemplo n.º 8
0
    def _fit(self):
        weights = self.A
        comp = self.S
        alpha = self.alpha
        #### Source updates set-up ####
        # initialize dual variable and compute Starlet filters for Condat source updates
        dual_var = np.zeros((self.im_hr_shape))
        if self.default_filters:
            self.Phi_filters = get_mr_filters(self.im_hr_shape[:2],
                                              opt=self.opt,
                                              coarse=True)
        rho_phi = np.sqrt(
            np.sum(np.sum(np.abs(self.Phi_filters), axis=(1, 2))**2))

        # Set up source updates, starting with the gradient
        source_grad = grads.SourceGrad(self.obs_data, self.obs_weights,
                                       weights, self.flux, self.sigs,
                                       self.shift_ker_stack,
                                       self.shift_ker_stack_adj, self.upfact,
                                       self.Phi_filters)

        # sparsity in Starlet domain prox (this is actually assuming synthesis form)
        sparsity_prox = rca_prox.StarletThreshold(
            0)  # we'll update to the actual thresholds later

        # and the linear recombination for the positivity constraint
        lin_recombine = rca_prox.LinRecombine(weights, self.Phi_filters)

        #### Weight updates set-up ####
        # gradient
        weight_grad = grads.CoeffGrad(self.obs_data, self.obs_weights, comp,
                                      self.VT, self.flux, self.sigs,
                                      self.shift_ker_stack,
                                      self.shift_ker_stack_adj, self.upfact)

        # cost function
        weight_cost = costObj([weight_grad], verbose=self.modopt_verb)
        source_cost = costObj([source_grad], verbose=self.modopt_verb)

        # k-thresholding for spatial constraint
        iter_func = lambda x: np.floor(np.sqrt(x)) + 1
        coeff_prox = rca_prox.KThreshold(iter_func)

        for k in range(self.nb_iter):
            #### Eigenpsf update ####
            # update gradient instance with new weights...
            source_grad.update_A(weights)

            # ... update linear recombination weights...
            lin_recombine.update_A(weights)

            # ... set optimization parameters...
            beta = source_grad.spec_rad + rho_phi
            tau = 1. / beta
            sigma = 1. / lin_recombine.norm * beta / 2

            # ... update sparsity prox thresholds...
            thresh = utils.reg_format(
                utils.acc_sig_maps(self.shap,
                                   self.shift_ker_stack_adj,
                                   self.sigs,
                                   self.flux,
                                   self.flux_ref,
                                   self.upfact,
                                   weights,
                                   sig_data=np.ones(
                                       (self.shap[2], )) * self.sig_min))
            thresholds = self.ksig * np.sqrt(
                np.array([
                    filter_convolve(Sigma_k**2, self.Phi_filters**2)
                    for Sigma_k in thresh
                ]))

            sparsity_prox.update_threshold(tau * thresholds)

            # and run source update:
            transf_comp = utils.apply_transform(comp, self.Phi_filters)
            if self.nb_reweight:
                reweighter = cwbReweight(thresholds)
                for _ in range(self.nb_reweight):
                    source_optim = optimalg.Condat(transf_comp,
                                                   dual_var,
                                                   source_grad,
                                                   sparsity_prox,
                                                   Positivity(),
                                                   linear=lin_recombine,
                                                   cost=source_cost,
                                                   max_iter=self.nb_subiter_S,
                                                   tau=tau,
                                                   sigma=sigma)
                    transf_comp = source_optim.x_final
                    reweighter.reweight(transf_comp)
                    thresholds = reweighter.weights
            else:
                source_optim = optimalg.Condat(transf_comp,
                                               dual_var,
                                               source_grad,
                                               sparsity_prox,
                                               Positivity(),
                                               linear=lin_recombine,
                                               cost=source_cost,
                                               max_iter=self.nb_subiter_S,
                                               tau=tau,
                                               sigma=sigma)
                transf_comp = source_optim.x_final
            comp = utils.rca_format(
                np.array([
                    filter_convolve(transf_compj, self.Phi_filters, True)
                    for transf_compj in transf_comp
                ]))

            #TODO: replace line below with Fred's component selection (to be extracted from `low_rank_global_src_est_comb`)
            ind_select = range(comp.shape[2])

            #### Weight update ####
            if k < self.nb_iter - 1:
                # update sources and reset iteration counter for K-thresholding
                weight_grad.update_S(comp)
                coeff_prox.reset_iter()
                weight_optim = optimalg.ForwardBackward(
                    alpha,
                    weight_grad,
                    coeff_prox,
                    cost=weight_cost,
                    beta_param=weight_grad.inv_spec_rad,
                    auto_iterate=False)
                weight_optim.iterate(max_iter=self.nb_subiter_weights)
                alpha = weight_optim.x_final
                weights_k = alpha.dot(self.VT)

                # renormalize to break scale invariance
                weight_norms = np.sqrt(np.sum(weights_k**2, axis=1))
                comp *= weight_norms
                weights_k /= weight_norms.reshape(-1, 1)
                #TODO: replace line below with Fred's component selection
                ind_select = range(weights.shape[0])
                weights = weights_k[ind_select, :]
                supports = None  #TODO

        self.A = weights
        self.S = comp
        self.alpha = alpha
        source_grad.MX(transf_comp)
        self.current_rec = source_grad._current_rec