Ejemplo n.º 1
0
def reconstruct_map(data, noise_est, layout, psf=None, psf_pcs=None,
                    psf_coef=None, wavelet_levels=1, wavelet_opt=None,
                    wave_thresh_factor=1, lowr_thresh_factor=1,
                    n_reweights=0, mode='all'):

    ######
    # SET THE GRADIENT OPERATOR

    if not isinstance(psf, type(None)):
            grad_op = FixedPSF(data, psf)

    else:
            grad_op = PixelVariantPSF(data, psf_pcs, psf_coef)

    print ' - Spetral Radius:', grad_op.spec_rad

    ######
    # SET THE LINEAR OPERATOR

    # linear_op = LinearCombo([Identity()])
    # linear_op = LinearCombo([Wavelet(data.shape, wavelet_levels, wavelet_opt)])
    linear_op = LinearCombo([Wavelet(data.shape, wavelet_levels, wavelet_opt),
                             Identity])

    # print linear_op.operators[0].filters.shape
    # exit()

    ######
    # ESTIMATE THE NOISE

    if not isinstance(psf, type(None)):
        noise_est *= norm(convolve_mr_filters(np.rot90(psf, 2), linear_op.operators[0].filters))
        # noise_est *= norm(linear_op.operators[0].op(np.rot90(psf, 2)))
        noise_est *= np.ones(linear_op.operators[0].filters.shape)

    else:
        noise_est = np.sqrt(convolve_mr_filters(noise_est,
                            linear_op.operators[0].filters ** 2))

    print noise_est.shape

    ######
    # SET THE WEIGHTS

    rw = cwbReweight(wave_thresh_factor * noise_est)

    ######
    # SET RHO, SIGMA AND TAU

    l1norm_filters = linear_op.operators[0].l1norm

    tau = 1.0 / (grad_op.spec_rad + l1norm_filters)
    sigma = tau
    rho = 0.5

    print ' - 1/tau - sigma||L||^2 >= beta/2:', (1 / tau - sigma *
                                                 l1norm_filters ** 2 >=
                                                 grad_op.spec_rad / 2)

    ######
    # SET THE SHAPE OF THE DUAL

    dual_shape = [linear_op.operators[0].filters.shape[0]] + list(data.shape)

    ######
    # INITALISE THE PRIMAL AND DUAL VALUES

    primal = np.ones(data.shape)
    # dual = np.array([np.ones(dual_shape)])
    dual = np.array([np.ones(dual_shape), np.ones(data.shape)])

    grad_op.get_grad(primal)

    ######
    # SET THE PROXIMITY OPERATORS

    # prox_op = Identity()
    prox_op = Positive()

    # prox_dual_op = ProximityCombo([LowRankMatrix(layout, thresh_factor=3,
    #                                grad_class=grad_op,
    #                                grad_factor=1/sigma)])
    # prox_dual_op = ProximityCombo([Threshold(rw.weights / sigma)])
    prox_dual_op = ProximityCombo([Threshold(rw.weights / sigma),
                                   LowRankMatrix(layout,
                                                 thresh_factor=lowr_thresh_factor,
                                                 grad_class=grad_op,
                                                 grad_factor=1/sigma)])

    ######
    # SET THE COST FUNCTION.

    cost_op = costTest(data, grad_op.MX)
    # cost_op = posThresh(data, sigma, (grad_op, linear_op), rw.weights)

    ######
    # PERFORM THE OPTIMISATION

    opt = Condat(primal, dual, grad_op, prox_op, prox_dual_op, linear_op,
                 cost_op, sigma=sigma, tau=tau, print_cost=True,
                 auto_iterate=False)
    opt.iterate(max_iter=150)

    ######
    # REWEIGHTING

    # for i in range(n_reweights):
    #
    #     rw.reweight(linear_op.op(opt.x_new))
    #     prox_dual_op.update_weights(rw.weights / sigma)
    #     cost_op.update_weights(rw.weights)
    #     opt.iterate()

    return opt.x_final
