Esempio n. 1
0
    def _get_G_full(self, size):
        start = time.perf_counter()
        self.dc = self.data[self.slic]
        self.out.append_stdout(f"Data shape is {self.dc.shape}\n")

        self.Qy1d, self.Qx1d = get_qx_qy_1d(self.scan_dimensions,
                                            self.dxy,
                                            fft_shifted=False)
        self.Ky, self.Kx = get_qx_qy_1d(self.dc.shape[-2:],
                                        self.r_min,
                                        fft_shifted=True)

        self.Kx = cp.array(self.Kx, dtype=cp.float32)
        self.Ky = cp.array(self.Ky, dtype=cp.float32)
        self.Qy1d = cp.array(self.Qy1d, dtype=cp.float32)
        self.Qx1d = cp.array(self.Qx1d, dtype=cp.float32)

        self.Psi_Qp = cp.zeros(self.scan_dimensions, dtype=np.complex64)
        self.Psi_Qp_left_sb = cp.zeros(self.scan_dimensions,
                                       dtype=np.complex64)
        self.Psi_Qp_right_sb = cp.zeros(self.scan_dimensions,
                                        dtype=np.complex64)
        self.Psi_Rp = cp.zeros(self.scan_dimensions, dtype=np.complex64)
        self.Psi_Rp_left_sb = cp.zeros(self.scan_dimensions,
                                       dtype=np.complex64)
        self.Psi_Rp_right_sb = cp.zeros(self.scan_dimensions,
                                        dtype=np.complex64)

        M = cp.array(self.dc, dtype=cp.complex64)
        start = time.perf_counter()
        G = fft.fft2(M, axes=(0, 1), overwrite_x=True)
        G /= cp.sqrt(np.prod(G.shape[:2]))

        self.out.append_stdout(
            f"FFT along scan coordinate took {time.perf_counter() - start:2.2g}s\n"
        )
        return G
