Пример #1
0
def tikhonov(*, psfs, measurements, tikhonov_lam=0.129, tikhonov_order=1):
    """Perform Tikhonov regularization based image reconstruction for PSSI.

    Solves x_hat = argmin_{x} { ||Ax-y||_2^2 + lam * ||Dx||_2^2 }. D is the
    discrete derivative operator of order `tikhonov_order`.

    Args:
        psfs (PSFs): PSFs object containing psfs and other csbs state data
        measured_noisy (ndarray): 4d array of noisy measurements
        tikhonov_lam (float): regularization parameter of tikhonov
        tikhonov_order (int): [0,1 or 2] order of the discrete derivative
            operator used in the tikhonov regularization

    Returns:
        4d array of the reconstructed images
    """
    num_sources = psfs.psfs.shape[1]
    rows, cols = measurements.shape[1:]

    # expand psf_dfts
    exp_psf_dfts = np.repeat(np.fft.fft2(
        size_equalizer(psfs.psfs, ref_size=[rows, cols])),
                             psfs.copies.astype(int),
                             axis=0)
    psfs.psf_GAM = block_mul(block_herm(exp_psf_dfts), exp_psf_dfts)
    # DFT of the kernel corresponding to (D^TD)
    LAM = get_LAM(rows=rows, cols=cols, order=tikhonov_order)
    return np.real(
        np.fft.ifft2(
            block_mul(
                block_inv(psfs.psf_GAM + tikhonov_lam *
                          np.einsum('ij,kl', np.eye(num_sources), LAM)),
                block_mul(
                    block_herm(exp_psf_dfts),
                    np.fft.fft2(np.fft.fftshift(measurements, axes=(1, 2)))))))
Пример #2
0
    def test_block_mul(self):
        from mas.block import block_mul

        # matrix x matrix
        result = block_mul(self.x, self.y)
        self.assertEqual(result.shape, (3, 3, 5, 5))

        # matrix x col vec
        result = block_mul(self.x, self.y[:, 0, :, :])
        self.assertEqual(result.shape, (3, 5, 5))
Пример #3
0
def iteration_end(psfs, lowest_psf_group_index):
    """
    """
    psfs.initialized_data['GAM'] -= block_mul(
        block_herm(psfs.initialized_data['psf_dfts']
                   [lowest_psf_group_index:lowest_psf_group_index + 1]),
        psfs.initialized_data['psf_dfts']
        [lowest_psf_group_index:lowest_psf_group_index + 1])
Пример #4
0
def cost(psfs, psf_group_index, **kwargs):
    """
    """

    iteration_SIG_e_dft = (SIG_e_dft(psfs, kwargs['lam']) - block_mul(
        block_herm(
            psfs.initialized_data['psf_dfts'][psf_group_index:psf_group_index +
                                              1]),
        psfs.initialized_data['psf_dfts'][psf_group_index:psf_group_index +
                                          1]))
    if kwargs['no_dc']:
        cov = block_inv(iteration_SIG_e_dft)
        cov[:, :, 0, 0] = 0  # set the DC values to zero
        return np.real(np.sum(np.trace(cov)))

    return np.real(np.sum(np.trace(block_inv(iteration_SIG_e_dft))))
Пример #5
0
def primal1_update_gaussian(*, pre_primal1, SIG_inv, ATy, nu):
    """Function that updates the first primal variable of ADMM based on the
    least squares data fidelity term which inherently assumes that the
    measurements have additive Gaussian noise.

    Args:
        pre_primal1 (ndarray): intermediate ADMM variable which is typically
            W(primal2 - dual) where W is the regularization transform
        SIG_inv (ndarray): spectrum of (A^TA+nu*W^TW + ...)^{-1}
        ATy (ndarray): FA^Ty term where F is DFT matrix
        nu (float): augmented Lagrangian parameter of ADMM (step size)
        kwargs (dict): keyword arguments of parameters related to the regularizer
    Returns:
        ndarray of updated primal1
    """
    return np.real(
        np.fft.ifft2(block_mul(SIG_inv, ATy + nu * np.fft.fft2(pre_primal1))))
