Example #1
0
    def ft1D(x,y,*,axis=0,zero_DC=False):
        """Takes in x and y = y(x), and returns k and the Fourier transform 
        of y(x) -> f(k) along a single (1D) axis
        Handles all of the annoyances of fftshift and ifftshift, and gets the 
        normalization right
        Args:
            x (np.ndarray) : independent variable, must be 1D
            y (np.ndarray) : dependent variable, can be nD
        Kwargs:
            axis (int) : which axis to perform FFT
            zero_DC (bool) : if true, sets f(0) = 0
    """
        dx = x[1]-x[0]
        k = fftshift(fftfreq(x.size,d=dx))*2*np.pi
        fft_norm = dx

        shifted_x = ifftshift(x)
        if np.isclose(shifted_x[0],0):
            f = fft(ifftshift(y,axes=(axis)),axis=axis)*fft_norm
        else:
            f = fft(y,axis=axis)*fft_norm

        if zero_DC:
            nd_slice = [slice(None) for i in range(len(f.shape))]
            nd_slice[axis] = slice(0,1,1)
            nd_slice = tuple(nd_slice)
            f[nd_slice] = 0

        f = fftshift(f,axes=(axis))

        return k, f
Example #2
0
    def polarization_to_signal(self,P_of_t_in,*,return_polarization=False,
                               local_oscillator_number = -1):
        """This function generates a frequency-resolved signal from a polarization field
local_oscillator_number - usually the local oscillator will be the last pulse in the list self.efields"""
        pulse_time = self.pulse_times[local_oscillator_number]
        if self.gamma != 0:
            exp_factor = np.exp(-self.gamma * (self.t-pulse_time))
            P_of_t_in *= exp_factor
        P_of_t = P_of_t_in
        if return_polarization:
            return P_of_t

        if local_oscillator_number == 'impulsive':
            efield = np.exp(1j*self.w*(pulse_time))
        else:
            pulse_time_ind = np.argmin(np.abs(self.t - pulse_time))

            pulse_start_ind = pulse_time_ind - self.size//2
            pulse_end_ind = pulse_time_ind + self.size//2 + self.size%2

            t_slice = slice(pulse_start_ind, pulse_end_ind,None)
            efield = np.zeros(self.t.size,dtype='complex')
            efield[t_slice] = self.efields[local_oscillator_number]
            efield = fftshift(ifft(ifftshift(efield)))*len(P_of_t)*(self.t[1]-self.t[0])/np.sqrt(2*np.pi)

        if P_of_t.size%2:
            P_of_t = P_of_t[:-1]
            efield = efield[:len(P_of_t)]

        P_of_w = fftshift(ifft(ifftshift(P_of_t)))*len(P_of_t)*(self.t[1]-self.t[0])/np.sqrt(2*np.pi)
        
        signal = np.imag(P_of_w * np.conjugate(efield))
        return signal
Example #3
0
    def ift1D(k,f,*,axis=0,zero_DC=False):
        """Takes in k and f = f(k), and returns x and the discrete Fourier 
        transform of f(k) -> y(x).
        Handles all of the annoyances of fftshift and ifftshift, and gets the 
        normalization right
        Args:
            x (np.ndarray): independent variable
            y (np.ndarray): dependent variable
        Kwargs:
            axis (int) : which axis to perform FFT
    """
        dk = k[1]-k[0]
        x = fftshift(fftfreq(k.size,d=dk))*2*np.pi
        ifft_norm = dk*k.size/(2*np.pi)

        shifted_k = ifftshift(k)
        if np.isclose(shifted_k[0],0):
            y = ifft(ifftshift(f,axes=(axis)),axis=axis)*ifft_norm
        else:
            y = ifft(f,axis=axis)*ifft_norm

        if zero_DC:
            nd_slice = [slice(None) for i in range(len(y.shape))]
            nd_slice[axis] = slice(0,1,1)
            nd_slice = tuple(nd_slice)
            y[nd_slice] = 0

        y = fftshift(y,axes=(axis))

        return x, y
Example #4
0
    def polarization_to_signal(self,
                               P_of_t_in,
                               *,
                               return_polarization=False,
                               local_oscillator_number=-1,
                               undersample_factor=1):
        """This function generates a frequency-resolved signal from a polarization field
           local_oscillator_number - usually the local oscillator will be the last pulse 
                                     in the list self.efields"""
        undersample_slice = slice(None, None, undersample_factor)
        P_of_t = P_of_t_in[undersample_slice]
        t = self.t[undersample_slice]
        dt = t[1] - t[0]
        pulse_time = self.pulse_times[local_oscillator_number]
        if self.gamma != 0:
            exp_factor = np.exp(-self.gamma * np.abs(t - pulse_time))
            P_of_t *= exp_factor
        if self.sigma_I != 0:
            inhomogeneous = np.exp(-(t - pulse_time)**2 * self.sigma_I**2 / 2)
            P_of_t *= inhomogeneous
        if return_polarization:
            return P_of_t

        pulse_time_ind = np.argmin(np.abs(self.t - pulse_time))
        efield = np.zeros(self.t.size, dtype='complex')

        if self.efield_t.size == 1:
            # Impulsive limit
            efield[pulse_time_ind] = self.efields[local_oscillator_number]
            efield = fftshift(ifft(ifftshift(efield))) * efield.size / np.sqrt(
                2 * np.pi)
        else:
            pulse_start_ind = pulse_time_ind - self.size // 2
            pulse_end_ind = pulse_time_ind + self.size // 2 + self.size % 2

            t_slice = slice(pulse_start_ind, pulse_end_ind, None)

            efield[t_slice] = self.efields[local_oscillator_number]
            efield = fftshift(ifft(ifftshift(efield))) * self.t.size * (
                self.t[1] - self.t[0]) / np.sqrt(2 * np.pi)

        # if P_of_t.size%2:
        #     P_of_t = P_of_t[:-1]
        #     t = t[:-1]
        halfway = self.w.size // 2
        pm = self.w.size // (2 * undersample_factor)
        efield_min_ind = halfway - pm
        efield_max_ind = halfway + pm + self.w.size % 2
        efield = efield[efield_min_ind:efield_max_ind]

        P_of_w = fftshift(ifft(
            ifftshift(P_of_t))) * len(P_of_t) * dt / np.sqrt(2 * np.pi)

        signal = np.imag(P_of_w * np.conjugate(efield))
        return signal
Example #5
0
    def add_gaussian_linewidth(self, sigma):
        self.old_signal = self.signal.copy()

        sig_tau_t = fftshift(fft(ifftshift(self.old_signal, axes=(-1)),
                                 axis=-1),
                             axes=(-1))
        sig_tau_t = sig_tau_t * (
            np.exp(-self.t**2 / (2 * sigma**2))[np.newaxis, np.newaxis, :] *
            np.exp(-self.t21_array**2 /
                   (2 * sigma**2))[:, np.newaxis, np.newaxis])
        sig_tau_w = fftshift(ifft(ifftshift(sig_tau_t, axes=(-1)), axis=-1),
                             axes=(-1))
        self.signal = sig_tau_w
def stCoefs_1d(signals, filters, block_size=100, second_order=False):
    """
  compute 1D signals scattering coefficents
  input
    signals: 2D numpy array (# samples, signal_length)
    filters: 2D numpy array (# filters, signal_length)
  output
    filtered_signal: 
    3D numpy array (# samples, # filters, signal_length)
  """

    signal_length = signals.shape[-1]
    signal_shape = signals.shape[:-1]
    signal_size = reduce(lambda x, y: x * y, signal_shape)
    signals.shape = (signal_size, signal_length)
    filter_length = filters.shape[-1]
    filter_shape = filters.shape[:-1]
    filter_size = reduce(lambda x, y: x * y, filter_shape)
    filters.shape = (filter_size, filter_length)

    f_signals = np.fft.fft(signals, axis=1)
    f_filters = np.fft.fft(fft.ifftshift(filters, axes=(1,)), axis=1)
    if not second_order:
        f_conv = f_signals[:, np.newaxis, :] * f_filters[np.newaxis, :, :]
    else:
        f_conv = np.zeros(signal_size, filter_size, signal_length)

    filtered = fft.ifft(f_conv, axis=2)

    return filtered
