def set_sparse_weights(data_shape, psf, **kwargs):
    """Set the sparsity weights

    This method defines the weights for thresholding in the sparse domain and
    add them to the keyword arguments. It additionally defines the shape of the
    dual variable.

    Parameters
    ----------
    data_shape : tuple
        Shape of the input data array
    psf : np.ndarray
        PSF data (2D or 3D array)

    Returns
    -------
    dict Updated keyword arguments

    """

    # Convolve the PSF with the wavelet filters
    if kwargs['psf_type'] == 'fixed':

        filter_conv = (filter_convolve(np.rot90(psf, 2),
                                       kwargs['wavelet_filters']))

        filter_norm = np.array([
            norm(a) * b * np.ones(data_shape[1:])
            for a, b in zip(filter_conv, kwargs['wave_thresh_factor'])
        ])

        filter_norm = np.array([filter_norm for i in xrange(data_shape[0])])

    else:

        filter_conv = (filter_convolve_stack(np.rot90(psf, 2),
                                             kwargs['wavelet_filters']))

        filter_norm = np.array([[
            norm(b) * c * np.ones(data_shape[1:])
            for b, c in zip(a, kwargs['wave_thresh_factor'])
        ] for a in filter_conv])

    # Define a reweighting instance
    kwargs['reweight'] = cwbReweight(kwargs['noise_est'] * filter_norm)

    # Set the shape of the dual variable
    dual_shape = ([kwargs['wavelet_filters'].shape[0]] + list(data_shape))
    dual_shape[0], dual_shape[1] = dual_shape[1], dual_shape[0]
    kwargs['dual_shape'] = dual_shape

    return kwargs
def reconstruct_map(data, noise_est, layout, psf=None, psf_pcs=None,
                    psf_coef=None, wavelet_levels=1, wavelet_opt=None,
                    wave_thresh_factor=1, lowr_thresh_factor=1,
                    n_reweights=0, mode='all'):

    ######
    # SET THE GRADIENT OPERATOR

    if not isinstance(psf, type(None)):
            grad_op = FixedPSF(data, psf)

    else:
            grad_op = PixelVariantPSF(data, psf_pcs, psf_coef)

    print ' - Spetral Radius:', grad_op.spec_rad

    ######
    # SET THE LINEAR OPERATOR

    # linear_op = LinearCombo([Identity()])
    # linear_op = LinearCombo([Wavelet(data.shape, wavelet_levels, wavelet_opt)])
    linear_op = LinearCombo([Wavelet(data.shape, wavelet_levels, wavelet_opt),
                             Identity])

    # print linear_op.operators[0].filters.shape
    # exit()

    ######
    # ESTIMATE THE NOISE

    if not isinstance(psf, type(None)):
        noise_est *= norm(convolve_mr_filters(np.rot90(psf, 2), linear_op.operators[0].filters))
        # noise_est *= norm(linear_op.operators[0].op(np.rot90(psf, 2)))
        noise_est *= np.ones(linear_op.operators[0].filters.shape)

    else:
        noise_est = np.sqrt(convolve_mr_filters(noise_est,
                            linear_op.operators[0].filters ** 2))

    print noise_est.shape

    ######
    # SET THE WEIGHTS

    rw = cwbReweight(wave_thresh_factor * noise_est)

    ######
    # SET RHO, SIGMA AND TAU

    l1norm_filters = linear_op.operators[0].l1norm

    tau = 1.0 / (grad_op.spec_rad + l1norm_filters)
    sigma = tau
    rho = 0.5

    print ' - 1/tau - sigma||L||^2 >= beta/2:', (1 / tau - sigma *
                                                 l1norm_filters ** 2 >=
                                                 grad_op.spec_rad / 2)

    ######
    # SET THE SHAPE OF THE DUAL

    dual_shape = [linear_op.operators[0].filters.shape[0]] + list(data.shape)

    ######
    # INITALISE THE PRIMAL AND DUAL VALUES

    primal = np.ones(data.shape)
    # dual = np.array([np.ones(dual_shape)])
    dual = np.array([np.ones(dual_shape), np.ones(data.shape)])

    grad_op.get_grad(primal)

    ######
    # SET THE PROXIMITY OPERATORS

    # prox_op = Identity()
    prox_op = Positive()

    # prox_dual_op = ProximityCombo([LowRankMatrix(layout, thresh_factor=3,
    #                                grad_class=grad_op,
    #                                grad_factor=1/sigma)])
    # prox_dual_op = ProximityCombo([Threshold(rw.weights / sigma)])
    prox_dual_op = ProximityCombo([Threshold(rw.weights / sigma),
                                   LowRankMatrix(layout,
                                                 thresh_factor=lowr_thresh_factor,
                                                 grad_class=grad_op,
                                                 grad_factor=1/sigma)])

    ######
    # SET THE COST FUNCTION.

    cost_op = costTest(data, grad_op.MX)
    # cost_op = posThresh(data, sigma, (grad_op, linear_op), rw.weights)

    ######
    # PERFORM THE OPTIMISATION

    opt = Condat(primal, dual, grad_op, prox_op, prox_dual_op, linear_op,
                 cost_op, sigma=sigma, tau=tau, print_cost=True,
                 auto_iterate=False)
    opt.iterate(max_iter=150)

    ######
    # REWEIGHTING

    # for i in range(n_reweights):
    #
    #     rw.reweight(linear_op.op(opt.x_new))
    #     prox_dual_op.update_weights(rw.weights / sigma)
    #     cost_op.update_weights(rw.weights)
    #     opt.iterate()

    return opt.x_final