Пример #6
0
def init(psfs, **kwargs):
    """
    """
    _, _, rows, cols = psfs.psfs.shape
    # psf_dfts = np.fft.fft2(psfs.psfs, axes=(2, 3))
    psf_dfts = psfs.psf_dfts

    initialized_data = {
        "psf_dfts":
        psf_dfts,
        "GAM":
        block_mul(
            block_herm(
                # scale rows of psf_dfts by copies
                # split across sqrt to prevent rounding error
                np.einsum('i,ijkl->ijkl', np.sqrt(psfs.copies), psf_dfts)),
            np.einsum('i,ijkl->ijkl', np.sqrt(psfs.copies), psf_dfts),
        ),
        "LAM":
        get_LAM(rows=rows, cols=cols, order=kwargs['order'])
    }

    psfs.initialized_data = initialized_data
Пример #7
0
def strollr(*, sources, psfs, measurements, recon_init, iternum, periter, lr,
            theta, s, lam, patch_shape, transform, learning, window_size,
            group_size, group_size_s):
    """Function that implements PSSI deconvolution using both sparsifying transform
    learning and emposing low rankness on the grouped patches of the images.

    The algorithm details can be found on [arXiv:1808.01316]. For efficiency,
    multiprocessing has been used in some parts of the algorithm.

    Args:
        sources (ndarray): 4d array of sources
        psfs (PSFs): PSFs object containing psfs and related data
        measurements (ndarray): 4d array of measurements
        recon_init (ndarray): initialization for the reconstructed image(s)
        iternum (int): number of iterations of ADMM
        periter (int): iteration period of displaying the reconstructions
        lr (float): augmented Lagrangian parameter of the low-rank term
        theta (float): penalty parameter of the low-rank term
        s (float): augmented Lagrangian parameter of the sparsity term
        lam (float): penalty parameter of the sparsity term
        patch_shape (tuple): tuple of the shape of the patches used
        transform (ndarray): ndarray of the sparsifying transform used
        learning (bool): boolean variable of whether the transform gets updated
        window_size (tuple): tuple of the window size over which the group of
            similar patches to the centered reference patch is searched
        group_size (int): the number of patches in each group (of similar patches)
        group_size_s (int): the number of patches in each group (for sparsity)
    """
    num_sources = psfs.psfs.shape[1]
    rows, cols = measurements.shape[1:]
    psize = np.size(np.empty(patch_shape))
    if type(theta) is np.float:
        theta = np.ones(num_sources) * theta
    if type(lam) is np.float:
        lam = np.ones(num_sources) * lam

    recon = recon_init

    ################# pre-compute some arrays for efficiency #################
    patcher_spectrum = np.einsum('ij,kl->ijkl', np.eye(num_sources),
                                 psize * np.ones((rows, cols)))
    psfs.psf_dfts = np.repeat(np.fft.fft2(
        size_equalizer(psfs.psfs, ref_size=[rows, cols])),
                              psfs.copies.astype(int),
                              axis=0)
    psfs.psf_GAM = block_mul(block_herm(psfs.psf_dfts), psfs.psf_dfts)
    psfdfts_h_meas = block_mul(block_herm(
        psfs.psf_dfts), np.fft.fft2(np.fft.fftshift(
            measurements,
            axes=(1, 2))))  # this is reshaped FA^Ty term where F is DFT matrix
    SIG_inv = block_inv(psfs.psf_GAM + (s + lr) * patcher_spectrum)

    for iter in range(iternum):

        # ----- Low-Rank Approximation -----
        patches = patch_extractor(recon, patch_shape=patch_shape)

        patch_means = np.mean(patches, axis=0)
        patches_zeromean = (patches - patch_means).T
        indices = np.zeros((patches.shape[1], group_size), dtype=np.int)

        pool = multiprocessing.Pool()
        lowrank_i = functools.partial(lowrank,
                                      patches_zeromean=patches_zeromean,
                                      window_size=window_size,
                                      imsize=(rows, cols),
                                      threshold=theta,
                                      group_size=group_size)
        D, indices = zip(*pool.map(lowrank_i, np.arange(patches.shape[1])))
        D = np.array(D)
        indices = np.array(indices, dtype=np.int)
        pool.close()
        pool.join()

        D += np.einsum('ik,j->ijk', patch_means[indices],
                       np.ones(patches.shape[0]))

        if s > 0:
            # ----- Sparse Coding -----
            patches_3d = np.zeros(
                (patches.shape[0] * group_size_s, patches.shape[1]))

            for i in range(group_size_s):
                patches_3d[i * patches.shape[0]:(i + 1) *
                           patches.shape[0], :] = patches[:, indices[:, i]]

            sparse_codes = transform @ patches_3d
            #FIXME - implement variable threshold for each image
            # sparse_indices = (sparse_codes > recon.lam).astype(np.int)
            # if not sparse_indices.any():
            #     sparse_codes = np.zeros_like(sparse_codes)
            #
            # sparse_codes = sparse_codes * sparse_indices

            for i in range(num_sources):
                ind = np.arange(i * rows * cols, (i + 1) * rows * cols)
                sparse_codes[:, ind] = hard_thresholding(sparse_codes[:, ind],
                                                         threshold=np.sqrt(
                                                             lam[i] / s))

            if learning is True:
                u, s, v_T = np.linalg.svd(sparse_codes @ patches_3d.T)
                transform = u @ v_T

            # ----- Image Update -----
            Whb1 = np.zeros_like(patches)
            Whb = transform.T @ sparse_codes
            for i in range(group_size_s):
                Whb1 = indsum(
                    Whb1,
                    Whb[i * patches.shape[0]:(i + 1) * patches.shape[0], :],
                    indices[:, i])

            indvals = np.array(
                list(Counter(indices[:, :group_size_s].flatten()).values()))
            indkeys = np.argsort(
                np.array(
                    list(Counter(indices[:, :group_size_s].flatten()).keys())))
            Whb1 = Whb1 / indvals[indkeys]

            Fc = np.fft.fft2(
                patch_aggregator(Whb1,
                                 patch_shape=patch_shape,
                                 image_shape=(num_sources, rows, cols)))
        else:
            Fc = np.zeros_like(psfdfts_h_meas)

        if lr > 0:
            VhD = np.zeros_like(patches)
            for i in range(group_size):
                VhD = indsum(VhD, D[:, :, i].T, indices[:, i])

            indvals = np.array(list(Counter(indices.flatten()).values()))
            indkeys = np.argsort(
                np.array(list(Counter(indices.flatten()).keys())))
            VhD = VhD / indvals[indkeys]

            Fd = np.fft.fft2(
                patch_aggregator(VhD,
                                 patch_shape=patch_shape,
                                 image_shape=(num_sources, rows, cols)))
        else:
            Fd = np.zeros_like(psfdfts_h_meas)

        recon = np.real(
            np.fft.ifft2(block_mul(SIG_inv,
                                   s * Fc + lr * Fd + psfdfts_h_meas)))

        if (iter + 1) % periter == 0 or iter == iternum - 1:
            deconv_plotter(sources=sources, recons=recon, iter=iter)

    return recon
