Exemplo n.º 1
0
def tikhonov(*, psfs, measurements, tikhonov_lam=0.129, tikhonov_order=1):
    """Perform Tikhonov regularization based image reconstruction for PSSI.

    Solves x_hat = argmin_{x} { ||Ax-y||_2^2 + lam * ||Dx||_2^2 }. D is the
    discrete derivative operator of order `tikhonov_order`.

    Args:
        psfs (PSFs): PSFs object containing psfs and other csbs state data
        measured_noisy (ndarray): 4d array of noisy measurements
        tikhonov_lam (float): regularization parameter of tikhonov
        tikhonov_order (int): [0,1 or 2] order of the discrete derivative
            operator used in the tikhonov regularization

    Returns:
        4d array of the reconstructed images
    """
    num_sources = psfs.psfs.shape[1]
    rows, cols = measurements.shape[1:]

    # expand psf_dfts
    exp_psf_dfts = np.repeat(np.fft.fft2(
        size_equalizer(psfs.psfs, ref_size=[rows, cols])),
                             psfs.copies.astype(int),
                             axis=0)
    psfs.psf_GAM = block_mul(block_herm(exp_psf_dfts), exp_psf_dfts)
    # DFT of the kernel corresponding to (D^TD)
    LAM = get_LAM(rows=rows, cols=cols, order=tikhonov_order)
    return np.real(
        np.fft.ifft2(
            block_mul(
                block_inv(psfs.psf_GAM + tikhonov_lam *
                          np.einsum('ij,kl', np.eye(num_sources), LAM)),
                block_mul(
                    block_herm(exp_psf_dfts),
                    np.fft.fft2(np.fft.fftshift(measurements, axes=(1, 2)))))))
Exemplo n.º 2
0
    def test_size_equalizer(self):
        # scale up
        bigger = size_equalizer(self.simple_array, (4, 4))
        self.assertEqualNp(
            bigger,
            np.array(((0, 0, 0, 0), (0, 1, 2, 0), (0, 3, 4, 0), (0, 0, 0, 0))))
        # scale down
        self.assertEqualNp(size_equalizer(bigger, (2, 2)), self.simple_array)

        self.assertEqual(
            size_equalizer(np.random.random((5, 5)), (4, 4)).shape, (4, 4))

        # test vectorization
        x = np.random.random((3, 3))
        y = np.repeat(x[np.newaxis, :, :], 4, axis=0)
        self.assertEqual(size_equalizer(y, (2, 2)).shape, (4, 2, 2))
Exemplo n.º 3
0
def default_adjoint(x, psfs):
    [p, aa, bb] = x.shape
    [k, p, ss, ss] = psfs.psfs.shape
    ta, tb = [aa + ss - 1, bb + ss - 1]

    # FIXME: make it work for 2D input, remove selected_psfs
    # FIXME: ;move psf_dft computation to PSFs (make PSFs accept sampling_interval and o
    # output size arguments)

    # reshape psfs
    expanded_psfs = size_equalizer(psfs.psfs, ref_size=[aa, bb])

    expanded_psfs = np.repeat(expanded_psfs, psfs.copies.astype(int), axis=0)
    expanded_psf_dfts = np.fft.fft2(expanded_psfs).transpose((1, 0, 2, 3))

    # ----- forward -----
    im = np.fft.fftshift(np.fft.ifft2(
        np.einsum('ijkl,jkl->ikl', expanded_psf_dfts, np.fft.fft2(x))),
                         axes=(1, 2))
    # im = get_measurements(sources=x, psfs=psfs, real=True)
    return radon_forward(im)
