Beispiel #1
0
    def detect(self, waves) -> np.ndarray:
        """
        Integrate the intensity of a the wave functions over the detector range.

        Parameters
        ----------
        waves : Waves object
            The batch of wave functions to detect.

        Returns
        -------
        1d array
            Detected values as a 1D array. The array has the same length as the batch size of the wave functions.
        """

        xp = get_array_module(waves.array)
        fft2 = get_device_function(xp, 'fft2')
        abs2 = get_device_function(xp, 'abs2')

        intensity = abs2(fft2(waves.array, overwrite_x=False))
        return self._integrate_array(intensity, waves.angular_sampling, min(waves.cutoff_scattering_angles))
Beispiel #2
0
    def detect(self, waves) -> np.ndarray:
        """
        Integrate the intensity of a the wave functions over the detector range.

        Parameters
        ----------
        waves: Waves object
            The batch of wave functions to detect.

        Returns
        -------
        3d array
            Detected values. The first dimension indexes the batch size, the second and third indexes the radial and
            angular bins, respectively.
        """

        xp = get_array_module(waves.array)
        fft2 = get_device_function(xp, 'fft2')
        abs2 = get_device_function(xp, 'abs2')
        sum_run_length_encoded = get_device_function(xp,
                                                     'sum_run_length_encoded')

        intensity = abs2(fft2(waves.array, overwrite_x=False))

        indices = self._get_regions(waves.gpts, waves.angular_sampling,
                                    min(waves.cutoff_scattering_angles), xp)
        total = xp.sum(intensity, axis=(-2, -1))

        separators = xp.concatenate(
            (xp.array([0]), xp.cumsum(xp.array([len(ring)
                                                for ring in indices]))))
        intensity = intensity.reshape(
            (intensity.shape[0], -1))[:, xp.concatenate(indices)]
        result = xp.zeros((len(intensity), len(separators) - 1),
                          dtype=xp.float32)
        sum_run_length_encoded(intensity, result, separators)

        return result.reshape(
            (-1, self.nbins_radial, self.nbins_angular)) / total[:, None, None]
Beispiel #3
0
    def diffraction_pattern(self) -> Measurement:
        """
        :return: The intensity of the wave functions at the diffraction plane.
        """
        calibrations = calibrations_from_grid(self.grid.antialiased_gpts,
                                              self.grid.antialiased_sampling,
                                              names=['alpha_x', 'alpha_y'],
                                              units='mrad',
                                              scale_factor=self.wavelength *
                                              1000,
                                              fourier_space=True)

        calibrations = (None, ) * (len(self.array.shape) - 2) + calibrations

        xp = get_array_module(self.array)
        abs2 = get_device_function(xp, 'abs2')
        fft2 = get_device_function(xp, 'fft2')
        pattern = asnumpy(
            abs2(
                crop_to_center(
                    xp.fft.fftshift(fft2(self.array, overwrite_x=False)))))
        return Measurement(pattern, calibrations)
Beispiel #4
0
def fourier_translation_operator(positions: np.ndarray,
                                 shape: tuple) -> np.ndarray:
    """
    Create an array representing one or more phase ramp(s) for shifting another array.

    Parameters
    ----------
    positions : array of xy-positions
    shape : two int

    Returns
    -------

    """

    positions_shape = positions.shape

    if len(positions_shape) == 1:
        positions = positions[None]

    xp = get_array_module(positions)
    complex_exponential = get_device_function(xp, 'complex_exponential')

    kx, ky = spatial_frequencies(shape, (1., 1.))
    kx = kx.reshape((1, -1, 1))
    ky = ky.reshape((1, 1, -1))
    kx = xp.asarray(kx)
    ky = xp.asarray(ky)
    positions = xp.asarray(positions)
    x = positions[:, 0].reshape((-1, ) + (1, 1))
    y = positions[:, 1].reshape((-1, ) + (1, 1))

    result = complex_exponential(-2 * np.pi * kx * x) * complex_exponential(
        -2 * np.pi * ky * y)

    if len(positions_shape) == 1:
        return result[0]
    else:
        return result