Example #7
0
 def ft_interface(a, s, axes, norm, **kwargs):
     # call fft and shift the result
     shftax = axes[:-1] if omitlast else axes
     if forward:
         return fftmod.fftshift(ftfunc(a, s, axes, norm, **kwargs), shftax)
     else:
         return ftfunc(fftmod.ifftshift(a, shftax), s, axes, norm, **kwargs)
Example #8
0
 def fillPowerFromTemplate(self, twodPower):
     """
     Fill the power2D.powerMap with the input power array.
     
     Parameters
     ----------
     twodPower : array_like
         The 2D data array specifying the template power to fill with.
     """
     tdp = twodPower.copy()
     
     # interpolate
     if tdp.Nx != self.Nx or tdp.Ny != self.Ny:
         
         # first divide out the area factor
         area = tdp.Nx*tdp.Ny*tdp.pixScaleX*tdp.pixScaleY
         tdp.powerMap *= (tdp.Nx*tdp.Ny)**2 / area
         
         lx_shifted  = fftshift(tdp.lx)
         ly_shifted  = fftshift(tdp.ly)
         tdp_shifted = fftshift(tdp.powerMap)
     
         f_interp = interp2d(lx_shifted, ly_shifted, tdp_shifted)
     
         cl_new = f_interp(fftshift(self.lx), fftshift(self.ly))
         cl_new = ifftshift(cl_new)
         
         area = self.Nx*self.Ny*self.pixScaleX*self.pixScaleY
         cl_new *= area / (self.Nx*self.Ny*1.)**2
         
         self.powerMap[:] = cl_new[:]
     else:
         self.powerMap[:] = tdp.powerMap[:]
def convolve(arr1, arr2, dx=None, axes=None):
    """
    Performs a centred convolution of input arrays

    Parameters
    ----------
    arr1, arr2 : `numpy.ndarray`
        Arrays to be convolved. If dimensions are not equal then 1s are appended
        to the lower dimensional array. Otherwise, arrays must be broadcastable.
    dx : float > 0, list of float, or `None` , optional
        Grid spacing of input arrays. Output is scaled by
        `dx**max(arr1.ndim, arr2.ndim)`. default=`None` applies no scaling
    axes : tuple of ints or `None`, optional
        Choice of axes to convolve. default=`None` convolves all axes

    """
    if arr2.ndim > arr1.ndim:
        arr1, arr2 = arr2, arr1
        if axes is None:
            axes = range(arr2.ndim)
    arr2 = arr2.reshape(arr2.shape + (1, ) * (arr1.ndim - arr2.ndim))

    if dx is None:
        dx = 1
    elif isscalar(dx):
        dx = dx**(len(axes) if axes is not None else arr1.ndim)
    else:
        dx = prod(dx)

    arr1 = fftn(arr1, axes=axes)
    arr2 = fftn(ifftshift(arr2), axes=axes)
    out = ifftn(arr1 * arr2, axes=axes) * dx
    return require(out, requirements="CA")
Example #10
0
def decomposition_fft(X, filter, **kwargs):
    """Decompose a 2d input field into multiple spatial scales by using the Fast 
    Fourier Transform (FFT) and a bandpass filter.
    
    Parameters
    ----------
    X : array_like
      Two-dimensional array containing the input field. All values are required 
      to be finite.
    filter : dict
      A filter returned by any method implemented in bandpass_filters.py.
    
    Other Parameters
    ----------------
    MASK : array_like
      Optional mask to use for computing the statistics for the cascade levels. 
      Pixels with MASK==False are excluded from the computations.
    
    Returns
    -------
    out : ndarray
      A dictionary described in the module documentation. The parameter n is 
      determined from the filter (see bandpass_filters.py).
    
    """
    MASK = kwargs.get("MASK", None)

    if len(X.shape) != 2:
        raise ValueError("the input is not two-dimensional array")
    if MASK is not None and MASK.shape != X.shape:
        raise ValueError("dimension mismatch between X and MASK: X.shape=%s, MASK.shape=%s" % \
          (str(X.shape), str(MASK.shape)))
    if X.shape != filter["weights_2d"].shape[1:3]:
        raise ValueError(
            "dimension mismatch between X and filter: X.shape=%s, filter['weights_2d'].shape[1:3]=%s"
            % (str(X.shape), str(filter["weights_2d"].shape[1:3])))
    if np.any(~np.isfinite(X)):
        raise ValueError("X contains non-finite values")

    result = {}
    means = []
    stds = []

    F = fft.fftshift(fft.fft2(X, **fft_kwargs))
    X_decomp = []
    for k in range(len(filter["weights_1d"])):
        W_k = filter["weights_2d"][k, :, :]
        X_ = np.real(fft.ifft2(fft.ifftshift(F * W_k), **fft_kwargs))
        X_decomp.append(X_)

        if MASK is not None:
            X_ = X_[MASK]
        means.append(np.mean(X_))
        stds.append(np.std(X_))

    result["cascade_levels"] = np.stack(X_decomp)
    result["means"] = means
    result["stds"] = stds

    return result