Exemplo n.º 4
0
def motion_deblur(*, sv, coadded, drift):
    """

    """
    pixel_size_um = sv.pixel_size * 1e6

    # width of the final motion blur kernel with CCD pixel size
    kernel_size = 11
    (x, y) = (pixel_size_um * drift[0], pixel_size_um * drift[1])
    N = int(np.ceil(np.max((abs(x), abs(y)))))

    # set the shape of the initial kernel with 1 um pixels based on the estimated drift
    kernel_um = np.zeros((2 * N + 1, 2 * N + 1))

    # calculate the line representing the motion blur
    rr, cc, val = line_aa(
        N + np.round((y / 2)).astype(int),
        N - np.round((x / 2)).astype(int),
        N - np.round((y / 2)).astype(int),
        N + np.round((x / 2)).astype(int),
    )

    # update the kernel with the calculated line
    kernel_um[rr, cc] = val

    # resize the initial 1 um kernel to the given pixel size
    kernel = resize(size_equalizer(kernel_um,
                                   [int(pixel_size_um) * kernel_size] * 2),
                    [kernel_size] * 2,
                    anti_aliasing=True)
    # compute the analytical photon sieve PSF with the given pixel size
    psfs = copy.deepcopy(sv.psfs)

    # convolve the photon sieve PSF with the motion blur kernel to find the "effective blurring kernel"
    psfs.psfs[0, 0] = convolve2d(psfs.psfs[0, 0], kernel, mode='same')

    # normalize the kernel
    psfs.psfs[0, 0] /= psfs.psfs[0, 0].sum()

    # normalize the coadded image (doesn't change anything but helps choosing regularization parameter consistently)
    coadded /= coadded.max()

    # do a tikhonov regularized deblurring on the coadded image to remove the
    # in-frame blur with the calculated "effective blurring kernel"
    recon_tik = tikhonov(measurements=coadded[np.newaxis, :, :],
                         psfs=psfs,
                         tikhonov_lam=1e1,
                         tikhonov_order=1)
    plt.figure()
    plt.imshow(recon_tik[0], cmap='gist_heat')
    plt.title('Deblurred Tikhonov')
    plt.show()

    # do a Plug and Play with BM3D reconstruction with tikhonov initialization
    recon = admm(measurements=coadded[np.newaxis, :, :],
                 psfs=psfs,
                 regularizer=partial(bm3d_pnp),
                 recon_init=recon_tik,
                 plot=False,
                 iternum=5,
                 periter=1,
                 nu=10**-0.0,
                 lam=[10**-0.5])

    plt.figure()
    plt.imshow(recon[0], cmap='gist_heat')
    plt.title('Deblurred')
    plt.show()

    return recon
Exemplo n.º 5
0
def shift_and_sum(frames, drift, mode='full', shift_method='roll'):
    """Coadd frames by given shift

    Args:
        frames (ndarray): input frames to coadd
        drift (ndarray): drift between adjacent frames
        mode (str): zeropad before coadding ('full') or crop to region of
            frame overlap ('crop'), or crop to region of first frame ('first')
        shift_method (str): method for shifting frames ('roll', 'fourier')
        pad (bool): zeropad images before coadding

    Returns:
        (ndarray): coadded images
    """

    assert type(drift) is np.ndarray, "'drift' should be ndarray"

    print('1')
    pad = np.ceil(drift * (len(frames) - 1)).astype(int)
    pad_r = (0, pad[0]) if drift[0] > 0 else (-pad[0], 0)
    pad_c = (0, pad[1]) if drift[1] > 0 else (-pad[1], 0)
    print('2')
    frames_ones = np.pad(
        np.ones(frames.shape, dtype=int),
        ((0, 0), pad_r, pad_c),
        mode='constant',
    )
    print('3')
    frames_pad = np.pad(frames, ((0, 0), pad_r, pad_c), mode='constant')

    print('3')
    summation = np.zeros(frames_pad[0].shape, dtype='complex128')
    print('4')
    summation_scale = np.zeros(frames_pad[0].shape, dtype=int)
    print('5')

    for time_diff, (frame,
                    frame_ones) in tqdm(enumerate(zip(frames_pad,
                                                      frames_ones))):
        shift = np.array(drift) * (time_diff + 1)
        if shift_method == 'roll':
            integer_shift = np.floor(shift).astype(int)
            shifted = roll(frame, (integer_shift[0], integer_shift[1]))
            shifted_ones = roll(frame_ones,
                                (integer_shift[0], integer_shift[1]))
        elif shift_method == 'fourier':
            shifted = np.fft.ifftn(
                fourier_shift(np.fft.fftn(frame), (shift[0], shift[1])))
            shifted_ones = np.fft.ifftn(
                fourier_shift(np.fft.fftn(frame_ones), (shift[0], shift[1])))
        else:
            raise Exception('Invalid shift_method')
        summation += shifted
        summation_scale += shifted_ones

    summation /= summation_scale

    if mode == 'crop':
        summation = size_equalizer(
            summation,
            np.array(frames_pad[0].shape).astype(int) -
            2 * np.ceil(drift * (len(frames_pad) - 1)).astype(int))
    elif mode == 'full':
        pass
    elif mode == 'first':
        summation_scale[summation_scale == 0] = 1
        summation = summation[:frames.shape[1], :frames.shape[2]]
    elif mode == 'center':
        summation_scale[summation_scale == 0] = 1
        summation = size_equalizer(summation, frames.shape[1:])
    else:
        raise Exception('Invalid mode')

    return summation.real