Esempio n. 2
0
def weak_phase_reconstruction(dc: DataCube, aberrations=None, verbose=False, use_cuda=True):
    """
    Perform a ptychographic reconstruction of the datacube assuming a weak phase object.
    In the weak phase object approximation, the dataset in double Fourier-space
    coordinates can be described as [1]::

        G(r',\rho') = |A(r')|^2 \delta(\rho') + A(r')A*(r'+\rho')Ψ*(-\rho')+ A*(r')A(r'-\rho')Ψ(\rho')

    We solve this equation for Ψ*(\rho') in two different ways:

    1) collect all the signal in the bright-field by multiplying G with::

        A(r')A*(r'+\rho')+ A*(r')A(r'-\rho')[2]

    2) collect only the signal in the double-overlap region [1]

    References:
        * [1] Rodenburg, J. M., McCallum, B. C. & Nellist, P. D. Experimental tests on
          double-resolution coherent imaging via STEM. Ultramicroscopy 48, 304–314 (1993).
        * [2] Yang, H., Ercius, P., Nellist, P. D. & Ophus, C. Enhanced phase contrast
          transfer using ptychography combined with a pre-specimen phase plate in a
          scanning transmission electron microscope. Ultramicroscopy 171, 117–125 (2016).

    Args:
        dc: py4DSTEM datacube
        aberrations: optional array shape (12,), cartesian aberration coefficients
        verbose: optional bool, default: False
        use_cuda: optional bool, default: True

    Returns:
        (Psi_Rp, Psi_Rp_left_sb, Psi_Rp_right_sb)
        Psi_Rp is the result of method 1) and Psi_Rp_left_sb, Psi_Rp_right_sb are the results
        of method 2)
    """

    assert 'beam_energy' in dc.metadata.microscope, 'metadata.microscope dictionary missing key: beam_energy'
    assert 'convergence_semiangle_mrad' in dc.metadata.microscope, 'metadata.microscope dictionary missing key: convergence_semiangle_mrad'

    assert 'Q_pixel_size' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: Q_pixel_size'
    assert 'R_pixel_size' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: R_pixel_size'
    assert 'QR_rotation' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: QR_rotation'
    assert 'QR_rotation_units' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: QR_rotation_units'

    M = dc.data

    ny, nx, nky, nkx = M.shape

    E = dc.metadata.microscope['beam_energy']
    alpha_rad = dc.metadata.microscope['convergence_semiangle_mrad'] * 1e-3
    lam = electron_wavelength_angstrom(E)
    eps = 1e-3
    k_max = dc.metadata.calibration['Q_pixel_size']
    dxy = dc.metadata.calibration['R_pixel_size']
    theta = dc.metadata.calibration['QR_rotation']
    if dc.metadata.calibration['QR_rotation_units'] == 'deg':
        theta = np.deg2rad(theta)

    cuda_is_available = config.cupy_enabled

    if verbose:
        print(f"E               = {E}             eV")
        print(f"λ               = {lam * 1e2:2.2}   pm")
        print(f"dR              = {dxy}             Å")
        print(f"dK              = {k_max}           Å")
        print(f"scan       size = {[ny, nx]}")
        print(f"detector   size = {[nky, nkx]}")

    if cuda_is_available:
        M = cp.array(M, dtype=M.dtype)

    xp = sp.get_array_module(M)

    Kx, Ky = get_qx_qy_1d([nkx, nky], k_max, fft_shifted=True)
    Qx, Qy = get_qx_qy_1d([nx, ny], dxy, fft_shifted=False)

    Kx = Kx.astype(M.dtype)
    Ky = Ky.astype(M.dtype)
    Qx = Qx.astype(M.dtype)
    Qy = Qy.astype(M.dtype)

    ap = aperture3(Kx, Ky, lam, alpha_rad).astype(xp.float32)
    scale = 1  # math.sqrt(mean_intensity / aperture_intensity)
    ap *= scale

    start = time.perf_counter()

    G = xp.fft.fft2(M, axes=(0, 1), norm='ortho')
    end = time.perf_counter()
    print(f"FFT along scan coordinate took {end - start}s")

    if aberrations is None:
        aberrations = xp.zeros((12))

    Psi_Qp = xp.zeros((ny, nx), dtype=G.dtype)
    Psi_Qp_left_sb = xp.zeros((ny, nx), dtype=xp.complex64)
    Psi_Qp_right_sb = xp.zeros((ny, nx), dtype=xp.complex64)

    start = time.perf_counter()
    if cuda_is_available:
        threadsperblock = 2 ** 8
        blockspergrid = m.ceil(np.prod(G.shape) / threadsperblock)
        strides = cp.array((np.array(G.strides) / (G.nbytes / G.size)).astype(np.int))

        single_sideband_kernel_cartesian[blockspergrid, threadsperblock](G, strides, Qx, Qy, Kx, Ky, aberrations,
                                                                         theta, alpha_rad, Psi_Qp, Psi_Qp_left_sb,
                                                                         Psi_Qp_right_sb, eps, lam, scale)
    else:
        def get_qx_qy(M, dx, fft_shifted=False):
            qxa = fftfreq(M[0], dx[0])
            qya = fftfreq(M[1], dx[1])
            [qxn, qyn] = np.meshgrid(qxa, qya)
            if fft_shifted:
                qxn = fftshift(qxn)
                qyn = fftshift(qyn)
            return qxn, qyn

        Kx, Ky = get_qx_qy([nkx, nky], k_max, fft_shifted=True)
        # reciprocal in scanning space
        Qx, Qy = get_qx_qy([nx, ny], dxy)

        Kplus = np.sqrt((Kx + Qx[:, :, None, None]) ** 2 + (Ky + Qy[:, :, None, None]) ** 2)
        Kminus = np.sqrt((Kx - Qx[:, :, None, None]) ** 2 + (Ky - Qy[:, :, None, None]) ** 2)
        K = np.sqrt(Kx ** 2 + Ky ** 2)

        A_KplusQ = np.zeros_like(G)
        A_KminusQ = np.zeros_like(G)

        C = np.zeros((12))
        A = np.exp(1j * cartesian_aberrations(Kx, Ky, lam, C)) * aperture_xp(Kx, Ky, lam, alpha_rad, edge=0)

        print('Creating aperture overlap functions')
        for ix, qx in enumerate(Qx[0]):
            print(f"{ix} / {Qx[0].shape}")
            for iy, qy in enumerate(Qy[:, 0]):
                x = Kx + qx
                y = Ky + qy
                A_KplusQ[iy, ix] = np.exp(1j * cartesian_aberrations(x, y, lam, C)) * aperture_xp(x, y, lam, alpha_rad,
                                                                                                  edge=0)
                # A_KplusQ *= 1e4

                x = Kx - qx
                y = Ky - qy
                A_KminusQ[iy, ix] = np.exp(1j * cartesian_aberrations(x, y, lam, C)) * aperture_xp(x, y, lam, alpha_rad,
                                                                                                   edge=0)
                # A_KminusQ *= 1e4

        # [1] Equ. (4): Γ = A*(Kf)A(Kf-Qp) - A(Kf)A*(Kf+Qp)
        Gamma = A.conj() * A_KminusQ - A * A_KplusQ.conj()

        double_overlap1 = (Kplus < alpha_rad / lam) * (K < alpha_rad / lam) * (Kminus > alpha_rad / lam)
        double_overlap2 = (Kplus > alpha_rad / lam) * (K < alpha_rad / lam) * (Kminus < alpha_rad / lam)

        Psi_Qp = np.zeros((ny, nx), dtype=np.complex64)
        Psi_Qp_left_sb = np.zeros((ny, nx), dtype=np.complex64)
        Psi_Qp_right_sb = np.zeros((ny, nx), dtype=np.complex64)
        print(f"Now summing over K-space.")
        for y in trange(ny):
            for x in range(nx):
                Γ_abs = np.abs(Gamma[y, x])
                take = Γ_abs > eps
                Psi_Qp[y, x] = np.sum(G[y, x][take] * Gamma[y, x][take].conj())
                Psi_Qp_left_sb[y, x] = np.sum(G[y, x][double_overlap1[y, x]])
                Psi_Qp_right_sb[y, x] = np.sum(G[y, x][double_overlap2[y, x]])

                # direct beam at zero spatial frequency
                if x == 0 and y == 0:
                    Psi_Qp[y, x] = np.sum(np.abs(G[y, x]))
                    Psi_Qp_left_sb[y, x] = np.sum(np.abs(G[y, x]))
                    Psi_Qp_right_sb[y, x] = np.sum(np.abs(G[y, x]))

    end = time.perf_counter()
    print(f"SSB took {end - start}")

    Psi_Rp = xp.fft.ifft2(Psi_Qp, norm='ortho')
    Psi_Rp_left_sb = xp.fft.ifft2(Psi_Qp_left_sb, norm='ortho')
    Psi_Rp_right_sb = xp.fft.ifft2(Psi_Qp_right_sb, norm='ortho')

    if cuda_is_available:
        Psi_Rp = Psi_Rp.get()
        Psi_Rp_left_sb = Psi_Rp_left_sb.get()
        Psi_Rp_right_sb = Psi_Rp_right_sb.get()

    return Psi_Rp, Psi_Rp_left_sb, Psi_Rp_right_sb
