Ejemplo n.º 1
0
def tikhonov(*, psfs, measurements, tikhonov_lam=0.129, tikhonov_order=1):
    """Perform Tikhonov regularization based image reconstruction for PSSI.

    Solves x_hat = argmin_{x} { ||Ax-y||_2^2 + lam * ||Dx||_2^2 }. D is the
    discrete derivative operator of order `tikhonov_order`.

    Args:
        psfs (PSFs): PSFs object containing psfs and other csbs state data
        measured_noisy (ndarray): 4d array of noisy measurements
        tikhonov_lam (float): regularization parameter of tikhonov
        tikhonov_order (int): [0,1 or 2] order of the discrete derivative
            operator used in the tikhonov regularization

    Returns:
        4d array of the reconstructed images
    """
    num_sources = psfs.psfs.shape[1]
    rows, cols = measurements.shape[1:]

    # expand psf_dfts
    exp_psf_dfts = np.repeat(np.fft.fft2(
        size_equalizer(psfs.psfs, ref_size=[rows, cols])),
                             psfs.copies.astype(int),
                             axis=0)
    psfs.psf_GAM = block_mul(block_herm(exp_psf_dfts), exp_psf_dfts)
    # DFT of the kernel corresponding to (D^TD)
    LAM = get_LAM(rows=rows, cols=cols, order=tikhonov_order)
    return np.real(
        np.fft.ifft2(
            block_mul(
                block_inv(psfs.psf_GAM + tikhonov_lam *
                          np.einsum('ij,kl', np.eye(num_sources), LAM)),
                block_mul(
                    block_herm(exp_psf_dfts),
                    np.fft.fft2(np.fft.fftshift(measurements, axes=(1, 2)))))))
Ejemplo n.º 2
0
def iteration_end(psfs, lowest_psf_group_index):
    """
    """
    psfs.initialized_data['GAM'] -= block_mul(
        block_herm(psfs.initialized_data['psf_dfts']
                   [lowest_psf_group_index:lowest_psf_group_index + 1]),
        psfs.initialized_data['psf_dfts']
        [lowest_psf_group_index:lowest_psf_group_index + 1])
Ejemplo n.º 3
0
def cost(psfs, psf_group_index, **kwargs):
    """
    """

    iteration_SIG_e_dft = (SIG_e_dft(psfs, kwargs['lam']) - block_mul(
        block_herm(
            psfs.initialized_data['psf_dfts'][psf_group_index:psf_group_index +
                                              1]),
        psfs.initialized_data['psf_dfts'][psf_group_index:psf_group_index +
                                          1]))
    if kwargs['no_dc']:
        cov = block_inv(iteration_SIG_e_dft)
        cov[:, :, 0, 0] = 0  # set the DC values to zero
        return np.real(np.sum(np.trace(cov)))

    return np.real(np.sum(np.trace(block_inv(iteration_SIG_e_dft))))
Ejemplo n.º 4
0
def init(psfs, **kwargs):
    """
    """
    _, _, rows, cols = psfs.psfs.shape
    # psf_dfts = np.fft.fft2(psfs.psfs, axes=(2, 3))
    psf_dfts = psfs.psf_dfts

    initialized_data = {
        "psf_dfts":
        psf_dfts,
        "GAM":
        block_mul(
            block_herm(
                # scale rows of psf_dfts by copies
                # split across sqrt to prevent rounding error
                np.einsum('i,ijkl->ijkl', np.sqrt(psfs.copies), psf_dfts)),
            np.einsum('i,ijkl->ijkl', np.sqrt(psfs.copies), psf_dfts),
        ),
        "LAM":
        get_LAM(rows=rows, cols=cols, order=kwargs['order'])
    }

    psfs.initialized_data = initialized_data
Ejemplo n.º 5
0
Archivo: admm.py Proyecto: zoey0919/MAS
def admm(*,
         sources=None,
         psfs,
         measurements,
         regularizer,
         recon_init,
         iternum,
         plot=True,
         periter=5,
         nu,
         lam,
         **kwargs):
    """Function that implements PSSI deconvolution with the ADMM algorithm using
    the specified regularization method.

    Args:
        sources (ndarray): 3d array of sources
        psfs (PSFs): PSFs object containing psfs and related data
        measurements (ndarray): 3d array of measurements
        regularizer (function): function that specifies the regularization type
        recon_init (ndarray): initialization for the reconstructed image(s)
        iternum (int): number of iterations of ADMM
        plot (boolean): if set to True, display the reconstructions as the iterations go
        periter (int): iteration period of displaying the reconstructions
        nu (float): augmented Lagrangian parameter (step size) of ADMM
        lam (list): regularization parameter of dimension num_sources
        kwargs (dict): keyword arguments to be passed to the regularizers

    Returns:
        ndarray of reconstructed images
    """
    num_sources = psfs.psfs.shape[1]
    rows, cols = measurements.shape[1:]
    if type(lam) is np.float or type(lam) is np.float64 or type(lam) is np.int:
        lam = np.ones(num_sources) * lam

    ################## initialize the primal/dual variables ##################
    primal1 = recon_init
    primal2 = None
    dual = None

    ################# pre-compute some arrays for efficiency #################
    psfs.psf_dfts = np.repeat(np.fft.fft2(
        size_equalizer(psfs.psfs, ref_size=[rows, cols])),
                              psfs.copies.astype(int),
                              axis=0)
    psfs.psf_GAM = block_mul(block_herm(psfs.psf_dfts), psfs.psf_dfts)
    psfdfts_h_meas = block_mul(block_herm(
        psfs.psf_dfts), np.fft.fft2(np.fft.fftshift(
            measurements,
            axes=(1, 2))))  # this is reshaped FA^Ty term where F is DFT matrix
    SIG_inv = get_SIG_inv(regularizer=regularizer,
                          measurements=measurements,
                          psfs=psfs,
                          nu=nu,
                          **kwargs)

    for iter in range(iternum):
        ######################### PRIMAL 1,2 UPDATES #########################
        primal1, primal2, pre_primal2, dual = regularizer(
            psfs=psfs,
            measurements=measurements,
            psfdfts_h_meas=psfdfts_h_meas,
            SIG_inv=SIG_inv,
            primal1=primal1,
            primal2=primal2,
            dual=dual,
            nu=nu,
            lam=lam,
            **kwargs)

        ########################### DUAL UPDATE ###########################
        dual += (pre_primal2 - primal2)

        if plot == True and ((iter + 1) % periter == 0 or iter == iternum - 1):
            deconv_plotter(sources=sources, recons=primal1, iter=iter)

    return primal1