Exemplo n.º 6
0
from mas.forward_model import get_measurements, add_noise, size_equalizer
from mas.psf_generator import PSFs, PhotonSieve, circ_incoherent_psf
from mas.deconvolution import tikhonov, admm
from mas.deconvolution.admm import bm3d_pnp
from mas.plotting import plotter4d
from mas.data import strands_ext
from functools import partial
from skimage.measure import compare_ssim as ssim

# %% meas ------------------------

source_wavelengths = np.array([33.4e-9, 33.5e-9])
num_sources = len(source_wavelengths)
meas_size = [160,160]
sources = strands_ext[0:num_sources]
sources_ref = size_equalizer(sources, ref_size=meas_size)
ps = PhotonSieve(diameter=16e-2, smallest_hole_diameter=7e-6)

# generate psfs
psfs = PSFs(
    ps,
    sampling_interval=3.5e-6,
    measurement_wavelengths=source_wavelengths,
    source_wavelengths=source_wavelengths,
    psf_generator=circ_incoherent_psf,
    # image_width=psf_width,
    num_copies=1
)

# ############## INVERSE CRIME ##############
# measured = get_measurements(
Exemplo n.º 7
0
from imageio import imread
import numpy as np
import matplotlib.pyplot as plt
from scipy.misc import face
from scipy.ndimage import rotate
from mas.forward_model import downsample, upsample, size_equalizer

offset = (500, 500)
x = face(gray=True)
# roi = upsample(downsample(x[500:600, 500:600], factor=5), factor=5)
crop_roi = x[:-1, :-1]
roll_roi = np.roll(np.roll(x, -offset[0], axis=0), -offset[1], axis=1)

x_fft = np.fft.fft2(x)
# crop_fft = size_equalizer(np.fft.fft2(crop_roi), x.shape)
crop_fft = np.fft.fft2(size_equalizer(crop_roi, x.shape))
roll_fft = np.fft.fft2(roll_roi)

crop_csd = (
        np.multiply(x_fft, np.conj(crop_fft)) /
        np.abs(np.multiply(x_fft, np.conj(crop_fft)))
)
crop_phase = np.abs(np.fft.ifft2(crop_csd))

roll_csd = (
        np.multiply(x_fft, np.conj(roll_fft)) /
        np.abs(np.multiply(x_fft, np.conj(roll_fft)))
)
roll_phase = np.abs(np.fft.ifft2(roll_csd))

def zoom_plot(im, offset, width):
Exemplo n.º 8
0
import numpy as np
import imageio, pickle, h5py

