예제 #1
0
파일: test_ctf.py 프로젝트: nzaker/abTEM
def test_ctf_raises():
    with pytest.raises(ValueError) as e:
        CTF(not_a_parameter=10)

    assert str(e.value) == 'not_a_parameter not a recognized parameter'

    ctf = CTF()
    with pytest.raises(RuntimeError) as e:
        ctf.evaluate(0, 0)

    assert str(e.value) == 'energy is not defined'

    ctf.energy = 200e3
    ctf.evaluate(0, 0)
예제 #2
0
    def apply_ctf(self, ctf: CTF = None, **kwargs):
        """
        Apply the aberrations defined by a CTF object to wave function.

        :param ctf: Contrast Transfer Function object to be applied.
        :param kwargs: Provide the aberration coefficients as keyword arguments.
        :return: The wave functions with aberrations applied.
        """
        xp = get_array_module(self.array)
        fft2_convolve = get_device_function(get_array_module(self.array),
                                            'fft2_convolve')

        if ctf is None:
            ctf = CTF(**kwargs)

        ctf.accelerator.match(self.accelerator)
        kx, ky = spatial_frequencies(self.grid.gpts, self.grid.sampling)
        alpha, phi = polargrid(xp.asarray(kx * self.wavelength),
                               xp.asarray(ky * self.wavelength))
        kernel = ctf.evaluate(alpha, phi)

        return self.__class__(fft2_convolve(self.array,
                                            kernel,
                                            overwrite_x=False),
                              extent=self.extent,
                              energy=self.energy)