def reconstruct(data, noise_est, layout, psf=None, psf_pcs=None,
                psf_coef=None, wavelet_levels=1, wavelet_opt=None,
                wave_thresh_factor=1, lowr_thresh_factor=1,
                n_reweights=0, mode='all', data_format='cube'):

    ######
    # SET THE GRADIENT OPERATOR

    if not isinstance(psf, type(None)):
            grad_op = FixedPSF(data, psf, data_format=data_format)

    else:
            grad_op = PixelVariantPSF(data, psf_pcs, psf_coef,
                                      data_format=data_format)

    print ' - Spetral Radius:', grad_op.spec_rad

    ######
    # SET THE LINEAR OPERATOR

    if mode == 'all':
        linear_op = LinearCombo([Wavelet(data.shape, wavelet_levels,
                                 wavelet_opt, data_format=data_format),
                                 Identity()])

    elif mode == 'lowr':
        linear_op = LinearCombo([Identity()])

    elif mode == 'wave':
        linear_op = LinearCombo([Wavelet(data.shape, wavelet_levels,
                                 wavelet_opt, data_format=data_format)])

    ######
    # ESTIMATE THE NOISE

    if mode in ('all', 'wave'):

        if not isinstance(psf, type(None)):
            noise_est *= norm(convolve_mr_filters(np.rot90(psf, 2),
                              linear_op.operators[0].filters))
            # noise_est *= norm(linear_op.operators[0].op(np.rot90(psf, 2)))
            if data_format == 'map':
                noise_est *= np.ones(linear_op.operators[0].filters.shape)
            else:
                noise_est *= np.ones([data.shape[0]] +
                                     list(linear_op.operators[0].filters.shape))

        else:
            noise_est = np.sqrt(convolve_cube(noise_est,
                                linear_op.operators[0].filters ** 2))

        print noise_est.shape

    ######
    # SET THE WEIGHTS

        rw = cwbReweight(wave_thresh_factor * noise_est)

    ######
    # SET THE SHAPE OF THE DUAL

        dual_shape = [linear_op.operators[0].filters.shape[0]] + list(data.shape)

        if data_format == 'cube':
            dual_shape[0], dual_shape[1] = dual_shape[1], dual_shape[0]

    ######
    # SET RHO, SIGMA AND TAU

    l1norm_filters = linear_op.operators[0].l1norm

    tau = 1.0 / (grad_op.spec_rad + l1norm_filters)
    sigma = tau
    rho = 0.5

    print ' - 1/tau - sigma||L||^2 >= beta/2:', (1 / tau - sigma *
                                                 l1norm_filters ** 2 >=
                                                 grad_op.spec_rad / 2)

    ######
    # INITALISE THE PRIMAL AND DUAL VALUES

    # 1 Primal Operator (Positivity or Identity)
    primal = np.ones(data.shape)

    if mode == 'all':

        # 2 Dual Operators (Wavelet + Threshold and Identity + LowRankMatrix)
        dual = np.empty(2, dtype=np.ndarray)
        dual[0] = np.ones(dual_shape)
        dual[1] = np.ones(data.shape)

    elif mode == 'lowr':

        # 1 Dual Operator (Identity + LowRankMatrix)
        dual = np.empty(1, dtype=np.ndarray)
        dual[0] = np.ones(data.shape)

    elif mode == 'wave':

        # 1 Dual Operator (Wavelet + Threshold or Identity + LowRankMatrix)
        dual = np.empty(1, dtype=np.ndarray)
        dual[0] = np.ones(dual_shape)

    # Get the initial gradient value
    grad_op.get_grad(primal)

    print ' - Primal Variable Shape:', primal.shape
    print ' - Dual Variable Shape:', dual.shape

    ######
    # SET THE PROXIMITY OPERATORS

    # prox_op = Identity()
    prox_op = Positive()

    if mode == 'all':

        prox_dual_op = ProximityCombo([Threshold(rw.weights / sigma),
                                       LowRankMatrix(layout,
                                       thresh_factor=lowr_thresh_factor,
                                       grad_class=grad_op,
                                       grad_factor=1/sigma,
                                       data_format=data_format)])

    elif mode == 'lowr':

        prox_dual_op = ProximityCombo([LowRankMatrix(layout,
                                       thresh_factor=lowr_thresh_factor,
                                       grad_class=grad_op,
                                       grad_factor=1/sigma,
                                       data_format=data_format)])

    elif mode == 'wave':

        prox_dual_op = ProximityCombo([Threshold(rw.weights / sigma),])

    ######
    # SET THE COST FUNCTION.

    cost_op = costTest(data, grad_op.MX)
    # cost_op = posThresh(data, sigma, (grad_op, linear_op), rw.weights)

    ######
    # PERFORM THE OPTIMISATION

    opt = Condat(primal, dual, grad_op, prox_op, prox_dual_op, linear_op,
                 cost_op, sigma=sigma, tau=tau, print_cost=True,
                 auto_iterate=False)
    opt.iterate(max_iter=150)

    ######
    # REWEIGHTING

    # for i in range(n_reweights):
    #
    #     rw.reweight(linear_op.op(opt.x_new))
    #     prox_dual_op.update_weights(rw.weights / sigma)
    #     cost_op.update_weights(rw.weights)
    #     opt.iterate()

    return opt.x_final, opt.y_final