Example #1
0
def aperture3(qx, qy, lam, alpha_max):
    xp = cp.get_array_module(qx)
    qx2 = qx**2
    qy2 = qy**2
    q = xp.sqrt(qx2 + qy2)
    ktheta = xp.arcsin(q * lam)
    return ktheta < alpha_max
Example #2
0
def aperture_xp(qx, qy, lam, alpha_max, edge=2):
    """
           Return a boolean where (qx,qy) is within a sharp aperture given wavelength and convergence semi-angle

            Args:
                qx (float): Qx components
                qy (float, float, 1D): Qy components
                lam (float): wavelength in angstrom
                alpha_max (float): convergence semi-angle in rad

            Returns:
                bool, True if sqrt(qx**2 + qy**2) * lam < alpha_max
    """

    xp = sp.get_array_module(qx)
    q = xp.sqrt(qx ** 2 + qy ** 2)
    ktheta = xp.arcsin(q * lam)
    qmax = alpha_max / lam
    dk = qx[0][1]

    arr = xp.zeros_like(qx)
    arr[ktheta < alpha_max] = 1
    # riplot(arr,'arr')
    if edge > 0:
        dEdge = edge / (qmax / dk);  # fraction of aperture radius that will be smoothed
        # some fancy indexing: pull out array elements that are within
        #    our smoothing edges
        ind = (ktheta / alpha_max > (1 - dEdge)) * (ktheta / alpha_max < (1 + dEdge))
        arr[ind] = 0.5 * (1 - xp.sin(np.pi / (2 * dEdge) * (ktheta[ind] / alpha_max - 1)))
    return arr
Example #3
0
def aperture_xp(qx, qy, lam, alpha_max, edge=2):
    xp = cp.get_array_module(qx)
    q = xp.sqrt(qx**2 + qy**2)
    ktheta = xp.arcsin(q * lam)
    qmax = alpha_max / lam
    dk = qx[0][1]

    arr = xp.zeros_like(qx)
    arr[ktheta < alpha_max] = 1
    # riplot(arr,'arr')
    if edge > 0:
        dEdge = edge / (qmax / dk)
        # fraction of aperture radius that will be smoothed
        # some fancy indexing: pull out array elements that are within
        #    our smoothing edges
        ind = (ktheta / alpha_max > (1 - dEdge)) * (ktheta / alpha_max <
                                                    (1 + dEdge))
        arr[ind] = 0.5 * (1 - xp.sin(np.pi / (2 * dEdge) *
                                     (ktheta[ind] / alpha_max - 1)))
    return arr
Example #4
0
def aperture3(qx, qy, lam, alpha_max):
    """
       Return a boolean where (qx,qy) is within a sharp aperture given wavelength and convergence semi-angle

        Args:
            qx (float): Qx components
            qy (float, float, 1D): Qy components
            lam (float): wavelength in angstrom
            alpha_max (float): convergence semi-angle in rad

        Returns:
            bool, True if sqrt(qx**2 + qy**2) * lam < alpha_max
    """

    xp = sp.get_array_module(qx)
    qx2 = qx ** 2
    qy2 = qy ** 2
    q = xp.sqrt(qx2 + qy2)
    ktheta = xp.arcsin(q * lam)
    return ktheta < alpha_max