Esempio n. 3
0
def find_rotation_angle_with_double_disk_overlap(G, lam, dx, dscan, alpha_rad, mask=None, n_fit=6, ranges=[360, 30],
                                                 partitions=[144, 120], verbose=False, manual_frequencies=None,
                                                 aberrations=None):
    """
    Finds the best rotation angle by maximizing the double disk overlap intensity of the 4D dataset. Only valid
    for datasets where the scan step size is roughly on the same length scale as the illumination half-angle alp

    Args:
        G (th.tensor, float, 4D): (NY, NX, MY, MX) 4D tensor of disk overlap functions
        lam (float): wavelength in Angstrom
        dx (float, float): 1/(2 * k_max) real_space sampling determined from maximum sampled detector angle, in angstrom
        dscan (float, float): real-space sampling of the scan, in angstrom
        alpha_rad (float): convergence semi-angle in rad
        mask (th.tensor, float, 2D): (NY, NX) mask to apply to G
        n_fit (int): number of "trotters" to use for summation
        ranges (list): list of angle ranges in degrees to try and rotate the disk overlap function to, default [360, 30]
        partitions (list): list of numbers of partitions the range of angles should be split into, default [144,120]
        verbose (bool): optional, talk to me or not
        manual_frequencies (list of 2-tuples): optional, indices into (NY, NX) that pick out spatial frequencies at which the G-function has bragg-peaks/maxima
        aberrations (th.tensor, float, 1D): (12,) aberration coefficients

    Returns:
        tuple (max_ind, thetas, intensities) max_ind: index into thetas and intensities that gives the maximum intensity in the double overlap sum --> the best STEM rotation angle
    """
    ny, nx, nky, nkx = G.shape
    xp = sp.backend.get_array_module(G)

    Kx, Ky = get_qx_qy_1d([nkx, nky], dx, fft_shifted=True)
    Qx, Qy = get_qx_qy_1d([nx, ny], dscan, fft_shifted=False)

    Kx = xp.array(Kx, dtype=G[0, 0, 0, 0].real.dtype)
    Ky = xp.array(Ky, dtype=G[0, 0, 0, 0].real.dtype)
    Qx = xp.array(Qx, dtype=G[0, 0, 0, 0].real.dtype)
    Qy = xp.array(Qy, dtype=G[0, 0, 0, 0].real.dtype)

    if aberrations is None:
        aberrations = xp.zeros((12))

    if manual_frequencies is None:
        Gabs = xp.sum(xp.abs(G), (2, 3))
        if mask is not None:
            gg = Gabs * mask
            inds = xp.argsort((gg).ravel()).get()
        else:
            inds = xp.argsort(Gabs.ravel()).get()
        strongest_object_frequencies = np.unravel_index(inds[-1 - n_fit:-1], G.shape[:2])

        G_max = G[strongest_object_frequencies]
        Qy_max = Qy[strongest_object_frequencies[0]]
        Qx_max = Qx[strongest_object_frequencies[1]]
    else:
        strongest_object_frequencies = manual_frequencies
        G_max = G[strongest_object_frequencies]
        Qy_max = Qy[strongest_object_frequencies[0]]
        Qx_max = Qx[strongest_object_frequencies[1]]

    if verbose:
        print(f"strongest_object_frequencies: {strongest_object_frequencies}")

    best_angle = 0

    for j, (range, parts) in enumerate(zip(ranges, partitions)):
        thetas = np.linspace(best_angle - np.deg2rad(range / 2), best_angle + np.deg2rad(range / 2), parts)
        intensities = double_overlap_intensitities_in_range(G_max, thetas, Qx_max, Qy_max, Kx, Ky, aberrations,
                                                            alpha_rad, lam)

        sortind = np.argsort(intensities)
        max_ind0 = sortind[-1]
        max_ind1 = sortind[0]
        best_angle = thetas[max_ind0]

    max_ind = np.argsort(intensities)[-1]

    return max_ind, thetas, intensities