num_instances = 1
tikhonov_order = 1
tikhonov_lam = 1e-2
if type(tikhonov_lam) is np.int or type(tikhonov_lam) is np.float:
    tikhonov_lam = [tikhonov_lam]
psf_width = 201
# source_wavelengths = np.array([9.4e-9])
source_wavelengths = np.array([33.4e-9, 33.5e-9])
num_sources = len(source_wavelengths)

source1 = size_equalizer(
    np.array(
        h5py.File(
            '/home/kamo/Research/mas/nanoflare_videos/NanoMovie0_2000strands_94.h5'
        )['NanoMovie0_2000strands_94'])[0], (160, 160))
source2 = size_equalizer(
    np.array(
        h5py.File(
            '/home/kamo/Research/mas/nanoflare_videos/NanoMovie0_2000strands_94.h5'
        )['NanoMovie0_2000strands_94'])[250], (160, 160))
# source1 = rectangle_adder(image=np.zeros((100,100)), size=(30,30), upperleft=(35,10))
# source2 = 10 * rectangle_adder(image=np.zeros((100,100)), size=(30,30), upperleft=(35,60))
# source1 = readsav('/home/kamo/Research/mas/nanoflare_videos/old/movie0_1250strands_335.sav',python_dict=True)['movie'][500]
# source2 = readsav('/home/kamo/Research/mas/nanoflare_videos/old/movie0_1250strands_94.sav',python_dict=True)['movie'][500]
[aa, bb] = source1.shape
meas_size = tuple(np.array([aa, bb]) - 0)
sources = np.zeros((len(source_wavelengths), 1, aa, bb))
sources[0, 0] = source1 / source1.max()
Exemplo n.º 9
0
Arquivo: admm.py Projeto: zoey0919/MAS
def admm(*,
         sources=None,
         psfs,
         measurements,
         regularizer,
         recon_init,
         iternum,
         plot=True,
         periter=5,
         nu,
         lam,
         **kwargs):
    """Function that implements PSSI deconvolution with the ADMM algorithm using
    the specified regularization method.

    Args:
        sources (ndarray): 3d array of sources
        psfs (PSFs): PSFs object containing psfs and related data
        measurements (ndarray): 3d array of measurements
        regularizer (function): function that specifies the regularization type
        recon_init (ndarray): initialization for the reconstructed image(s)
        iternum (int): number of iterations of ADMM
        plot (boolean): if set to True, display the reconstructions as the iterations go
        periter (int): iteration period of displaying the reconstructions
        nu (float): augmented Lagrangian parameter (step size) of ADMM
        lam (list): regularization parameter of dimension num_sources
        kwargs (dict): keyword arguments to be passed to the regularizers

    Returns:
        ndarray of reconstructed images
    """
    num_sources = psfs.psfs.shape[1]
    rows, cols = measurements.shape[1:]
    if type(lam) is np.float or type(lam) is np.float64 or type(lam) is np.int:
        lam = np.ones(num_sources) * lam

    ################## initialize the primal/dual variables ##################
    primal1 = recon_init
    primal2 = None
    dual = None

    ################# pre-compute some arrays for efficiency #################
    psfs.psf_dfts = np.repeat(np.fft.fft2(
        size_equalizer(psfs.psfs, ref_size=[rows, cols])),
                              psfs.copies.astype(int),
                              axis=0)
    psfs.psf_GAM = block_mul(block_herm(psfs.psf_dfts), psfs.psf_dfts)
    psfdfts_h_meas = block_mul(block_herm(
        psfs.psf_dfts), np.fft.fft2(np.fft.fftshift(
            measurements,
            axes=(1, 2))))  # this is reshaped FA^Ty term where F is DFT matrix
    SIG_inv = get_SIG_inv(regularizer=regularizer,
                          measurements=measurements,
                          psfs=psfs,
                          nu=nu,
                          **kwargs)

    for iter in range(iternum):
        ######################### PRIMAL 1,2 UPDATES #########################
        primal1, primal2, pre_primal2, dual = regularizer(
            psfs=psfs,
            measurements=measurements,
            psfdfts_h_meas=psfdfts_h_meas,
            SIG_inv=SIG_inv,
            primal1=primal1,
            primal2=primal2,
            dual=dual,
            nu=nu,
            lam=lam,
            **kwargs)

        ########################### DUAL UPDATE ###########################
        dual += (pre_primal2 - primal2)

        if plot == True and ((iter + 1) % periter == 0 or iter == iternum - 1):
            deconv_plotter(sources=sources, recons=primal1, iter=iter)

    return primal1