Beispiel #5
0
    def evaluate_spatial_envelope(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> \
            Union[float, np.ndarray]:
        xp = get_array_module(alpha)
        p = self.parameters
        dchi_dk = 2 * xp.pi / self.wavelength * (
            (p['C12'] * xp.cos(2. * (phi - p['phi12'])) + p['C10']) * alpha +
            (p['C23'] * xp.cos(3. * (phi - p['phi23'])) +
             p['C21'] * xp.cos(1. * (phi - p['phi21']))) * alpha**2 +
            (p['C34'] * xp.cos(4. * (phi - p['phi34'])) +
             p['C32'] * xp.cos(2. *
                               (phi - p['phi32'])) + p['C30']) * alpha**3 +
            (p['C45'] * xp.cos(5. * (phi - p['phi45'])) +
             p['C43'] * xp.cos(3. * (phi - p['phi43'])) +
             p['C41'] * xp.cos(1. * (phi - p['phi41']))) * alpha**4 +
            (p['C56'] * xp.cos(6. * (phi - p['phi56'])) +
             p['C54'] * xp.cos(4. * (phi - p['phi54'])) +
             p['C52'] * xp.cos(2. * (phi - p['phi52'])) + p['C50']) * alpha**5)

        dchi_dphi = -2 * xp.pi / self.wavelength * (
            1 / 2. *
            (2. * p['C12'] * xp.sin(2. *
                                    (phi - p['phi12']))) * alpha + 1 / 3. *
            (3. * p['C23'] * xp.sin(3. * (phi - p['phi23'])) +
             1. * p['C21'] * xp.sin(1. *
                                    (phi - p['phi21']))) * alpha**2 + 1 / 4. *
            (4. * p['C34'] * xp.sin(4. * (phi - p['phi34'])) +
             2. * p['C32'] * xp.sin(2. *
                                    (phi - p['phi32']))) * alpha**3 + 1 / 5. *
            (5. * p['C45'] * xp.sin(5. * (phi - p['phi45'])) +
             3. * p['C43'] * xp.sin(3. * (phi - p['phi43'])) +
             1. * p['C41'] * xp.sin(1. *
                                    (phi - p['phi41']))) * alpha**4 + 1 / 6. *
            (6. * p['C56'] * xp.sin(6. * (phi - p['phi56'])) +
             4. * p['C54'] * xp.sin(4. * (phi - p['phi54'])) +
             2. * p['C52'] * xp.sin(2. * (phi - p['phi52']))) * alpha**5)

        return xp.exp(-xp.sign(self.angular_spread) *
                      (self.angular_spread / 2 / 1000)**2 *
                      (dchi_dk**2 + dchi_dphi**2))
Beispiel #6
0
 def evaluate_gaussian_envelope(
         self, alpha: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
     xp = get_array_module(alpha)
     return xp.exp(-.5 * self.gaussian_spread**2 * alpha**2 /
                   self.wavelength**2)
Beispiel #7
0
 def evaluate_temporal_envelope(
         self, alpha: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
     xp = get_array_module(alpha)
     return xp.exp(-(.5 * xp.pi / self.wavelength * self.focal_spread *
                     alpha**2)**2).astype(xp.float32)
Beispiel #8
0
def _run_epie(object,
              probe: np.ndarray,
              diffraction_patterns: np.ndarray,
              positions: np.ndarray,
              maxiter: int,
              alpha: float = 1.,
              beta: float = 1.,
              fix_probe: bool = False,
              fix_com: bool = False,
              return_iterations: bool = False,
              seed=None):
    xp = get_array_module(probe)

    object = xp.array(object)
    probe = xp.array(probe)

    if len(diffraction_patterns.shape) != 3:
        raise ValueError()

    if len(diffraction_patterns) != len(positions):
        raise ValueError()

    if object.shape == (2, ):
        object = xp.ones((int(object[0]), int(object[1])), dtype=xp.complex64)
    elif len(object.shape) != 2:
        raise ValueError()

    if probe.shape != diffraction_patterns.shape[1:]:
        raise ValueError()

    if probe.shape != object.shape:
        raise ValueError()

    if return_iterations:
        object_iterations = []
        probe_iterations = []
        SSE_iterations = []

    if seed is not None:
        np.random.seed(seed)

    diffraction_patterns = np.fft.ifftshift(np.sqrt(diffraction_patterns),
                                            axes=(-2, -1))

    SSE = 0.
    k = 0
    outer_pbar = ProgressBar(total=maxiter)
    inner_pbar = ProgressBar(total=len(positions))

    while k < maxiter:
        indices = np.arange(len(positions))
        np.random.shuffle(indices)

        old_position = xp.array((0., 0.))
        inner_pbar.reset()
        SSE = 0.
        for j in indices:
            position = xp.array(positions[j])

            diffraction_pattern = xp.array(diffraction_patterns[j])
            illuminated_object = fft_shift(object, old_position - position)

            g = illuminated_object * probe
            gprime = xp.fft.ifft2(diffraction_pattern *
                                  xp.exp(1j * xp.angle(xp.fft.fft2(g))))

            object = illuminated_object + alpha * (
                gprime - g) * xp.conj(probe) / (xp.max(xp.abs(probe))**2)
            old_position = position

            if not fix_probe:
                probe = probe + beta * (
                    gprime - g) * xp.conj(illuminated_object) / (xp.max(
                        xp.abs(illuminated_object))**2)

            # SSE += xp.sum(xp.abs(G) ** 2 - diffraction_pattern) ** 2
            inner_pbar.update(1)

        object = fft_shift(object, position)

        if fix_com:
            com = center_of_mass(xp.fft.fftshift(xp.abs(probe)**2))
            probe = xp.fft.ifftshift(fft_shift(probe, -xp.array(com)))

        # SSE = SSE / np.prod(diffraction_patterns.shape)

        if return_iterations:
            object_iterations.append(object)
            probe_iterations.append(probe)
            SSE_iterations.append(SSE)

        outer_pbar.update(1)
        # if verbose:
        #    print(f'Iteration {k:<{len(str(maxiter))}}, SSE = {float(SSE):.3e}')

        k += 1

    inner_pbar.close()
    outer_pbar.close()

    if return_iterations:
        return object_iterations, probe_iterations, SSE_iterations
    else:
        return object, probe, SSE
Beispiel #9
0
 def transmit(self, waves):
     self.accelerator.check_match(waves)
     xp = get_array_module(waves._array)
     waves._array *= copy_to_device(self.array, xp)
     return waves
Beispiel #10
0
 def measure(self):
     array = np.fft.fftshift(self.build())[0]
     calibrations = calibrations_from_grid(self.gpts, self.sampling,
                                           ['x', 'y'])
     abs2 = get_device_function(get_array_module(array), 'abs2')
     return Measurement(array, calibrations, name=str(self))
Beispiel #11
0
def polar_coordinates(x, y):
    """Calculate a polar grid for a given Cartesian grid."""
    xp = get_array_module(x)
    alpha = xp.sqrt(x.reshape((-1, 1))**2 + y.reshape((1, -1))**2)
    phi = xp.arctan2(x.reshape((-1, 1)), y.reshape((1, -1)))
    return alpha, phi
Beispiel #12
0
    def evaluate_chi(self, alpha, phi) -> np.ndarray:
        """
        Calculates the polar expansion of the phase error up to 5th order.

        See Eq. 2.22 in ref [1].

        Parameters
        ----------
        alpha : numpy.ndarray
            Angle between the scattered electrons and the optical axis [mrad].
        phi : numpy.ndarray
            Angle around the optical axis of the scattered electrons [mrad].
        wavelength : float
            Relativistic wavelength of wavefunction [Å].
        parameters : Mapping[str, float]
            Mapping from Cnn, phinn coefficients to their corresponding values. See parameter `parameters` in class CTFBase.

        Returns
        -------

        References
        ----------
        .. [1] Kirkland, E. J. (2010). Advanced Computing in Electron Microscopy (2nd ed.). Springer.

        """
        xp = get_array_module(alpha)
        p = self.parameters

        alpha2 = alpha**2

        array = xp.zeros(alpha.shape, dtype=np.float32)
        if any([p[symbol] != 0. for symbol in ('C10', 'C12', 'phi12')]):
            array += (1 / 2 * alpha2 *
                      (p['C10'] + p['C12'] * xp.cos(2 * (phi - p['phi12']))))

        if any(
            [p[symbol] != 0. for symbol in ('C21', 'phi21', 'C23', 'phi23')]):
            array += (1 / 3 * alpha2 * alpha *
                      (p['C21'] * xp.cos(phi - p['phi21']) +
                       p['C23'] * xp.cos(3 * (phi - p['phi23']))))

        if any([
                p[symbol] != 0.
                for symbol in ('C30', 'C32', 'phi32', 'C34', 'phi34')
        ]):
            array += (1 / 4 * alpha2**2 *
                      (p['C30'] + p['C32'] * xp.cos(2 * (phi - p['phi32'])) +
                       p['C34'] * xp.cos(4 * (phi - p['phi34']))))

        if any([
                p[symbol] != 0.
                for symbol in ('C41', 'phi41', 'C43', 'phi43', 'C45', 'phi41')
        ]):
            array += (1 / 5 * alpha2**2 * alpha *
                      (p['C41'] * xp.cos((phi - p['phi41'])) +
                       p['C43'] * xp.cos(3 * (phi - p['phi43'])) +
                       p['C45'] * xp.cos(5 * (phi - p['phi45']))))

        if any([
                p[symbol] != 0. for symbol in ('C50', 'C52', 'phi52', 'C54',
                                               'phi54', 'C56', 'phi56')
        ]):
            array += (1 / 6 * alpha2**3 *
                      (p['C50'] + p['C52'] * xp.cos(2 * (phi - p['phi52'])) +
                       p['C54'] * xp.cos(4 * (phi - p['phi54'])) +
                       p['C56'] * xp.cos(6 * (phi - p['phi56']))))

        array = 2 * xp.pi / self.wavelength * array
        return array
Beispiel #13
0
def polargrid(x, y):
    xp = get_array_module(x)
    alpha = xp.sqrt(x.reshape((-1, 1))**2 + y.reshape((1, -1))**2)
    phi = xp.arctan2(x.reshape((-1, 1)), y.reshape((1, -1)))
    return alpha, phi
Beispiel #14
0
    def collapse(self,
                 positions: Sequence[float],
                 max_batch_expansion: int = None) -> Waves:
        """
        Collapse the scattering matrix to probe wave functions centered on the provided positions.

        :param positions: The positions of the probe wave functions.
        :param max_batch_expansion: The maximum number of plane waves the reduction is applied to simultanously.
        :return: Probe wave functions for the provided positions.
        """
        xp = get_array_module(self.array)
        complex_exponential = get_device_function(xp, 'complex_exponential')
        scale_reduce = get_device_function(xp, 'scale_reduce')
        windowed_scale_reduce = get_device_function(xp,
                                                    'windowed_scale_reduce')

        positions = np.array(positions, dtype=xp.float32)

        if positions.shape == (2, ):
            positions = positions[None]
        elif (len(positions.shape) != 2) or (positions.shape[-1] != 2):
            raise RuntimeError()

        interpolated_grid = self.interpolated_grid
        W = np.floor_divide(interpolated_grid.gpts, 2)
        corners = np.rint(positions / self.sampling - W).astype(np.int)
        corners = np.asarray(corners, dtype=xp.int)
        corners = np.remainder(corners, np.asarray(self.gpts))
        corners = xp.asarray(corners)

        window = xp.zeros((
            len(positions),
            interpolated_grid.gpts[0],
            interpolated_grid.gpts[1],
        ),
                          dtype=xp.complex64)

        positions = xp.asarray(positions)

        translation = (complex_exponential(
            2. * np.pi * self.k[0][None] * positions[:, 0, None]) *
                       complex_exponential(2. * np.pi * self.k[1][None] *
                                           positions[:, 1, None]))

        coefficients = translation * self._evaluate_ctf()

        for partial_s_matrix in self._generate_partial(max_batch_expansion,
                                                       pbar=False):
            partial_coefficients = coefficients[:, partial_s_matrix.
                                                start:partial_s_matrix.stop]

            if self.interpolation > 1:
                windowed_scale_reduce(window, partial_s_matrix.array, corners,
                                      partial_coefficients)
            else:
                scale_reduce(window, partial_s_matrix.array,
                             partial_coefficients)

        return Waves(window,
                     extent=interpolated_grid.extent,
                     energy=self.energy)
Beispiel #15
0
 def _evaluate_ctf(self):
     xp = get_array_module(self._array)
     alpha = xp.sqrt(self.k[0]**2 + self.k[1]**2) * self.wavelength
     phi = xp.arctan2(self.k[0], self.k[1])
     return self._ctf.evaluate(alpha, phi)
Beispiel #16
0
 def evaluate_aberrations(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> \
         Union[float, np.ndarray]:
     xp = get_array_module(alpha)
     complex_exponential = get_device_function(xp, 'complex_exponential')
     return complex_exponential(-self.evaluate_chi(alpha, phi))
Beispiel #17
0
def fft_shift(array, positions):
    xp = get_array_module(array)
    return xp.fft.ifft2(
        xp.fft.fft2(array) *
        fourier_translation_operator(positions, array.shape[-2:]))