Пример #8
0
def sparsepatch(*,
                sources,
                psfs,
                measurements,
                recon_init,
                iternum,
                plot=True,
                periter,
                nu,
                lam,
                patch_shape,
                transform,
                learning):
    """Function that implements PSSI deconvolution with a patch based sparsifying
    transform for sparse recovery.

    P1 and P3 formulations described in [doi:10.1137/141002293] have been
    implemented without the transform update step.


    Args:
        sources (ndarray): 4d array of sources
        psfs (PSFs): PSFs object containing psfs and related data
        measurements (ndarray): 4d array of measurements
        recon_init (ndarray): initialization for the reconstructed image(s)
        iternum (int): number of iterations of ADMM
        plot (boolean): if set to True, display the reconstructions as the iterations go
        periter (int): iteration period of displaying the reconstructions
        nu (float): augmented Lagrangian parameter
        lam (float): penalty parameter of the sparsity term
        patch_shape (tuple): tuple of the shape of the patches used
        transform (ndarray): ndarray of the sparsifying transform used
        learning (bool): boolean variable of whether the transform gets updated
    """
    num_sources = psfs.psfs.shape[1]
    rows, cols = measurements.shape[1:]
    psize = np.size(np.empty(patch_shape))
    # mse_inner = np.zeros((num_sources,recon.maxiter))
    if type(lam) is np.float or type(lam) is np.int:
        lam = np.ones(num_sources) * lam

    ################## initialize the reconstruction ##################
    recon = recon_init

    ################# pre-compute some arrays for efficiency #################
    psfs.psf_dfts = np.repeat(np.fft.fft2(
        size_equalizer(psfs.psfs, ref_size=[rows, cols])),
                              psfs.copies.astype(int),
                              axis=0)
    psfs.psf_GAM = block_mul(block_herm(psfs.psf_dfts), psfs.psf_dfts)
    psfdfts_h_meas = block_mul(block_herm(
        psfs.psf_dfts), np.fft.fft2(np.fft.fftshift(
            measurements,
            axes=(1, 2))))  # this is reshaped FA^Ty term where F is DFT matrix
    LAM = psize * np.ones((rows, cols))
    spectrum = nu * np.einsum('ij,kl->ijkl', np.eye(num_sources), LAM)
    SIG_inv = block_inv(psfs.psf_GAM + spectrum)

    for iter in range(iternum):

        # ----- Sparse Coding -----
        patches = patch_extractor(recon, patch_shape=patch_shape)

        sparse_codes = transform @ patches
        for i in range(num_sources):
            sparse_codes[:, i * rows * cols:(i + 1) * rows *
                         cols] = hard_thresholding(
                             sparse_codes[:, i * rows * cols:(i + 1) * rows *
                                          cols],
                             threshold=np.sqrt(lam[i] / nu))

        # ----- Image Update -----
        Fc = np.fft.fft2(
            patch_aggregator(transform.conj().T @ sparse_codes,
                             patch_shape=patch_shape,
                             image_shape=(num_sources, rows, cols)))

        recon = np.real(
            np.fft.ifft2(block_mul(SIG_inv, nu * Fc + psfdfts_h_meas)))

        if learning is True:
            u, s, vT = np.linalg.svd(sparse_codes @ patches.T)
            transform = u @ vT

        # mse_inner[:,iter] = np.mean(
        #         (sources - recon.reconstructed)**2,
        #         axis=(1, 2, 3)
        # )
        #
        # dfid = recon.nu*(1/np.size(measurements))*np.sum(abs(
        #     block_mul(
        #     psfs.selected_psf_dfts, np.fft.fft2(recon.reconstructed)
        # ) - np.fft.fft2(measurements)
        # )**2)
        #
        # sp_error = np.sum((recon.transform@patch_extractor(
        #     recon.reconstructed,
        #     patch_shape=recon.patch_shape
        # )-sparse_codes)**2)

        if plot == True and ((iter + 1) % periter == 0 or iter == iternum - 1):
            deconv_plotter(sources=sources, recons=recon, iter=iter)

    return recon