Example #5
0
def weak_phase_reconstruction(dc: DataCube, verbose=False, use_cuda=True):
    """
    Perform a ptychographic reconstruction of the datacube assuming a weak phase object. In the weak phase object
    approximation, the dataset in double Fourier-space coordinates can be described as [1]

    G(r',\rho') = |A(r')|^2 \delta(\rho') + A(r')A*(r'+\rho')Ψ*(-\rho')+ A*(r')A(r'-\rho')Ψ(\rho')

    We solve this equation for Ψ*(\rho') in two different ways:

    1) collect all the signal in the bright-field by multiplying G with A(r')A*(r'+\rho')+ A*(r')A(r'-\rho')[2]
    2) collect only the signal in the double-overlap region [1]

    References:
    [1] Rodenburg, J. M., McCallum, B. C. & Nellist, P. D. Experimental tests on double-resolution coherent imaging via
        STEM. Ultramicroscopy 48, 304–314 (1993).
    [2] Yang, H., Ercius, P., Nellist, P. D. & Ophus, C. Enhanced phase contrast transfer using ptychography combined
        with a pre-specimen phase plate in a scanning transmission electron microscope. Ultramicroscopy 171, 117–125 (2016).

    :param dc: py4DSTEM datacube
    :return: (Ψ_Rp, Ψ_Rp_left_sb, Ψ_Rp_right_sb)
            Ψ_Rp is the result of method 1) and Ψ_Rp_left_sb, Ψ_Rp_right_sb are the results of method 2)
    """

    assert 'accelerating_voltage' in dc.metadata.microscope, 'metadata.microscope dictionary missing key: accelerating_voltage'
    assert 'convergence_semiangle_mrad' in dc.metadata.microscope, 'metadata.microscope dictionary missing key: convergence_semiangle_mrad'

    assert 'K_pix_size' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: K_pix_size'
    assert 'R_pix_size' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: R_pix_size'
    assert 'R_to_K_rotation_degrees' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: R_to_K_rotation_degrees'

    complex_dtype = {"float32": np.complex64, "float64": np.complex128}
    M = dc.data
    ny, nx, nky, nkx = M.shape

    E = dc.metadata.microscope['accelerating_voltage']
    alpha_rad = dc.metadata.microscope['convergence_semiangle_mrad'] * 1e-3
    lam = electron_wavelength_angstrom(E)
    eps = 1e-3
    k_max = dc.metadata.calibration['K_pix_size']
    dxy = dc.metadata.calibration['R_pix_size']
    theta = np.deg2rad(dc.metadata.calibration['R_to_K_rotation_degrees'])

    cuda_is_available = th.cuda.is_available() if use_cuda else False

    if verbose:
        print(f"E               = {E}             eV")
        print(f"λ               = {lam * 1e2:2.2}   pm")
        print(f"dR              = {dxy}             Å")
        print(f"dK              = {k_max}           Å")
        print(f"scan       size = {[ny, nx]}")
        print(f"detector   size = {[nky, nkx]}")

    if cuda_is_available:
        M = cp.array(M, dtype=M.dtype)

    xp = cp.get_array_module(M)

    def get_qx_qy_1D(M, dx, dtype, fft_shifted=False):
        qxa = xp.fft.fftfreq(M[0], dx[0]).astype(dtype)
        qya = xp.fft.fftfreq(M[1], dx[1]).astype(dtype)
        if fft_shifted:
            qxa = xp.fft.fftshift(qxa)
            qya = xp.fft.fftshift(qya)
        return qxa, qya

    Kx, Ky = get_qx_qy_1D([nkx, nky], k_max, M.dtype, fft_shifted=True)
    Qx, Qy = get_qx_qy_1D([nx, ny], dxy, M.dtype, fft_shifted=False)

    pacbed = xp.mean(M, (0, 1))
    mean_intensity = xp.sum(pacbed)
    print(mean_intensity)
    ap = aperture3(Kx, Ky, lam, alpha_rad).astype(xp.float32)
    aperture_intensity = float(xp.sum(ap))
    print(aperture_intensity)
    scale = 1  # math.sqrt(mean_intensity / aperture_intensity)
    ap *= scale

    if verbose:
        if cuda_is_available:
            plot(pacbed.get(), 'PACBED')
        else:
            plot(pacbed, 'PACBED')

    start = time.perf_counter()

    # M = xp.pad(M, ((ny // 2, ny // 2), (nx // 2, nx // 2), (0, 0), (0, 0)), mode='constant', constant_values=xp.mean(M).get())

    G = xp.fft.fft2(M, axes=(0, 1), norm='ortho')
    end = time.perf_counter()
    print(f"FFT along scan coordinate took {end - start}s")

    aberrations = xp.zeros((16))
    aberration_angles = xp.zeros((12))

    Ψ_Qp = xp.zeros((ny, nx), dtype=G.dtype)
    Ψ_Qp_left_sb = xp.zeros((ny, nx), dtype=np.complex64)
    Ψ_Qp_right_sb = xp.zeros((ny, nx), dtype=np.complex64)

    start = time.perf_counter()
    if cuda_is_available:
        gs = G.shape
        threadsperblock = 2**8
        blockspergrid = m.ceil(np.prod(G.shape) / threadsperblock)
        strides = cp.array(
            (np.array(G.strides) / (G.nbytes / G.size)).astype(np.int))

        # Gamma = xp.zeros_like(G)
        single_sideband_kernel[blockspergrid,
                               threadsperblock](G, strides, Qx, Qy, Kx, Ky,
                                                aberrations, aberration_angles,
                                                theta, alpha_rad, Ψ_Qp,
                                                Ψ_Qp_left_sb, Ψ_Qp_right_sb,
                                                eps, lam, scale)
    else:

        def get_qx_qy(M, dx, fft_shifted=False):
            qxa = fftfreq(M[0], dx[0])
            qya = fftfreq(M[1], dx[1])
            [qxn, qyn] = np.meshgrid(qxa, qya)
            if fft_shifted:
                qxn = fftshift(qxn)
                qyn = fftshift(qyn)
            return qxn, qyn

        Kx, Ky = get_qx_qy([nkx, nky], k_max, fft_shifted=True)
        # reciprocal in scanning space
        Qx, Qy = get_qx_qy([nx, ny], dxy)

        Kplus = np.sqrt((Kx + Qx[:, :, None, None])**2 +
                        (Ky + Qy[:, :, None, None])**2)
        Kminus = np.sqrt((Kx - Qx[:, :, None, None])**2 +
                         (Ky - Qy[:, :, None, None])**2)
        K = np.sqrt(Kx**2 + Ky**2)

        A_KplusQ = np.zeros_like(G)
        A_KminusQ = np.zeros_like(G)

        a20 = th.Tensor([20])

        C = np.zeros((12))
        A = np.exp(1j * cartesian_aberrations(Kx, Ky, lam, C)) * aperture_xp(
            Kx, Ky, lam, alpha_rad, edge=0)

        print('Creating aperture overlap functions')
        for ix, qx in enumerate(Qx[0]):
            print(f"{ix} / {Qx[0].shape}")
            for iy, qy in enumerate(Qy[:, 0]):
                x = Kx + qx
                y = Ky + qy
                A_KplusQ[iy, ix] = np.exp(1j * cartesian_aberrations(
                    x, y, lam, C)) * aperture_xp(x, y, lam, alpha_rad, edge=0)
                # A_KplusQ *= 1e4

                x = Kx - qx
                y = Ky - qy
                A_KminusQ[iy, ix] = np.exp(1j * cartesian_aberrations(
                    x, y, lam, C)) * aperture_xp(x, y, lam, alpha_rad, edge=0)
                # A_KminusQ *= 1e4

        # [1] Equ. (4): Γ = A*(Kf)A(Kf-Qp) - A(Kf)A*(Kf+Qp)
        Γ = A.conj() * A_KminusQ - A * A_KplusQ.conj()

        double_overlap1 = (Kplus < alpha_rad / lam) * (K < alpha_rad / lam) * (
            Kminus > alpha_rad / lam)
        double_overlap2 = (Kplus > alpha_rad / lam) * (K < alpha_rad / lam) * (
            Kminus < alpha_rad / lam)

        Ψ_Qp = np.zeros((ny, nx), dtype=np.complex64)
        Ψ_Qp_left_sb = np.zeros((ny, nx), dtype=np.complex64)
        Ψ_Qp_right_sb = np.zeros((ny, nx), dtype=np.complex64)
        print(f"Now summing over K-space.")
        for y in trange(ny):
            for x in range(nx):
                Γ_abs = np.abs(Γ[y, x])
                take = Γ_abs > eps
                Ψ_Qp[y, x] = np.sum(G[y, x][take] * Γ[y, x][take].conj())
                Ψ_Qp_left_sb[y, x] = np.sum(G[y, x][double_overlap1[y, x]])
                Ψ_Qp_right_sb[y, x] = np.sum(G[y, x][double_overlap2[y, x]])

                # direct beam at zero spatial frequency
                if x == 0 and y == 0:
                    Ψ_Qp[y, x] = np.sum(np.abs(G[y, x]))
                    Ψ_Qp_left_sb[y, x] = np.sum(np.abs(G[y, x]))
                    Ψ_Qp_right_sb[y, x] = np.sum(np.abs(G[y, x]))

    end = time.perf_counter()
    print(f"SSB took {end - start}")

    Ψ_Rp = xp.fft.ifft2(Ψ_Qp, norm='ortho')
    Ψ_Rp_left_sb = xp.fft.ifft2(Ψ_Qp_left_sb, norm='ortho')
    Ψ_Rp_right_sb = xp.fft.ifft2(Ψ_Qp_right_sb, norm='ortho')

    if cuda_is_available:
        Ψ_Rp = Ψ_Rp.get()
        Ψ_Rp_left_sb = Ψ_Rp_left_sb.get()
        Ψ_Rp_right_sb = Ψ_Rp_right_sb.get()

    return Ψ_Rp, Ψ_Rp_left_sb, Ψ_Rp_right_sb
    ksp = np.load(args.ksp_file, 'r')
    coord = np.load(args.coord_file)
    dcf = np.load(args.dcf_file)
    mps = np.load(args.mps_file, 'r')
    resp = np.load(args.resp_file)

    comm = sp.Communicator()
    if args.multi_gpu:
        device = sp.Device(comm.rank)
    else:
        device = sp.Device(args.device)

    # Split between nodes.
    ksp = ksp[comm.rank::comm.size]
    mps = mps[comm.rank::comm.size]
    mrimg = MotionResolvedRecon(ksp,
                                coord,
                                dcf,
                                mps,
                                resp,
                                args.B,
                                max_iter=args.max_iter,
                                lamda=args.lamda,
                                device=device,
                                comm=comm).run()

    if comm.rank == 0:
        xp = sp.get_array_module(mrimg)
        xp.save(args.mrimg_file, mrimg)