Example #11
0
 def generate_psf(self,
                  sphase=slice(4, None, None),
                  size=None,
                  zsize=None,
                  zrange=None):
     """Make a perfect PSF"""
     # make a copy of the internal model
     model = copy.copy(self.model)
     # update zsize or zrange
     if zsize is not None:
         model.zsize = zsize
     if zrange is not None:
         model.zrange = zrange
     # generate the PSF from the reconstructed phase
     if not hasattr(self, 'zd_result'):
         self.fit_to_zernikes(120)
     model._gen_psf(ifftshift(self.zd_result.complex_pupil(sphase=sphase)))
     # reshpae PSF if needed in x/y dimensions
     psf = model.PSFi
     nz, ny, nx = psf.shape
     assert ny == nx, "Something is very wrong"
     if size is not None:
         if nx < size:
             # if size is too small, pad it out.
             psf = fft_pad(psf, (nz, size, size), mode="constant")
         elif nx > size:
             # if size is too big, crop it
             lb = size // 2
             hb = size - lb
             myslice = slice(nx // 2 - lb, nx // 2 + hb)
             psf = psf[:, myslice, myslice]
     # return data
     return psf
Example #12
0
def stCoefs_1d(signals, filters, block_size=100, second_order=False):
    """
  compute 1D signals scattering coefficents
  input
    signals: 2D numpy array (# samples, signal_length)
    filters: 2D numpy array (# filters, signal_length)
  output
    filtered_signal: 
    3D numpy array (# samples, # filters, signal_length)
  """

    signal_length = signals.shape[-1]
    signal_shape = signals.shape[:-1]
    signal_size = reduce(lambda x, y: x * y, signal_shape)
    signals.shape = (signal_size, signal_length)
    filter_length = filters.shape[-1]
    filter_shape = filters.shape[:-1]
    filter_size = reduce(lambda x, y: x * y, filter_shape)
    filters.shape = (filter_size, filter_length)

    f_signals = np.fft.fft(signals, axis=1)
    f_filters = np.fft.fft(fft.ifftshift(filters, axes=(1, )), axis=1)
    if not second_order:
        f_conv = f_signals[:, np.newaxis, :] * f_filters[np.newaxis, :, :]
    else:
        f_conv = np.zeros(signal_size, filter_size, signal_length)

    filtered = fft.ifft(f_conv, axis=2)

    return filtered
Example #13
0
    def check_efield_resolution(self, efield, *, plot_fields=False):
        efield_tail = np.max(np.abs([efield[0], efield[-1]]))

        if efield_tail > np.max(np.abs(efield)) / 100:
            warnings.warn(
                'Consider using larger time interval, pulse does not decay to less than 1% of maximum value in time domain'
            )

        efield_fft = fftshift(fft(ifftshift(efield))) * self.dt
        efield_fft_tail = np.max(np.abs([efield_fft[0], efield_fft[-1]]))

        if efield_fft_tail > np.max(np.abs(efield_fft)) / 100:
            warnings.warn(
                '''Consider using smaller value of dt, pulse does not decay to less than 1% of maximum value in frequency domain'''
            )

        if plot_fields:
            fig, axes = plt.subplots(1, 2)
            l1, l2, = axes[0].plot(self.efield_t, np.real(efield),
                                   self.efield_t, np.imag(efield))
            plt.legend([l1, l2], ['Real', 'Imag'])
            axes[1].plot(self.efield_w, np.real(efield_fft), self.efield_w,
                         np.imag(efield_fft))

            axes[0].set_ylabel('Electric field Amp')
            axes[0].set_xlabel('Time ($\omega_0^{-1})$')
            axes[1].set_xlabel('Frequency ($\omega_0$)')

            fig.suptitle(
                'Check that efield is well-resolved in time and frequency')
            plt.show()
Example #14
0
def retrieve_phase_far_field(src_fname,
                             save_path,
                             output_fname=None,
                             pad_length=256,
                             n_epoch=100,
                             learning_rate=0.001):

    # raw data is assumed to be centered at zero frequency
    prj_np = dxchange.read_tiff(src_fname)
    if output_fname is None:
        output_fname = os.path.basename(
            os.path.splitext(src_fname)[0]) + '_recon'

    # take modulus and inverse shift
    prj_np = ifftshift(np.sqrt(prj_np))

    obj_init = np.random.normal(50, 10, list(prj_np.shape) + [2])

    obj = tf.Variable(obj_init, dtype=tf.float32, name='obj')
    prj = tf.constant(prj_np, name='prj')

    obj_real = tf.cast(obj[:, :, 0], dtype=tf.complex64)
    obj_imag = tf.cast(obj[:, :, 1], dtype=tf.complex64)

    # obj_pad = tf.pad(obj, [[pad_length, pad_length], [pad_length, pad_length], [0, 0]], mode='SYMMETRIC')
    det = tf.fft2d(obj_real + 1j * obj_imag, name='detector_plane')

    loss = tf.reduce_mean(tf.squared_difference(tf.abs(det), prj, name='loss'))

    sess = tf.Session()

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    optimizer = optimizer.minimize(loss)

    sess.run(tf.global_variables_initializer())

    for i_epoch in range(n_epoch):
        t0 = time.time()
        _, current_loss = sess.run([optimizer, loss])
        print('Iteration {}: loss = {}, Δt = {} s.'.format(
            i_epoch, current_loss,
            time.time() - t0))

    det_final = sess.run(det)
    obj_final = sess.run(obj)
    res = np.linalg.norm(obj_final, 2, axis=2)
    dxchange.write_tiff(res,
                        os.path.join(save_path, output_fname),
                        dtype='float32',
                        overwrite=True)
    dxchange.write_tiff(fftshift(np.angle(det_final)),
                        os.path.join(save_path, 'detector_phase'),
                        dtype='float32',
                        overwrite=True)
    dxchange.write_tiff(fftshift(np.abs(det_final)**2),
                        os.path.join(save_path, 'detector_mag'),
                        dtype='float32',
                        overwrite=True)

    return
Example #15
0
    def polarization_to_signal(self,
                               P_of_t_in,
                               *,
                               local_oscillator_number=-1,
                               undersample_factor=1):
        """This function generates a frequency-resolved signal from a polarization field
           local_oscillator_number - usually the local oscillator will be the last pulse 
                                     in the list self.efields"""
        undersample_slice = slice(None, None, undersample_factor)
        P_of_t = P_of_t_in[undersample_slice].copy()
        t = self.t[undersample_slice]
        dt = t[1] - t[0]
        pulse_time = self.pulse_times[local_oscillator_number]
        efield_t = self.efield_times[local_oscillator_number]

        center = -self.centers[local_oscillator_number]
        P_of_t = P_of_t  # * np.exp(-1j*center*t)

        pulse_time_ind = np.argmin(np.abs(self.t))
        efield = np.zeros(self.t.size, dtype='complex')

        if efield_t.size == 1:
            # Impulsive limit: delta in time is flat in frequency
            efield = np.ones(
                self.w.size) * self.efields[local_oscillator_number]
        else:
            pulse_start_ind = pulse_time_ind - efield_t.size // 2
            pulse_end_ind = pulse_time_ind + efield_t.size // 2 + efield_t.size % 2

            t_slice = slice(pulse_start_ind, pulse_end_ind, None)

            efield[t_slice] = self.efields[local_oscillator_number]
            efield = fftshift(ifft(ifftshift(efield))) * efield.size * dt

        halfway = self.w.size // 2
        pm = self.w.size // (2 * undersample_factor)
        efield_min_ind = halfway - pm
        efield_max_ind = halfway + pm + self.w.size % 2
        efield = efield[efield_min_ind:efield_max_ind]

        P_of_w = fftshift(ifft(ifftshift(P_of_t))) * P_of_t.size * dt

        signal = P_of_w * np.conjugate(efield)
        if not self.return_complex_signal:
            return np.imag(signal)
        else:
            return 1j * signal
Example #16
0
def easy_ifft(data, axes=None):
    """utility method that includes fft shifting"""
    return ifftshift(
        ifftn(
            fftshift(
                data, axes=axes
            ), axes=axes
        ), axes=axes)
Example #17
0
def cfftn(data, axes):
    """ Centered fast fourier transform, n-dimensional.

    :param data: Complex input data.
    :param axes: Axes along which to shift and transform.
    :return: Fourier transformed data.
    """
    return fft.fftshift(fft.fftn(fft.ifftshift(data, axes=axes),
                                 axes=axes,
                                 norm='ortho'),
                        axes=axes)
Example #18
0
 def filter_frames(self, data):
     output = np.empty_like(data[0])
     nSlices = data[0].shape[self.slice_dir]
     for i in range(nSlices):
         self.sslice[self.slice_dir] = i
         sino = fft.fftshift(self.fft_object(data[0][tuple(self.sslice)]))
         sino[self.row1:self.row2] = \
             sino[self.row1:self.row2] * self.filtercomplex
         sino = fft.ifftshift(sino)
         sino = self.ifft_object(sino).real
         output[self.sslice] = sino
     return output
Example #19
0
 def ft_interface(a, s, axes, norm, **kwargs):
     # call fft and shift the result
     if omitlast:
         if isinstance(axes, int) or len(axes) < 2:  # no fftshift in the case of rfft
             return ftfunc(a, s, axes, norm, **kwargs)
         else:
             shftax = axes[:-1]
     else:
         shftax = axes
     if forward:
         return fftmod.fftshift(ftfunc(a, s, axes, norm, **kwargs), shftax)
     else:
         return ftfunc(fftmod.ifftshift(a, shftax), s, axes, norm, **kwargs)
Example #20
0
def fftdeconvolve(image, psf):
   """
   De-convolution by directly dividing the DFT... may not be numerically 
   desirable for many applications. Noise could be an issue.
   Use scipy.fftpack for now; will re-write for anfft later...
   Taken from this post on stackoverflow.com: 
   http://stackoverflow.com/questions/17473917/is-there-a-equivalent-of-scipy-signal-deconvolve-for-2d-arrays
   """
   if not _pyfftw:
      raise NotImplementedError
   image = image.astype('float')
   psf = psf.astype('float')

   # image_fft = fftpack.fftshift(fftpack.fftn(image))
   # psf_fft = fftpack.fftshift(fftpack.fftn(psf))
   image_fft = fftshift(fftn(image))
   psf_fft = fftshift(fftn(psf))
   kernel = fftshift(ifftn(ifftshift(image_fft / psf_fft)))
   return kernel
Example #21
0
def fresnel_propagate(wavefield,
                      energy_ev,
                      psize_cm,
                      dist_cm,
                      fresnel_approx=True,
                      pad=0,
                      sign_convention=1):
    """
    Perform Fresnel propagation on a batch of wavefields.
    :param wavefield: complex wavefield with shape [n_batches, n_y, n_x].
    :param energy_ev: float.
    :param psize_cm: size-3 vector with pixel size ([dy, dx, dz]).
    :param dist_cm: propagation distance.
    :return:
    """
    minibatch_size = wavefield.shape[0]

    if pad > 0:
        wavefield = np.pad(wavefield, [[0, 0], [pad, pad], [pad, pad]],
                           mode='edge')

    grid_shape = wavefield.shape[1:]
    if len(psize_cm) == 1: psize_cm = [psize_cm] * 3
    voxel_nm = np.array(psize_cm) * 1.e7
    lmbda_nm = 1240. / energy_ev
    mean_voxel_nm = np.prod(voxel_nm)**(1. / 3)
    size_nm = np.array(grid_shape) * voxel_nm
    dist_nm = dist_cm * 1e7

    h = get_kernel(dist_nm,
                   lmbda_nm,
                   voxel_nm,
                   grid_shape,
                   fresnel_approx=fresnel_approx,
                   sign_convention=sign_convention)

    wavefield = ifft2(
        ifftshift(fftshift(fft2(wavefield), axes=[1, 2]) * h, axes=[1, 2]))

    if pad > 0: wavefield = wavefield[:, pad:-pad, pad:-pad]

    return wavefield
Example #22
0
    def subtract_DC(signal,return_ft = False, axis = 1):
        """Use discrete fourier transform to remove the DC component of a 
            signal.
        Args:
            signal (np.ndarray): real signal to be processed
            return_ft (bool): if True, return the Fourier transform of the 
                              input signal
            axis (int): axis along which the fourier trnasform is to be taken
"""
        sig_fft = fft(ifftshift(signal,axes=(axis)),axis=axis)
        
        nd_slice = [slice(None) for i in range(len(sig_fft.shape))]
        nd_slice[axis] = slice(0,1,1)
        nd_slice = tuple(nd_slice)
        sig_fft[nd_slice] = 0
        
        if not return_ft:
            sig = fftshift(ifft(sig_fft),axes=(axis))
        else:
            sig = sig_fft
        return sig
Example #23
0
def fft_filter(im, rois, shift, plot=True, imshow_kwargs={}):
    f0 = fft.fftshift(fft.fft2(im))
    f1 = f0.copy()

    if plot:
        kw = dict(vmin=-0.05, vmax=1)
        kw.update(imshow_kwargs)
        fig, axes = plt.subplots(2,
                                 2,
                                 figsize=(9, 4),
                                 sharex='col',
                                 sharey='col')
        (ax, ax_f), (ax1, ax1_f) = axes
        ax.imshow(im, **kw)
        ax_f.imshow(abs(f0), norm=LogNorm())

    dx, dy = shift
    for j, roi in enumerate(rois):
        print(j, roi)
        roi_t = translate_roi(roi, dx, dy)
        roi_symm = symmetric_roi(roi, im.shape)
        roi_symm_t = symmetric_roi(roi_t, im.shape)

        f1[roi] = f1[roi_t]
        f1[roi_symm] = f1[roi_symm_t]
        if plot:
            roi_to_rect(roi, ax=ax_f, index=j)
            roi_to_rect(roi, ax=ax1_f)
            roi_to_rect(roi_t, ax=ax_f, ec='r')
            roi_to_rect(roi_symm, ax=ax_f)
            roi_to_rect(roi_symm, ax=ax1_f)
            roi_to_rect(roi_symm_t, ax=ax_f, ec='r')

    im1 = fft.ifft2(fft.ifftshift(f1)).real
    if plot:
        ax1.imshow(im1, **kw)
        ax1_f.imshow(abs(f1), norm=LogNorm())

    return im1
Example #24
0
def pattern_params(my_pat, size=2):
    """Find stuff"""
    # REAL FFT!
    # note the limited shifting, we don't want to shift the last axis
    my_pat_fft = fftshift(rfftn(ifftshift(my_pat)),
                           axes=tuple(range(my_pat.ndim))[:-1])
    my_abs_pat_fft = abs(my_pat_fft)
    # find dc loc, center of FFT after shifting
    sizeky, sizekx = my_abs_pat_fft.shape
    # remember we didn't shift the last axis!
    dc_loc = (sizeky // 2, 0)
    # mask data and find next biggest peak
    dc_power = my_abs_pat_fft[dc_loc]
    my_abs_pat_fft[dc_loc] = 0
    max_loc = np.unravel_index(my_abs_pat_fft.argmax(), my_abs_pat_fft.shape)
    # pull the 3x3 region around the peak and fit
    max_shift = localize_peak(my_abs_pat_fft[slice_maker(max_loc, 3)])
    # calculate precise peak relative to dc
    peak = np.array(max_loc) + np.array(max_shift) - np.array(dc_loc)
    # correct location based on initial data shape
    peak_corr = peak / np.array(my_pat.shape)
    # calc angle
    preciseangle = np.arctan2(*peak_corr)
    # calc period
    precise_period = 1 / norm(peak_corr)
    # calc phase
    phase = np.angle(my_pat_fft[max_loc[0], max_loc[1]])
    # calc modulation depth
    numerator = abs(my_pat_fft[slice_maker(max_loc, size)].sum())
    mod = numerator / dc_power
    return {"period": precise_period,
            "angle": preciseangle,
            "phase": phase,
            "fft": my_pat_fft,
            "mod": mod,
            "max_loc": max_loc}
Example #25
0
def fftconvolve_fast(data, kernel, **kwargs):
    """A faster version of fft convolution

    In this case the kernel ifftshifted before FFT but the data is not.
    This can be done because the effect of fourier convolution is to 
    "wrap" around the data edges so whether we ifftshift before FFT
    and then fftshift after it makes no difference so we can skip the
    step entirely.
    """
    # TODO: add error checking like in the above and add functionality
    # for complex inputs. Also could add options for different types of
    # padding.
    dshape = np.array(data.shape)
    kshape = np.array(kernel.shape)
    # find maximum dimensions
    maxshape = np.max((dshape, kshape), 0)
    # calculate a nice shape
    fshape = [sig.fftpack.helper.next_fast_len(int(d)) for d in maxshape]
    # pad out with reflection
    pad_data = fft_pad(data, fshape, "reflect")
    # calculate padding
    padding = tuple(
        _calc_pad(o, n) for o, n in zip(data.shape, pad_data.shape))
    # so that we can calculate the cropping, maybe this should be integrated
    # into `fft_pad` ...
    fslice = tuple(
        slice(s, -e) if e != 0 else slice(s, None) for s, e in padding)
    if kernel.shape != pad_data.shape:
        # its been assumed that the background of the kernel has already been
        # removed and that the kernel has already been centered
        kernel = fft_pad(kernel, pad_data.shape, mode='constant')
    k_kernel = rfftn(ifftshift(kernel), pad_data.shape, **kwargs)
    k_data = rfftn(pad_data, pad_data.shape, **kwargs)
    convolve_data = irfftn(k_kernel * k_data, pad_data.shape, **kwargs)
    # return data with same shape as original data
    return convolve_data[fslice]
Example #26
0
    ax3.set_xlim([x1,x2-dx])
    ax3.set_ylim([y1,y2-dy])
    ax3.set_title("Spatial density (NUMERIC)")

    ax2.imshow(phi.T, origin='lower', interpolation='none', extent=[px1,px2-dpx,py1,py2-dpy],vmin=amin(phi_exact), vmax=amax(phi_exact))
    ax4.set_xlabel('x')
    ax4.set_ylabel('y')
    ax2.set_xlim([px1,px2-dpx])
    ax2.set_ylim([py1,py2-dpy])
    ax4.set_title("Momentum density (NUMERIC)")

    plt.tight_layout()
    fig.savefig('frames' + '/%04d.png' % k)
    fig.clf()
    plt.close('all')

dt = (t2-t1)/100. # the first very rough guess of time step
expU = exp(dt*dU)
expT = exp(dt*dT)

W = fftshift(gauss(xgrid, ygrid, pxgrid, pygrid, x0, y0, px0, py0, sigma_x, sigma_y, sigma_px, sigma_py))
t = t1
Nt = 1

while t <= t2:
    Wexact = W_analytic(xgrid, ygrid, pxgrid, pygrid, t)
    draw_frame(ifftshift(W), Wexact, Nt)
    W = solve_spectral(W, expU, expT)
    t += dt
    Nt += 1
Example #27
0
def multislice_propagate(grid_delta_batch,
                         grid_beta_batch,
                         probe_real,
                         probe_imag,
                         energy_ev,
                         psize_cm,
                         free_prop_cm=None,
                         return_intermediate=False,
                         sign_convention=1):
    """
    Perform multislice propagation on a batch of 3D objects.
    :param grid_delta_batch: 4D array for object delta with shape [n_batches, n_y, n_x, n_z].
    :param grid_beta_batch: 4D array for object beta with shape [n_batches, n_y, n_x, n_z].
    :param probe_real: 2D array for the real part of the probe.
    :param probe_imag: 2D array for the imaginary part of the probe.
    :param energy_ev:
    :param psize_cm: size-3 vector with pixel size ([dy, dx, dz]).
    :param free_prop_cm:
    :return:
    """
    minibatch_size = grid_delta_batch.shape[0]
    grid_shape = grid_delta_batch.shape[1:]
    voxel_nm = np.array(psize_cm) * 1.e7
    wavefront = np.zeros([minibatch_size, grid_shape[0], grid_shape[1]],
                         dtype='complex64')
    wavefront += (probe_real + 1j * probe_imag)

    lmbda_nm = 1240. / energy_ev
    mean_voxel_nm = np.prod(voxel_nm)**(1. / 3)
    size_nm = np.array(grid_shape) * voxel_nm

    n_slice = grid_shape[-1]
    delta_nm = voxel_nm[-1]

    # h = get_kernel_ir(delta_nm, lmbda_nm, voxel_nm, grid_shape)
    h = get_kernel(delta_nm,
                   lmbda_nm,
                   voxel_nm,
                   grid_shape,
                   sign_convention=sign_convention)
    k = 2. * PI * delta_nm / lmbda_nm

    if return_intermediate:
        wavefront_ls = []
        wavefront_ls.append(abs(wavefront))

    for i in trange(n_slice):
        delta_slice = grid_delta_batch[:, :, :, i]
        beta_slice = grid_beta_batch[:, :, :, i]
        c = np.exp(1j * k * delta_slice) * np.exp(-k * beta_slice)
        wavefront = wavefront * c
        if i < n_slice - 1:
            wavefront = ifft2(
                ifftshift(fftshift(fft2(wavefront), axes=[1, 2]) * h,
                          axes=[1, 2]))
        if return_intermediate:
            wavefront_ls.append(abs(wavefront))

    if free_prop_cm not in [None, 0]:
        if free_prop_cm == 'inf':
            wavefront = fftshift(fft2(wavefront), axes=[1, 2])
        else:
            dist_nm = free_prop_cm * 1e7
            l = np.prod(size_nm)**(1. / 3)
            crit_samp = lmbda_nm * dist_nm / l
            algorithm = 'TF' if mean_voxel_nm > crit_samp else 'IR'
            if algorithm == 'TF':
                h = get_kernel(dist_nm,
                               lmbda_nm,
                               voxel_nm,
                               grid_shape,
                               sign_convention=sign_convention)
                wavefront = ifft2(
                    ifftshift(fftshift(fft2(wavefront), axes=[1, 2]) * h,
                              axes=[1, 2]))
            else:
                h = get_kernel_ir(dist_nm,
                                  lmbda_nm,
                                  voxel_nm,
                                  grid_shape,
                                  sign_convention=sign_convention)
                wavefront = ifft2(
                    ifftshift(fftshift(fft2(wavefront), axes=[1, 2]) * h,
                              axes=[1, 2]))

    if return_intermediate:
        if free_prop_cm not in [None, 0]:
            wavefront_ls.append(abs(wavefront))
        return wavefront, wavefront_ls
    else:
        return wavefront
Example #28
0
def fft_convolve2d(x, f, mode='same', boundary='constant', fft_filter=False):
    r"""
    Performs fast 2d convolution in the frequency domain convolving each image
    channel with its corresponding filter channel.

    Parameters
    ----------
    x : ``(channels, height, width)`` `ndarray`
        Image.
    f : ``(channels, height, width)`` `ndarray`
        Filter.
    mode : str {`full`, `same`, `valid`}, optional
        Determines the shape of the resulting convolution.
    boundary: str {`constant`, `symmetric`}, optional
        Determines how the image is padded.
    fft_filter: `bool`, optional
        If `True`, the filter is assumed to be defined on the frequency
        domain. If `False` the filter is assumed to be defined on the
        spatial domain.

    Returns
    -------
    c: ``(channels, height, width)`` `ndarray`
        Result of convolving each image channel with its corresponding
        filter channel.
    """
    if fft_filter:
        # extended shape is filter shape
        ext_shape = np.asarray(f.shape[-2:])

        # extend image and filter
        ext_x = pad(x, ext_shape, boundary=boundary)

        # compute ffts of extended image
        fft_ext_x = fft2(ext_x)
        fft_ext_f = f
    else:
        # extended shape
        x_shape = np.asarray(x.shape[-2:])
        f_shape = np.asarray(f.shape[-2:])
        f_half_shape = (f_shape / 2).astype(int)
        ext_shape = x_shape + f_half_shape - 1

        # extend image and filter
        ext_x = pad(x, ext_shape, boundary=boundary)
        ext_f = pad(f, ext_shape)

        # compute ffts of extended image and extended filter
        fft_ext_x = fft2(ext_x)
        fft_ext_f = fft2(ext_f)

    # compute extended convolution in Fourier domain
    fft_ext_c = fft_ext_f * fft_ext_x

    # compute ifft of extended convolution
    ext_c = np.real(ifftshift(ifft2(fft_ext_c), axes=(-2, -1)))

    if mode is 'full':
        return ext_c
    elif mode is 'same':
        return crop(ext_c, x_shape)
    elif mode is 'valid':
        return crop(ext_c, x_shape - f_half_shape + 1)
    else:
        raise ValueError(
            "mode={}, is not supported. The only supported "
            "modes are: 'full', 'same' and 'valid'.".format(mode))
Example #29
0
def retrieve_phase(data,
                   params,
                   max_iters=200,
                   pupil_tol=1e-8,
                   mse_tol=1e-8,
                   phase_only=False,
                   mclass=HanserPSF):
    """Retrieve the phase across the objective's back pupil from an
    experimentally measured PSF.

    Follows: [Hanser, B. M.; Gustafsson, M. G. L.; Agard, D. A.;
    Sedat, J. W. Phase Retrieval for High-Numerical-Aperture Optical Systems.
    Optics Letters 2003, 28 (10), 801.](dx.doi.org/10.1364/OL.28.000801)

    Parameters
    ----------
    data : ndarray (3 dim)
        The experimentally measured PSF of a subdiffractive source
    params : dict
        Parameters to pass to HanserPSF, size and zsize will be automatically
        updated from data.shape
    max_iters : int
        The maximum number of iterations to run, default is 200
    pupil_tol : float
        the tolerance in percent change in change in pupil, default is 1e-8
    mse_tol : float
        the tolerance in percent change for the mean squared error between
        data and simulated data, default is 1e-8
    phase_only : bool
        True means only the phase of the back pupil is retrieved while the
        amplitude is not.

    Returns
    -------
    PR_result : PhaseRetrievalResult
        An object that contains the phase retrieval result
    """
    # make sure data is square
    assert data.shape[1] == data.shape[2], "Data is not square in x/y"
    assert data.ndim == 3, "Data doesn't have enough dims"
    # make sure the user hasn't screwed up the params
    params.update(
        dict(vec_corr="none",
             condition="none",
             zsize=data.shape[0],
             size=data.shape[-1]))
    # assume that data prep has been handled outside function
    # The field magnitude is the square root of the intensity
    mag = psqrt(data)
    # generate a model from parameters
    model = mclass(**params)
    # generate coordinates
    model._gen_kr()
    # start a list for iteration
    mse = np.zeros(max_iters)
    mse_diff = np.zeros(max_iters)
    pupil_diff = np.zeros(max_iters)
    # generate a pupil to start with
    new_pupil = model._gen_pupil()
    # save it as a mask
    mask = new_pupil.real
    # iterate
    old_mse, old_pupil = None, None
    for i in range(max_iters):
        # generate new mse and add it to the list
        model._gen_psf(new_pupil)
        new_mse = _calc_mse(data, model.PSFi)
        mse[i] = new_mse
        if i > 0:
            # calculate the difference in mse to test for convergence
            mse_diff[i] = abs(old_mse - new_mse) / old_mse
            # calculate the difference in pupil
            pupil_diff[i] = (abs(old_pupil - new_pupil)**
                             2).mean() / (abs(old_pupil)**2).mean()
        else:
            mse_diff[i] = np.nan
            pupil_diff[i] = np.nan
        # check tolerances, how much has the pupil changed, how much has the mse changed
        # and what's the absolute mse
        logger.info(
            "Iteration {}, mse_diff = {:.2g}, pupil_diff = {:.2g}".format(
                i, mse_diff[i], pupil_diff[i]))
        if pupil_diff[i] < pupil_tol or mse_diff[i] < mse_tol or mse[
                i] < mse_tol:
            break
        # update old_mse
        old_mse = new_mse
        # retrieve new pupil
        old_pupil = new_pupil
        # keep phase
        phase = np.angle(model.PSFa.squeeze())
        # replace magnitude with experimentally measured mag
        new_psf = mag * np.exp(1j * phase)
        # generate the new pupils
        new_pupils = fftn(ifftshift(new_psf, axes=(1, 2)), axes=(1, 2))
        # undo defocus and take the mean
        new_pupils /= model._calc_defocus()
        new_pupil = new_pupils.mean(0) * mask
        # if phase only discard magnitude info
        if phase_only:
            new_pupil = np.exp(1j * np.angle(new_pupil)) * mask
    else:
        logger.warn("Reach max iterations without convergence")
    mse = mse[:i + 1]
    mse_diff = mse_diff[:i + 1]
    pupil_diff = pupil_diff[:i + 1]
    # shift mask
    mask = fftshift(mask)
    # shift phase then unwrap and mask
    phase = unwrap_phase(fftshift(np.angle(new_pupil))) * mask
    # shift magnitude
    magnitude = fftshift(abs(new_pupil)) * mask
    return PhaseRetrievalResult(magnitude, phase, mse, pupil_diff, mse_diff,
                                model)
Example #30
0
def ifftnc(x, axes):
    tmp = fft.fftshift(x, axes=axes)
    tmp = fft.ifftn(tmp, axes=axes)
    return fft.ifftshift(tmp, axes=axes)
Example #31
0
 def process_frames(self, data):
     sino = fft.fftshift(self.fft_object(data[0]))
     sino[self.row1:self.row2] = \
         sino[self.row1:self.row2] * self.filtercomplex
     sino = fft.ifftshift(sino)
     return self.ifft_object(sino).real
    def set_pulse_shapes(self, pump_field, probe_field, *, plot_fields=True):
        """Sets a list of 4 pulse amplitudes, given an input pump shape and probe
shape. Assumes 4-wave mixing signals, and so 4 interactions
"""
        self.efields = [pump_field, pump_field, probe_field, probe_field]

        if self.efield_t.size == 1:
            pass
        else:
            pump_tail = np.max(np.abs([pump_field[0], pump_field[-1]]))
            probe_tail = np.max(np.abs([probe_field[0], probe_field[-1]]))

            if pump_field.size != self.efield_t.size:
                warnings.warn(
                    'Pump must be evaluated on efield_t, the grid defined by dt and num_conv_points'
                )
            if probe_field.size != self.efield_t.size:
                warnings.warn(
                    'Probe must be evaluated on efield_t, the grid defined by dt and num_conv_points'
                )

            if pump_tail > np.max(np.abs(pump_field)) / 100:
                warnings.warn(
                    'Consider using larger num_conv_points, pump does not decay to less than 1% of maximum value in time domain'
                )
            if probe_tail > np.max(np.abs(probe_field)) / 100:
                warnings.warn(
                    'Consider using larger num_conv_points, probe does not decay to less than 1% of maximum value in time domain'
                )

            pump_fft = fftshift(fft(ifftshift(pump_field))) * self.dt
            probe_fft = fftshift(fft(ifftshift(probe_field))) * self.dt
            pump_fft_tail = np.max(np.abs([pump_fft[0], pump_fft[-1]]))
            probe_fft_tail = np.max(np.abs([probe_fft[0], probe_fft[-1]]))

            if pump_fft_tail > np.max(np.abs(pump_fft)) / 100:
                warnings.warn(
                    '''Consider using smaller value of dt, pump does not decay to less than 1% of maximum value in frequency domain'''
                )

            if probe_fft_tail > np.max(np.abs(probe_fft)) / 100:
                warnings.warn(
                    '''Consider using smaller value of dt, probe does not decay to less than 1% of maximum value in frequency domain'''
                )

            if plot_fields:
                fig, axes = plt.subplots(2, 2)
                l1, l2, = axes[0, 0].plot(self.efield_t, np.real(pump_field),
                                          self.efield_t, np.imag(pump_field))
                plt.legend([l1, l2], ['Real', 'Imag'])
                axes[0, 1].plot(self.efield_w, np.real(pump_fft),
                                self.efield_w, np.imag(pump_fft))
                axes[1, 0].plot(self.efield_t, np.real(probe_field),
                                self.efield_t, np.imag(probe_field))
                axes[1, 1].plot(self.efield_w, np.real(probe_fft),
                                self.efield_w, np.imag(probe_fft))

                axes[0, 0].set_ylabel('Pump Amp')
                axes[1, 0].set_ylabel('Probe Amp')
                axes[1, 0].set_xlabel('Time')
                axes[1, 1].set_xlabel('Frequency')

                fig.suptitle(
                    'Check that pump and probe are well-resolved in time and frequency'
                )
Example #33
0
def py_zogy(Nf, Rf, P_Nf, P_Rf, S_Nf, S_Rf, SN, SR, dx=0.25, dy=0.25):

	'''Python implementation of ZOGY image subtraction algorithm.
	As per Frank's instructions, will assume images have been aligned,
	background subtracted, and gain-matched.
	
	Arguments:
	N: New image (filename)
	R: Reference image (filename)
	P_N: PSF of New image (filename)
	P_R: PSF or Reference image (filename)
	S_N: 2D Uncertainty (sigma) of New image (filename)
	S_R: 2D Uncertainty (sigma) of Reference image (filename)
	SN: Average uncertainty (sigma) of New image
	SR: Average uncertainty (sigma) of Reference image
	dx: Astrometric uncertainty (sigma) in x coordinate
	dy: Astrometric uncertainty (sigma) in y coordinate
	
	Returns:
	D: Subtracted image
	P_D: PSF of subtracted image
	S_corr: Corrected subtracted image
	'''
	
	# Load the new and ref images into memory
	N = fits.open(Nf)[0].data
	R = fits.open(Rf)[0].data
	
	# Load the PSFs into memory
	P_N_small = fits.open(P_Nf)[0].data
	P_R_small = fits.open(P_Rf)[0].data
	
	# Place PSF at center of image with same size as new / reference
	P_N = np.zeros(N.shape)
	P_R = np.zeros(R.shape)
	idx = [slice(N.shape[0]/2 - P_N_small.shape[0]/2,
	             N.shape[0]/2 + P_N_small.shape[0]/2 + 1),
	       slice(N.shape[1]/2 - P_N_small.shape[1]/2,
	             N.shape[1]/2 + P_N_small.shape[1]/2 + 1)]
	P_N[idx] = P_N_small
	P_R[idx] = P_R_small
		
	# Shift the PSF to the origin so it will not introduce a shift
	P_N = fft.fftshift(P_N)
	P_R = fft.fftshift(P_R)
		
	# Take all the Fourier Transforms
	N_hat = fft.fft2(N)
	R_hat = fft.fft2(R)
	
	P_N_hat = fft.fft2(P_N)
	P_R_hat = fft.fft2(P_R)
	
	# Fourier Transform of Difference Image (Equation 13)
	D_hat_num = (P_R_hat * N_hat - P_N_hat * R_hat) 
	D_hat_den = np.sqrt(SN**2 * np.abs(P_R_hat**2) + SR**2 * np.abs(P_N_hat**2))
	D_hat = D_hat_num / D_hat_den
	
	# Flux-based zero point (Equation 15)
	FD = 1. / np.sqrt(SN**2 + SR**2)
	
	# Difference Image
	# TODO: Why is the FD normalization in there?
	D = np.real(fft.ifft2(D_hat)) / FD
	
	# Fourier Transform of PSF of Subtraction Image (Equation 14)
	P_D_hat = P_R_hat * P_N_hat / FD / D_hat_den
	
	# PSF of Subtraction Image
	P_D = np.real(fft.ifft2(P_D_hat))
	P_D = fft.ifftshift(P_D)	
	P_D = P_D[idx]
	
	# Fourier Transform of Score Image (Equation 17)
	S_hat = FD * D_hat * np.conj(P_D_hat)
	
	# Score Image
	S = np.real(fft.ifft2(S_hat))
	
	# Now start calculating Scorr matrix (including all noise terms)
	
	# Start out with source noise
	# Load the sigma images into memory
	S_N = fits.open(S_Nf)[0].data
	S_R = fits.open(S_Rf)[0].data

	# Sigma to variance
	V_N = S_N**2
	V_R = S_R**2
	
	# Fourier Transform of variance images
	V_N_hat = fft.fft2(V_N)
	V_R_hat = fft.fft2(V_R)
	
	# Equation 28
	kr_hat = np.conj(P_R_hat) * np.abs(P_N_hat**2) / (D_hat_den**2)
	kr = np.real(fft.ifft2(kr_hat))
	
	# Equation 29
	kn_hat = np.conj(P_N_hat) * np.abs(P_R_hat**2) / (D_hat_den**2)
	kn = np.real(fft.ifft2(kn_hat))
	
	# Noise in New Image: Equation 26
	V_S_N = np.real(fft.ifft2(V_N_hat * fft.fft2(kn**2)))
	
	# Noise in Reference Image: Equation 27
	V_S_R = np.real(fft.ifft2(V_R_hat * fft.fft2(kr**2)))
	
	# Astrometric Noise
	# Equation 31
	# TODO: Check axis (0/1) vs x/y coordinates
	S_N = np.real(fft.ifft2(kn_hat * N_hat))
	dSNdx = S_N - np.roll(S_N, 1, axis=1)
	dSNdy = S_N - np.roll(S_N, 1, axis=0)
   
   	# Equation 30
   	V_ast_S_N = dx**2 * dSNdx**2 + dy**2 * dSNdy**2

	# Equation 33
	S_R = np.real(fft.ifft2(kr_hat * R_hat))
	dSRdx = S_R - np.roll(S_R, 1, axis=1)
	dSRdy = S_R - np.roll(S_R, 1, axis=0)
  
  	# Equation 32
  	V_ast_S_R = dx**2 * dSRdx**2 + dy**2 * dSRdy**2
  	
  	# Calculate Scorr
  	S_corr = S / np.sqrt(V_S_N + V_S_R + V_ast_S_N + V_ast_S_R)
  	
  	return D, P_D, S_corr
Example #34
0
    def fillWithGRFFromTemplate(self, twodPower, bufferFactor=1, threads=1):
        """
        Generate a Gaussian random field from an input power spectrum 
        specified as a 2d powerMap
        
        Notes
        -----
        BufferFactor = 1 means the map will have periodic boundary function, while
        BufferFactor > 1 means the map will be genrated on a patch bufferFactor 
        times larger in each dimension and then cut out so as to have 
        non-periodic boundary conditions.
        
        Fills the data field of the map with the GRF realization.
        """

        ft = fftTools.fftFromLiteMap(self, threads=threads)
        Ny = self.Ny * bufferFactor
        Nx = self.Nx * bufferFactor

        bufferFactor = int(bufferFactor)
        assert bufferFactor >= 1

        realPart = numpy.zeros([Ny, Nx])
        imgPart = numpy.zeros([Ny, Nx])

        ly = fftfreq(Ny, d=self.pixScaleY) * (2 * numpy.pi)
        lx = fftfreq(Nx, d=self.pixScaleX) * (2 * numpy.pi)
        # print ly
        modLMap = numpy.zeros([Ny, Nx])
        iy, ix = numpy.mgrid[0:Ny, 0:Nx]
        modLMap[iy, ix] = numpy.sqrt(ly[iy] ** 2 + lx[ix] ** 2)

        # divide out area factor
        area = twodPower.Nx * twodPower.Ny * twodPower.pixScaleX * twodPower.pixScaleY
        twodPower.powerMap *= (twodPower.Nx * twodPower.Ny) ** 2 / area

        if bufferFactor > 1 or twodPower.Nx != Nx or twodPower.Ny != Ny:

            lx_shifted = fftshift(twodPower.lx)
            ly_shifted = fftshift(twodPower.ly)
            twodPower_shifted = fftshift(twodPower.powerMap)

            f_interp = interp2d(lx_shifted, ly_shifted, twodPower_shifted)

            # ell = numpy.ravel(twodPower.modLMap)
            # Cell = numpy.ravel(twodPower.powerMap)
            # print ell
            # print Cell
            # s = splrep(ell,Cell,k=3)
            #
            #
            # ll = numpy.ravel(modLMap)
            # kk = splev(ll,s)

            kk = f_interp(fftshift(lx), fftshift(ly))
            kk = ifftshift(kk)

            # id = numpy.where(modLMap > ell.max())
            # kk[id] = 0.
            # add a cosine ^2 falloff at the very end
            # id2 = numpy.where( (ll> (ell.max()-500)) & (ll<ell.max()))
            # lEnd = ll[id2]
            # kk[id2] *= numpy.cos((lEnd-lEnd.min())/(lEnd.max() -lEnd.min())*numpy.pi/2)

            # pylab.loglog(ll,kk)

            area = Nx * Ny * self.pixScaleX * self.pixScaleY
            # p = numpy.reshape(kk,[Ny,Nx]) /area * (Nx*Ny)**2
            p = kk  # / area * (Nx*Ny)**2
        else:
            area = Nx * Ny * self.pixScaleX * self.pixScaleY
            p = twodPower.powerMap  # /area*(Nx*Ny)**2

        realPart = numpy.sqrt(p) * numpy.random.randn(Ny, Nx)
        imgPart = numpy.sqrt(p) * numpy.random.randn(Ny, Nx)

        kMap = realPart + 1j * imgPart
        if have_pyFFTW:
            data = numpy.real(ifft2(kMap, threads=threads))
        else:
            data = numpy.real(ifft2(kMap))

        b = bufferFactor
        self.data = data[(b - 1) / 2 * self.Ny : (b + 1) / 2 * self.Ny, (b - 1) / 2 * self.Nx : (b + 1) / 2 * self.Nx]
Example #35
0
def upgradePixelPitch(m, N=1, threads=1):
    """
    Go to finer pixels with fourier interpolation.
    
    Parameters
    ----------
    m : liteMap
        The liteMap object holding the data to upgrade the pixel size of. 
    N : int, optional
        Go to 2^N times smaller pixels. Default is 1.
    threads : int, optional
        Number of threads to use in pyFFTW calculations. Default is 1.
    
    Returns
    -------
    mNew : liteMap
        The map with smaller pixels.
    """
    if N < 1:
        return m.copy()

    Ny = m.Ny * 2 ** N
    Nx = m.Nx * 2 ** N
    npix = Ny * Nx

    if have_pyFFTW:
        ft = fft2(m.data, threads=threads)
    else:
        ft = fft2(m.data)
    ftShifted = fftshift(ft)
    newFtShifted = numpy.zeros((Ny, Nx), dtype=numpy.complex128)

    # From the numpy.fft.fftshift help:
    # """
    # Shift zero-frequency component to center of spectrum.
    #
    # This function swaps half-spaces for all axes listed (defaults to all).
    # If len(x) is even then the Nyquist component is y[0].
    # """
    #
    # So in the case that we have an odd dimension in our map, we want to put
    # the extra zero at the beginning

    if m.Nx % 2 != 0:
        offsetX = (Nx - m.Nx) / 2 + 1
    else:
        offsetX = (Nx - m.Nx) / 2
    if m.Ny % 2 != 0:
        offsetY = (Ny - m.Ny) / 2 + 1
    else:
        offsetY = (Ny - m.Ny) / 2

    newFtShifted[offsetY : offsetY + m.Ny, offsetX : offsetX + m.Nx] = ftShifted
    del ftShifted
    ftNew = ifftshift(newFtShifted)
    del newFtShifted

    # Finally, deconvolve by the pixel window
    mPix = numpy.copy(numpy.real(ftNew))
    mPix[:] = 0.0
    mPix[
        mPix.shape[0] / 2 - (2 ** (N - 1)) : mPix.shape[0] / 2 + (2 ** (N - 1)),
        mPix.shape[1] / 2 - (2 ** (N - 1)) : mPix.shape[1] / 2 + (2 ** (N - 1)),
    ] = (1.0 / (2.0 ** N) ** 2)

    if have_pyFFTW:
        ftPix = fft2(mPix, threads=threads)
    else:
        ftPix = fft2(mPix)

    del mPix
    inds = numpy.where(ftNew != 0)
    ftNew[inds] /= numpy.abs(ftPix[inds])

    if have_pyFFTW:
        newData = ifft2(ftNew, threads=threads) * (2 ** N) ** 2
    else:
        newData = ifft2(ftNew) * (2 ** N) ** 2
    del ftNew
    del ftPix

    x0_new, y0_new = m.pixToSky(0, 0)

    m = m.copy()  # don't overwrite original
    m.wcs.header.update("NAXIS1", 2 ** N * m.wcs.header["NAXIS1"])
    m.wcs.header.update("NAXIS2", 2 ** N * m.wcs.header["NAXIS2"])
    m.wcs.header.update("CDELT1", m.wcs.header["CDELT1"] / 2.0 ** N)
    m.wcs.header.update("CDELT2", m.wcs.header["CDELT2"] / 2.0 ** N)
    m.wcs.updateFromHeader()

    p_x, p_y = m.skyToPix(x0_new, y0_new)

    m.wcs.header.update("CRPIX1", m.wcs.header["CRPIX1"] - p_x)
    m.wcs.header.update("CRPIX2", m.wcs.header["CRPIX2"] - p_y)
    m.wcs.updateFromHeader()

    mNew = liteMapFromDataAndWCS(numpy.real(newData), m.wcs)
    mNew.data[:] = numpy.real(newData[:])
    return mNew
Example #36
0
 def process_frames(self, data):
     sino = fft.fftshift(self.fft_object(data[0]))
     sino[self.row1:self.row2] = \
         sino[self.row1:self.row2] * self.filtercomplex
     sino = fft.ifftshift(sino)
     return self.ifft_object(sino).real
Example #37
0
def fft_convolve2d(x, f, mode='same', boundary='constant', fft_filter=False):
    r"""
    Performs fast 2d convolution in the frequency domain convolving each image
    channel with its corresponding filter channel.

    Parameters
    ----------
    x : ``(channels, height, width)`` `ndarray`
        Image.
    f : ``(channels, height, width)`` `ndarray`
        Filter.
    mode : str {`full`, `same`, `valid`}, optional
        Determines the shape of the resulting convolution.
    boundary: str {`constant`, `symmetric`}, optional
        Determines how the image is padded.
    fft_filter: `bool`, optional
        If `True`, the filter is assumed to be defined on the frequency
        domain. If `False` the filter is assumed to be defined on the
        spatial domain.

    Returns
    -------
    c: ``(channels, height, width)`` `ndarray`
        Result of convolving each image channel with its corresponding
        filter channel.
    """
    if fft_filter:
        # extended shape is filter shape
        ext_shape = np.asarray(f.shape[-2:])

        # extend image and filter
        ext_x = pad(x, ext_shape, boundary=boundary)

        # compute ffts of extended image
        fft_ext_x = fft2(ext_x)
        fft_ext_f = f
    else:
        # extended shape
        x_shape = np.asarray(x.shape[-2:])
        f_shape = np.asarray(f.shape[-2:])
        f_half_shape = (f_shape / 2).astype(int)
        ext_shape = x_shape + f_half_shape - 1

        # extend image and filter
        ext_x = pad(x, ext_shape, boundary=boundary)
        ext_f = pad(f, ext_shape)

        # compute ffts of extended image and extended filter
        fft_ext_x = fft2(ext_x)
        fft_ext_f = fft2(ext_f)

    # compute extended convolution in Fourier domain
    fft_ext_c = fft_ext_f * fft_ext_x

    # compute ifft of extended convolution
    ext_c = np.real(ifftshift(ifft2(fft_ext_c), axes=(-2, -1)))

    if mode is 'full':
        return ext_c
    elif mode is 'same':
        return crop(ext_c, x_shape)
    elif mode is 'valid':
        return crop(ext_c, x_shape - f_half_shape + 1)
    else:
        raise ValueError("mode={}, is not supported. The only supported "
                         "modes are: 'full', 'same' and 'valid'.".format(mode))