Пример #9
0
plotter4d(sources, title='Orig', figsize=(5.6, 8), cmap='gist_heat')

[k, num_sources, aa, bb] = psfs.selected_psfs.shape[:2] + sources.shape[2:]
LAM = get_LAM(rows=aa, cols=bb, order=tikhonov_order)
ssim = np.zeros((len(tikhonov_lam), num_instances, num_sources))
psnr = np.zeros((len(tikhonov_lam), num_instances, num_sources))
for i in range(len(tikhonov_lam)):
    SIG_inv = block_inv(psfs.selected_GAM + tikhonov_lam[i] *
                        np.einsum('ij,kl', np.eye(num_sources), LAM))
    for j in range(num_instances):
        # DFT of the kernel corresponding to (D^TD)
        recon = np.real(
            np.fft.ifft2(
                block_mul(
                    SIG_inv,
                    block_mul(psfs.selected_psf_dfts_h,
                              np.fft.fft2(measured_noisy_instances[j])))))
        ###### COMPUTE PERFORMANCE METRICS ######
        mse = np.mean((sources - recon)**2, axis=(1, 2, 3))
        psnr[i, j] = 20 * np.log10(
            np.max(sources, axis=(1, 2, 3)) / np.sqrt(mse))
        for p in range(num_sources):
            ssim[i, j, p] = compare_ssim(sources[p, 0],
                                         recon[p, 0],
                                         data_range=np.max(recon[p, 0]) -
                                         np.min(recon[p, 0]))