Example #7
0
def weak_phase_reconstruction(dc: DataCube, aberrations=None, verbose=False, use_cuda=True):
    """
    Perform a ptychographic reconstruction of the datacube assuming a weak phase object.
    In the weak phase object approximation, the dataset in double Fourier-space
    coordinates can be described as [1]::

        G(r',\rho') = |A(r')|^2 \delta(\rho') + A(r')A*(r'+\rho')Ψ*(-\rho')+ A*(r')A(r'-\rho')Ψ(\rho')

    We solve this equation for Ψ*(\rho') in two different ways:

    1) collect all the signal in the bright-field by multiplying G with::

        A(r')A*(r'+\rho')+ A*(r')A(r'-\rho')[2]

    2) collect only the signal in the double-overlap region [1]

    References:
        * [1] Rodenburg, J. M., McCallum, B. C. & Nellist, P. D. Experimental tests on
          double-resolution coherent imaging via STEM. Ultramicroscopy 48, 304–314 (1993).
        * [2] Yang, H., Ercius, P., Nellist, P. D. & Ophus, C. Enhanced phase contrast
          transfer using ptychography combined with a pre-specimen phase plate in a
          scanning transmission electron microscope. Ultramicroscopy 171, 117–125 (2016).

    Args:
        dc: py4DSTEM datacube
        aberrations: optional array shape (12,), cartesian aberration coefficients
        verbose: optional bool, default: False
        use_cuda: optional bool, default: True

    Returns:
        (Psi_Rp, Psi_Rp_left_sb, Psi_Rp_right_sb)
        Psi_Rp is the result of method 1) and Psi_Rp_left_sb, Psi_Rp_right_sb are the results
        of method 2)
    """

    assert 'beam_energy' in dc.metadata.microscope, 'metadata.microscope dictionary missing key: beam_energy'
    assert 'convergence_semiangle_mrad' in dc.metadata.microscope, 'metadata.microscope dictionary missing key: convergence_semiangle_mrad'

    assert 'Q_pixel_size' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: Q_pixel_size'
    assert 'R_pixel_size' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: R_pixel_size'
    assert 'QR_rotation' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: QR_rotation'
    assert 'QR_rotation_units' in dc.metadata.calibration, 'metadata.calibration dictionary missing key: QR_rotation_units'

    M = dc.data

    ny, nx, nky, nkx = M.shape

    E = dc.metadata.microscope['beam_energy']
    alpha_rad = dc.metadata.microscope['convergence_semiangle_mrad'] * 1e-3
    lam = electron_wavelength_angstrom(E)
    eps = 1e-3
    k_max = dc.metadata.calibration['Q_pixel_size']
    dxy = dc.metadata.calibration['R_pixel_size']
    theta = dc.metadata.calibration['QR_rotation']
    if dc.metadata.calibration['QR_rotation_units'] == 'deg':
        theta = np.deg2rad(theta)

    cuda_is_available = config.cupy_enabled

    if verbose:
        print(f"E               = {E}             eV")
        print(f"λ               = {lam * 1e2:2.2}   pm")
        print(f"dR              = {dxy}             Å")
        print(f"dK              = {k_max}           Å")
        print(f"scan       size = {[ny, nx]}")
        print(f"detector   size = {[nky, nkx]}")

    if cuda_is_available:
        M = cp.array(M, dtype=M.dtype)

    xp = sp.get_array_module(M)

    Kx, Ky = get_qx_qy_1d([nkx, nky], k_max, fft_shifted=True)
    Qx, Qy = get_qx_qy_1d([nx, ny], dxy, fft_shifted=False)

    Kx = Kx.astype(M.dtype)
    Ky = Ky.astype(M.dtype)
    Qx = Qx.astype(M.dtype)
    Qy = Qy.astype(M.dtype)

    ap = aperture3(Kx, Ky, lam, alpha_rad).astype(xp.float32)
    scale = 1  # math.sqrt(mean_intensity / aperture_intensity)
    ap *= scale

    start = time.perf_counter()

    G = xp.fft.fft2(M, axes=(0, 1), norm='ortho')
    end = time.perf_counter()
    print(f"FFT along scan coordinate took {end - start}s")

    if aberrations is None:
        aberrations = xp.zeros((12))

    Psi_Qp = xp.zeros((ny, nx), dtype=G.dtype)
    Psi_Qp_left_sb = xp.zeros((ny, nx), dtype=xp.complex64)
    Psi_Qp_right_sb = xp.zeros((ny, nx), dtype=xp.complex64)

    start = time.perf_counter()
    if cuda_is_available:
        threadsperblock = 2 ** 8
        blockspergrid = m.ceil(np.prod(G.shape) / threadsperblock)
        strides = cp.array((np.array(G.strides) / (G.nbytes / G.size)).astype(np.int))

        single_sideband_kernel_cartesian[blockspergrid, threadsperblock](G, strides, Qx, Qy, Kx, Ky, aberrations,
                                                                         theta, alpha_rad, Psi_Qp, Psi_Qp_left_sb,
                                                                         Psi_Qp_right_sb, eps, lam, scale)
    else:
        def get_qx_qy(M, dx, fft_shifted=False):
            qxa = fftfreq(M[0], dx[0])
            qya = fftfreq(M[1], dx[1])
            [qxn, qyn] = np.meshgrid(qxa, qya)
            if fft_shifted:
                qxn = fftshift(qxn)
                qyn = fftshift(qyn)
            return qxn, qyn

        Kx, Ky = get_qx_qy([nkx, nky], k_max, fft_shifted=True)
        # reciprocal in scanning space
        Qx, Qy = get_qx_qy([nx, ny], dxy)

        Kplus = np.sqrt((Kx + Qx[:, :, None, None]) ** 2 + (Ky + Qy[:, :, None, None]) ** 2)
        Kminus = np.sqrt((Kx - Qx[:, :, None, None]) ** 2 + (Ky - Qy[:, :, None, None]) ** 2)
        K = np.sqrt(Kx ** 2 + Ky ** 2)

        A_KplusQ = np.zeros_like(G)
        A_KminusQ = np.zeros_like(G)

        C = np.zeros((12))
        A = np.exp(1j * cartesian_aberrations(Kx, Ky, lam, C)) * aperture_xp(Kx, Ky, lam, alpha_rad, edge=0)

        print('Creating aperture overlap functions')
        for ix, qx in enumerate(Qx[0]):
            print(f"{ix} / {Qx[0].shape}")
            for iy, qy in enumerate(Qy[:, 0]):
                x = Kx + qx
                y = Ky + qy
                A_KplusQ[iy, ix] = np.exp(1j * cartesian_aberrations(x, y, lam, C)) * aperture_xp(x, y, lam, alpha_rad,
                                                                                                  edge=0)
                # A_KplusQ *= 1e4

                x = Kx - qx
                y = Ky - qy
                A_KminusQ[iy, ix] = np.exp(1j * cartesian_aberrations(x, y, lam, C)) * aperture_xp(x, y, lam, alpha_rad,
                                                                                                   edge=0)
                # A_KminusQ *= 1e4

        # [1] Equ. (4): Γ = A*(Kf)A(Kf-Qp) - A(Kf)A*(Kf+Qp)
        Gamma = A.conj() * A_KminusQ - A * A_KplusQ.conj()

        double_overlap1 = (Kplus < alpha_rad / lam) * (K < alpha_rad / lam) * (Kminus > alpha_rad / lam)
        double_overlap2 = (Kplus > alpha_rad / lam) * (K < alpha_rad / lam) * (Kminus < alpha_rad / lam)

        Psi_Qp = np.zeros((ny, nx), dtype=np.complex64)
        Psi_Qp_left_sb = np.zeros((ny, nx), dtype=np.complex64)
        Psi_Qp_right_sb = np.zeros((ny, nx), dtype=np.complex64)
        print(f"Now summing over K-space.")
        for y in trange(ny):
            for x in range(nx):
                Γ_abs = np.abs(Gamma[y, x])
                take = Γ_abs > eps
                Psi_Qp[y, x] = np.sum(G[y, x][take] * Gamma[y, x][take].conj())
                Psi_Qp_left_sb[y, x] = np.sum(G[y, x][double_overlap1[y, x]])
                Psi_Qp_right_sb[y, x] = np.sum(G[y, x][double_overlap2[y, x]])

                # direct beam at zero spatial frequency
                if x == 0 and y == 0:
                    Psi_Qp[y, x] = np.sum(np.abs(G[y, x]))
                    Psi_Qp_left_sb[y, x] = np.sum(np.abs(G[y, x]))
                    Psi_Qp_right_sb[y, x] = np.sum(np.abs(G[y, x]))

    end = time.perf_counter()
    print(f"SSB took {end - start}")

    Psi_Rp = xp.fft.ifft2(Psi_Qp, norm='ortho')
    Psi_Rp_left_sb = xp.fft.ifft2(Psi_Qp_left_sb, norm='ortho')
    Psi_Rp_right_sb = xp.fft.ifft2(Psi_Qp_right_sb, norm='ortho')

    if cuda_is_available:
        Psi_Rp = Psi_Rp.get()
        Psi_Rp_left_sb = Psi_Rp_left_sb.get()
        Psi_Rp_right_sb = Psi_Rp_right_sb.get()

    return Psi_Rp, Psi_Rp_left_sb, Psi_Rp_right_sb