예제 #3
0
class SMatrix(HasGridAndAcceleratorMixin):
    """
    Scattering matrix object.

    The scattering matrix object represents a plane wave expansion of a probe.

    :param array: The array representation of the scattering matrix.
    :param expansion_cutoff: The angular cutoff of the plane wave expansion [mrad].
    :param interpolation: Interpolation factor.
    :param k: The spatial frequencies of each plane in the plane wave expansion.
    :param ctf: The probe contrast transfer function.
    :param extent: Lateral extent of wave functions [Å].
    :param gpts: Number of grid points describing the wave functions.
    :param sampling: Lateral sampling of wave functions [1 / Å].
    :param energy: Electron energy [eV].
    """
    def __init__(self,
                 array: np.ndarray,
                 expansion_cutoff: float,
                 interpolation: int,
                 k: Tuple[np.ndarray, np.ndarray],
                 ctf: CTF = None,
                 extent: Union[float, Sequence[float]] = None,
                 sampling: Union[float, Sequence[float]] = None,
                 energy: float = None,
                 device='cpu'):

        self._array = array
        self._interpolation = interpolation
        self._expansion_cutoff = expansion_cutoff
        self._k = k
        self._grid = Grid(extent=extent,
                          gpts=array.shape[1:],
                          sampling=sampling,
                          lock_gpts=True)

        self._accelerator = Accelerator(energy=energy)

        if ctf is None:
            ctf = CTF(semiangle_cutoff=expansion_cutoff, rolloff=.1)

        self.set_ctf(ctf)

        self._device = device

    def set_ctf(self, ctf: CTF = None, **kwargs):
        """
        Set the contrast transfer function.

        :param ctf: New contrast transfer function.
        :param kwargs: Provide the contrast transfer function as keyword arguments.
        """

        if ctf is None:
            self._ctf = CTF(**kwargs)
        else:
            self._ctf = copy(ctf)
        self._ctf._accelerator = self._accelerator

    @property
    def ctf(self) -> CTF:
        """
        Probe contrast transfer function.
        """
        return self._ctf

    @property
    def array(self) -> np.ndarray:
        """
        Array representing the scattering matrix.
        """
        return self._array

    @property
    def k(self) -> Tuple[np.ndarray, np.ndarray]:
        """
        The spatial frequencies of each wave in the plane wave expansion.
        """
        return self._k

    @property
    def interpolation(self) -> int:
        """
        Interpolation factor.
        """
        return self._interpolation

    @property
    def interpolated_grid(self) -> Grid:
        """
        The grid of the interpolated scattering matrix.
        """
        interpolated_gpts = tuple(n // self.interpolation for n in self.gpts)
        return Grid(gpts=interpolated_gpts,
                    sampling=self.sampling,
                    lock_gpts=True)

    def _evaluate_ctf(self):
        xp = get_array_module(self._array)
        alpha = xp.sqrt(self.k[0]**2 + self.k[1]**2) * self.wavelength
        phi = xp.arctan2(self.k[0], self.k[1])
        return self._ctf.evaluate(alpha, phi)

    def __len__(self) -> int:
        return len(self._array)

    def _generate_partial(self, max_batch: int = None, pbar: bool = True):
        if max_batch is None:
            n_batches = 1
        else:
            n_batches = (len(self) + (-len(self) % max_batch)) // max_batch

        batch_pbar = ProgressBar(total=len(self),
                                 desc='Batches',
                                 disable=(not pbar) or (n_batches == 1))
        batch_sizes = split_integer(len(self), n_batches)
        N = 0
        for batch_size in batch_sizes:
            yield PartialSMatrix(N, N + batch_size, self)
            N += batch_size
            batch_pbar.update(batch_size)

        batch_pbar.refresh()
        batch_pbar.close()

    def multislice(self,
                   potential: AbstractPotential,
                   max_batch=None,
                   pbar: bool = True):
        """
        Propagate the scattering matrix through the provided potential.

        :param positions: Positions of the probe wave functions
        :param max_batch: The probe batch size. Larger batches are faster, but require more memory.
        :param pbar: If true, display progress bars.
        :return: Probe exit wave functions as a Waves object.
        """
        propagator = FresnelPropagator()

        if isinstance(pbar, bool):
            pbar = ProgressBar(total=len(potential),
                               desc='Multislice',
                               disable=not pbar)

        for partial_s_matrix in self._generate_partial(max_batch):
            _multislice(partial_s_matrix,
                        potential,
                        propagator=propagator,
                        pbar=pbar)

        pbar.refresh()
        return self

    def collapse(self,
                 positions: Sequence[float],
                 max_batch_expansion: int = None) -> Waves:
        """
        Collapse the scattering matrix to probe wave functions centered on the provided positions.

        :param positions: The positions of the probe wave functions.
        :param max_batch_expansion: The maximum number of plane waves the reduction is applied to simultanously.
        :return: Probe wave functions for the provided positions.
        """
        xp = get_array_module(self.array)
        complex_exponential = get_device_function(xp, 'complex_exponential')
        scale_reduce = get_device_function(xp, 'scale_reduce')
        windowed_scale_reduce = get_device_function(xp,
                                                    'windowed_scale_reduce')

        positions = np.array(positions, dtype=xp.float32)

        if positions.shape == (2, ):
            positions = positions[None]
        elif (len(positions.shape) != 2) or (positions.shape[-1] != 2):
            raise RuntimeError()

        interpolated_grid = self.interpolated_grid
        W = np.floor_divide(interpolated_grid.gpts, 2)
        corners = np.rint(positions / self.sampling - W).astype(np.int)
        corners = np.asarray(corners, dtype=xp.int)
        corners = np.remainder(corners, np.asarray(self.gpts))
        corners = xp.asarray(corners)

        window = xp.zeros((
            len(positions),
            interpolated_grid.gpts[0],
            interpolated_grid.gpts[1],
        ),
                          dtype=xp.complex64)

        positions = xp.asarray(positions)

        translation = (complex_exponential(
            2. * np.pi * self.k[0][None] * positions[:, 0, None]) *
                       complex_exponential(2. * np.pi * self.k[1][None] *
                                           positions[:, 1, None]))

        coefficients = translation * self._evaluate_ctf()

        for partial_s_matrix in self._generate_partial(max_batch_expansion,
                                                       pbar=False):
            partial_coefficients = coefficients[:, partial_s_matrix.
                                                start:partial_s_matrix.stop]

            if self.interpolation > 1:
                windowed_scale_reduce(window, partial_s_matrix.array, corners,
                                      partial_coefficients)
            else:
                scale_reduce(window, partial_s_matrix.array,
                             partial_coefficients)

        return Waves(window,
                     extent=interpolated_grid.extent,
                     energy=self.energy)

    def _generate_probes(self, scan: AbstractScan, max_batch_probes,
                         max_batch_expansion):
        for start, end, positions in scan.generate_positions(
                max_batch=max_batch_probes):
            yield start, end, self.collapse(
                positions, max_batch_expansion=max_batch_expansion)

    def scan(self,
             scan: AbstractScan,
             detectors: Sequence[AbstractDetector],
             max_batch_probes=1,
             max_batch_expansion=None,
             pbar: Union[ProgressBar, bool] = True):
        """
        Raster scan the probe across the potential and record a measurement for each detector.

        :param scan: Scan object defining the positions of the probe wave functions.
        :param detectors: The detectors recording the measurments.
        :param potential: The potential across which to scan the probe.
        :param max_batch_probes: The probe batch size. Larger batches are faster, but require more memory.
        :param max_batch_expansion: The expansion plane wave batch size.
        :param pbar: If true, display progress bars.
        :return: Dictionary of measurements with keys given by the detector.
        """

        measurements = {}
        for detector in detectors:
            measurements[detector] = detector.allocate_measurement(
                self.interpolated_grid, self.wavelength, scan)

        if isinstance(pbar, bool):
            pbar = ProgressBar(total=len(scan), desc='Scan', disable=not pbar)

        for start, end, exit_probes in self._generate_probes(
                scan, max_batch_probes, max_batch_expansion):
            for detector, measurement in measurements.items():
                scan.insert_new_measurement(measurement, start, end,
                                            detector.detect(exit_probes))
            pbar.update(end - start)

        pbar.refresh()
        pbar.close()
        return measurements
예제 #4
0
class Probe(HasGridAndAcceleratorMixin):
    """
    Probe wave function object.

    The probe object can represent a stack of electron probe wave function for simulating scanning transmission
    electron microscopy.

<<<<<<< HEAD
    :param semiangle_cutoff: Convergence semi-angle [mrad.].
=======
    :param semiangle_cutoff: Convergence semi-angle [mrad].
>>>>>>> 97df8915641cc8531f632f24e687653e7cdf83ed
    :param rolloff: Softens the cutoff. A value of 0 gives a hard cutoff, while 1 gives the softest possible cutoff.
    :param focal_spread: The focal spread due to, among other factors, chromatic aberrations and lens current
        instabilities.
    :param angular_spread:
    :param ctf_parameters: The parameters describing the phase aberrations using polar notation or an alias.
        See the documentation of the CTF object for a description.
        Convert from cartesian to polar parameters using ´transfer.cartesian2polar´.
    :param extent: Lateral extent of wave functions [Å].
    :param gpts: Number of grid points describing the wave functions.
    :param sampling: Lateral sampling of wave functions [1 / Å].
    :param energy: Electron energy [eV].
    :param device: The probe wave functions will be build on this device.
    :param kwargs: Provide the aberration coefficients as keyword arguments.
    """
    def __init__(self,
                 semiangle_cutoff: float = np.inf,
                 rolloff: float = 0.1,
                 focal_spread: float = 0.,
                 angular_spread: float = 0.,
                 ctf_parameters: dict = None,
                 extent: Union[float, Sequence[float]] = None,
                 gpts: Union[int, Sequence[int]] = None,
                 sampling: Union[float, Sequence[float]] = None,
                 energy: float = None,
                 device='cpu',
                 **kwargs):

        self._ctf = CTF(semiangle_cutoff=semiangle_cutoff,
                        rolloff=rolloff,
                        focal_spread=focal_spread,
                        angular_spread=angular_spread,
                        parameters=ctf_parameters,
                        energy=energy,
                        **kwargs)
        self._accelerator = self._ctf._accelerator
        self._grid = Grid(extent=extent, gpts=gpts, sampling=sampling)
        self._ctf_cache = Cache(1)

        self._ctf.changed.register(cache_clear_callback(self._ctf_cache))
        self._grid.changed.register(cache_clear_callback(self._ctf_cache))
        self._accelerator.changed.register(
            cache_clear_callback(self._ctf_cache))

        self._device = device

    @property
    def ctf(self) -> CTF:
        """
        Probe contrast transfer function.
        """
        return self._ctf

    def _fourier_translation_operator(self, positions):
        xp = get_array_module(positions)
        complex_exponential = get_device_function(xp, 'complex_exponential')

        kx, ky = spatial_frequencies(self.grid.gpts, self.grid.sampling)
        kx = kx.reshape((1, -1, 1))
        ky = ky.reshape((1, 1, -1))
        kx = xp.asarray(kx)
        ky = xp.asarray(ky)
        positions = xp.asarray(positions)
        x = positions[:, 0].reshape((-1, ) + (1, 1))
        y = positions[:, 1].reshape((-1, ) + (1, 1))

        return complex_exponential(2 * np.pi * kx * x) * complex_exponential(
            2 * np.pi * ky * y)

    @cached_method('_ctf_cache')
    def _evaluate_ctf(self, xp):
        kx, ky = spatial_frequencies(self.grid.gpts, self.grid.sampling)
        alpha, phi = polargrid(xp.asarray(kx * self.wavelength),
                               xp.asarray(ky * self.wavelength))
        return self._ctf.evaluate(alpha, phi)

    def build(self, positions: Sequence[Sequence[float]] = None) -> Waves:
        """
        Build probe wave functions at the provided positions.

        :param positions: Positions of the probe wave functions
        :return: Probe wave functions as a Waves object.
        """

        self.grid.check_is_defined()
        self.accelerator.check_is_defined()
        xp = get_array_module_from_device(self._device)
        fft2 = get_device_function(xp, 'fft2')

        if positions is None:
            positions = xp.zeros((1, 2), dtype=xp.float32)
        else:
            positions = xp.array(positions, dtype=xp.float32)

        if len(positions.shape) == 1:
            positions = xp.expand_dims(positions, axis=0)

        array = fft2(self._evaluate_ctf(xp) *
                     self._fourier_translation_operator(positions),
                     overwrite_x=True)

        return Waves(array, extent=self.extent, energy=self.energy)

    def multislice(self,
                   positions: Sequence[Sequence[float]],
                   potential: AbstractPotential,
                   pbar=True) -> Waves:
        """
        Build probe wave functions at the provided positions and propagate them through the potential.

        :param positions: Positions of the probe wave functions.
        :param potential: The probe batch size. Larger batches are faster, but require more memory.
        :param pbar: If true, display progress bars.
        :return: Probe exit wave functions as a Waves object.
        """

        self.grid.match(potential)
        return _multislice(self.build(positions), potential, None, pbar)

    def _generate_probes(self, scan: AbstractScan,
                         potential: Union[AbstractPotential,
                                          Atoms], max_batch: int):
        for start, end, positions in scan.generate_positions(
                max_batch=max_batch):
            yield start, end, self.multislice(positions, potential, pbar=False)

    def _generate_tds_probes(self, scan, potential, max_batch, pbar):
        tds_bar = ProgressBar(total=len(potential.frozen_phonons),
                              desc='TDS',
                              disable=(not pbar)
                              or (len(potential.frozen_phonons) == 1))
        potential_pbar = ProgressBar(total=len(potential),
                                     desc='Potential',
                                     disable=not pbar)

        for potential_config in potential.generate_frozen_phonon_potentials(
                pbar=potential_pbar):
            yield self._generate_probes(scan, potential_config, max_batch)
            tds_bar.update(1)

        potential_pbar.close()
        tds_bar.refresh()
        tds_bar.close()

    def scan(self,
             scan: AbstractScan,
             detectors: Union[AbstractDetector, Sequence[AbstractDetector]],
             potential: Union[Atoms, AbstractPotential],
             max_batch: int = 1,
             pbar: bool = True) -> dict:
        """
        Raster scan the probe across the potential and record a measurement for each detector.

        :param scan: Scan object defining the positions of the probe wave functions.
        :param detectors: The detectors recording the measurements.
        :param potential: The potential across which to scan the probe .
        :param max_batch: The probe batch size. Larger batches are faster, but require more memory.
        :param pbar: If true, display progress bars.
        :return: Dictionary of measurements with keys given by the detector.
        """

        self.grid.match(potential.grid)
        self.grid.check_is_defined()

        if isinstance(detectors, AbstractDetector):
            detectors = [detectors]

        measurements = {}
        for detector in detectors:
            measurements[detector] = detector.allocate_measurement(
                self.grid, self.wavelength, scan)

        scan_bar = ProgressBar(total=len(scan), desc='Scan', disable=not pbar)

        if isinstance(potential, AbstractTDSPotentialBuilder):
            probe_generators = self._generate_tds_probes(
                scan, potential, max_batch, pbar)
        else:
            if isinstance(potential, AbstractPotentialBuilder):
                potential = potential.build(pbar=True)

            probe_generators = [
                self._generate_probes(scan, potential, max_batch)
            ]

        for probe_generator in probe_generators:
            scan_bar.reset()
            for start, end, exit_probes in probe_generator:
                for detector, measurement in measurements.items():
                    scan.insert_new_measurement(measurement, start, end,
                                                detector.detect(exit_probes))

                scan_bar.update(end - start)

            scan_bar.refresh()

        scan_bar.close()

        return measurements

    def show(self, profile: bool = False, **kwargs):
        """
        Show the probe wave function.

        :param profile: If true, show a 1D slice of the probe as a line profile.
        :param kwargs: Additional keyword arguments for the plot.show_line or plot.show_image functions.
            See the documentation of the respective function for a description.
        """

        measurement = self.build(
            (self.extent[0] / 2, self.extent[1] / 2)).intensity()

        if profile:
            array = measurement.array[0]
            array = array[array.shape[0] // 2, :]
            calibration = calibrations_from_grid(
                gpts=(self.grid.gpts[1], ),
                sampling=(self.grid.sampling[1], ),
                names=['x'])[0]
            show_line(array, calibration, **kwargs)
        else:
            return measurement.show(**kwargs)

    def __copy__(self) -> 'Probe':
        new_copy = self.__class__()
        new_copy._grid = copy(self.grid)
        new_copy._ctf = copy(self.ctf)
        new_copy._accelerator = copy(self._ctf._accelerator)
        return new_copy