plotter4d(recon,
          cmap='gist_heat',
          figsize=(5.6, 8),
          title='Recon. SSIM={}\n Recon. PSNR={}'.format(
Пример #10
0
def admm(*,
         sources=None,
         psfs,
         measurements,
         regularizer,
         recon_init,
         iternum,
         plot=True,
         periter=5,
         nu,
         lam,
         **kwargs):
    """Function that implements PSSI deconvolution with the ADMM algorithm using
    the specified regularization method.

    Args:
        sources (ndarray): 3d array of sources
        psfs (PSFs): PSFs object containing psfs and related data
        measurements (ndarray): 3d array of measurements
        regularizer (function): function that specifies the regularization type
        recon_init (ndarray): initialization for the reconstructed image(s)
        iternum (int): number of iterations of ADMM
        plot (boolean): if set to True, display the reconstructions as the iterations go
        periter (int): iteration period of displaying the reconstructions
        nu (float): augmented Lagrangian parameter (step size) of ADMM
        lam (list): regularization parameter of dimension num_sources
        kwargs (dict): keyword arguments to be passed to the regularizers

    Returns:
        ndarray of reconstructed images
    """
    num_sources = psfs.psfs.shape[1]
    rows, cols = measurements.shape[1:]
    if type(lam) is np.float or type(lam) is np.float64 or type(lam) is np.int:
        lam = np.ones(num_sources) * lam

    ################## initialize the primal/dual variables ##################
    primal1 = recon_init
    primal2 = None
    dual = None

    ################# pre-compute some arrays for efficiency #################
    psfs.psf_dfts = np.repeat(np.fft.fft2(
        size_equalizer(psfs.psfs, ref_size=[rows, cols])),
                              psfs.copies.astype(int),
                              axis=0)
    psfs.psf_GAM = block_mul(block_herm(psfs.psf_dfts), psfs.psf_dfts)
    psfdfts_h_meas = block_mul(block_herm(
        psfs.psf_dfts), np.fft.fft2(np.fft.fftshift(
            measurements,
            axes=(1, 2))))  # this is reshaped FA^Ty term where F is DFT matrix
    SIG_inv = get_SIG_inv(regularizer=regularizer,
                          measurements=measurements,
                          psfs=psfs,
                          nu=nu,
                          **kwargs)

    for iter in range(iternum):
        ######################### PRIMAL 1,2 UPDATES #########################
        primal1, primal2, pre_primal2, dual = regularizer(
            psfs=psfs,
            measurements=measurements,
            psfdfts_h_meas=psfdfts_h_meas,
            SIG_inv=SIG_inv,
            primal1=primal1,
            primal2=primal2,
            dual=dual,
            nu=nu,
            lam=lam,
            **kwargs)

        ########################### DUAL UPDATE ###########################
        dual += (pre_primal2 - primal2)

        if plot == True and ((iter + 1) % periter == 0 or iter == iternum - 1):
            deconv_plotter(sources=sources, recons=primal1, iter=iter)

    return primal1