Exemplo n.º 10
0
    def __init__(self,
                 sieve,
                 *,
                 source_wavelengths=np.array([33.4, 33.5]) * 1e-9,
                 measurement_wavelengths=30,
                 image_width=1001,
                 cropped_width=None,
                 energy_ratio=0.9995,
                 num_copies=1,
                 psf_generator=circ_incoherent_psf,
                 sampling_interval=3.5e-6,
                 grid_width=10,
                 zero_mean=False):

        focal_lengths = sieve.diameter * sieve.smallest_hole_diameter / source_wavelengths
        dofs = 2 * sieve.smallest_hole_diameter**2 / source_wavelengths
        if type(measurement_wavelengths) is int:
            approx_start = sieve.diameter * sieve.smallest_hole_diameter / (
                max(focal_lengths) + grid_width * max(dofs))
            approx_end = sieve.diameter * sieve.smallest_hole_diameter / (
                min(focal_lengths) - grid_width * min(dofs))
            measurement_wavelengths = np.linspace(approx_start, approx_end,
                                                  measurement_wavelengths)
            measurement_wavelengths = np.insert(
                measurement_wavelengths,
                np.searchsorted(measurement_wavelengths, source_wavelengths),
                source_wavelengths)

        psfs = np.empty((0, len(source_wavelengths), image_width, image_width))

        # generate incoherent measurements for each wavelength and plane location
        bar = tqdm(total=len(measurement_wavelengths) *
                   len(source_wavelengths),
                   desc='PSFs',
                   leave=None,
                   position=1)
        for m, measurement_wavelength in enumerate(measurement_wavelengths):
            psf_group = np.empty((0, image_width, image_width))
            for n, source_wavelength in enumerate(source_wavelengths):
                bar.update(1)
                # sys.stdout.write('\033[K')
                # print(
                #     'PSF {}/{}\r'.format(
                #         m * len(source_wavelengths) + n + 1,
                #         len(measurement_wavelengths) * len(source_wavelengths)
                #     ),
                #     end=''
                # )
                psf = psf_generator(
                    sieve=sieve,
                    source_wavelength=float(source_wavelength),
                    measurement_wavelength=measurement_wavelength,
                    image_width=image_width,
                    source_distance=float('inf'),
                    sampling_interval=float(sampling_interval))

                psf_group = np.append(psf_group, [psf], axis=0)
            psfs = np.append(psfs, [psf_group], axis=0)
        bar.close()

        if cropped_width is not None:
            psfs = size_equalizer(psfs, [cropped_width, cropped_width])
        else:
            width0 = size_compressor(psfs[0, -1], energy_ratio=energy_ratio)
            width1 = size_compressor(psfs[-1, 0], energy_ratio=energy_ratio)
            width = max(width0, width1)
            psfs = size_equalizer(psfs, [width, width])

        if zero_mean:
            psfs -= psfs.mean(axis=(2, 3))[:, :, np.newaxis, np.newaxis]

        self.psfs = psfs
        self.psf_dfts = np.fft.fft2(psfs)
        self.num_copies = num_copies
        self.copies = np.ones((len(measurement_wavelengths))) * num_copies
        self.measurement_wavelengths = measurement_wavelengths
        self.source_wavelengths = source_wavelengths
        self.sampling_interval = sampling_interval
        self.cropped_width = cropped_width
        self.copies_history = []