Ejemplo n.º 2
0
def reconstruct(data, noise_est, layout, psf=None, psf_pcs=None,
                psf_coef=None, wavelet_levels=1, wavelet_opt=None,
                wave_thresh_factor=1, lowr_thresh_factor=1,
                n_reweights=0, mode='all', data_format='cube'):

    ######
    # SET THE GRADIENT OPERATOR

    if not isinstance(psf, type(None)):
            grad_op = FixedPSF(data, psf, data_format=data_format)

    else:
            grad_op = PixelVariantPSF(data, psf_pcs, psf_coef,
                                      data_format=data_format)

    print ' - Spetral Radius:', grad_op.spec_rad

    ######
    # SET THE LINEAR OPERATOR

    if mode == 'all':
        linear_op = LinearCombo([Wavelet(data.shape, wavelet_levels,
                                 wavelet_opt, data_format=data_format),
                                 Identity()])

    elif mode == 'lowr':
        linear_op = LinearCombo([Identity()])

    elif mode == 'wave':
        linear_op = LinearCombo([Wavelet(data.shape, wavelet_levels,
                                 wavelet_opt, data_format=data_format)])

    ######
    # ESTIMATE THE NOISE

    if mode in ('all', 'wave'):

        if not isinstance(psf, type(None)):
            noise_est *= norm(convolve_mr_filters(np.rot90(psf, 2),
                              linear_op.operators[0].filters))
            # noise_est *= norm(linear_op.operators[0].op(np.rot90(psf, 2)))
            if data_format == 'map':
                noise_est *= np.ones(linear_op.operators[0].filters.shape)
            else:
                noise_est *= np.ones([data.shape[0]] +
                                     list(linear_op.operators[0].filters.shape))

        else:
            noise_est = np.sqrt(convolve_cube(noise_est,
                                linear_op.operators[0].filters ** 2))

        print noise_est.shape

    ######
    # SET THE WEIGHTS

        rw = cwbReweight(wave_thresh_factor * noise_est)

    ######
    # SET THE SHAPE OF THE DUAL

        dual_shape = [linear_op.operators[0].filters.shape[0]] + list(data.shape)

        if data_format == 'cube':
            dual_shape[0], dual_shape[1] = dual_shape[1], dual_shape[0]

    ######
    # SET RHO, SIGMA AND TAU

    l1norm_filters = linear_op.operators[0].l1norm

    tau = 1.0 / (grad_op.spec_rad + l1norm_filters)
    sigma = tau
    rho = 0.5

    print ' - 1/tau - sigma||L||^2 >= beta/2:', (1 / tau - sigma *
                                                 l1norm_filters ** 2 >=
                                                 grad_op.spec_rad / 2)

    ######
    # INITALISE THE PRIMAL AND DUAL VALUES

    # 1 Primal Operator (Positivity or Identity)
    primal = np.ones(data.shape)

    if mode == 'all':

        # 2 Dual Operators (Wavelet + Threshold and Identity + LowRankMatrix)
        dual = np.empty(2, dtype=np.ndarray)
        dual[0] = np.ones(dual_shape)
        dual[1] = np.ones(data.shape)

    elif mode == 'lowr':

        # 1 Dual Operator (Identity + LowRankMatrix)
        dual = np.empty(1, dtype=np.ndarray)
        dual[0] = np.ones(data.shape)

    elif mode == 'wave':

        # 1 Dual Operator (Wavelet + Threshold or Identity + LowRankMatrix)
        dual = np.empty(1, dtype=np.ndarray)
        dual[0] = np.ones(dual_shape)

    # Get the initial gradient value
    grad_op.get_grad(primal)

    print ' - Primal Variable Shape:', primal.shape
    print ' - Dual Variable Shape:', dual.shape

    ######
    # SET THE PROXIMITY OPERATORS

    # prox_op = Identity()
    prox_op = Positive()

    if mode == 'all':

        prox_dual_op = ProximityCombo([Threshold(rw.weights / sigma),
                                       LowRankMatrix(layout,
                                       thresh_factor=lowr_thresh_factor,
                                       grad_class=grad_op,
                                       grad_factor=1/sigma,
                                       data_format=data_format)])

    elif mode == 'lowr':

        prox_dual_op = ProximityCombo([LowRankMatrix(layout,
                                       thresh_factor=lowr_thresh_factor,
                                       grad_class=grad_op,
                                       grad_factor=1/sigma,
                                       data_format=data_format)])

    elif mode == 'wave':

        prox_dual_op = ProximityCombo([Threshold(rw.weights / sigma),])

    ######
    # SET THE COST FUNCTION.

    cost_op = costTest(data, grad_op.MX)
    # cost_op = posThresh(data, sigma, (grad_op, linear_op), rw.weights)

    ######
    # PERFORM THE OPTIMISATION

    opt = Condat(primal, dual, grad_op, prox_op, prox_dual_op, linear_op,
                 cost_op, sigma=sigma, tau=tau, print_cost=True,
                 auto_iterate=False)
    opt.iterate(max_iter=150)

    ######
    # REWEIGHTING

    # for i in range(n_reweights):
    #
    #     rw.reweight(linear_op.op(opt.x_new))
    #     prox_dual_op.update_weights(rw.weights / sigma)
    #     cost_op.update_weights(rw.weights)
    #     opt.iterate()

    return opt.x_final, opt.y_final