def test_ctf_raises(): with pytest.raises(ValueError) as e: CTF(not_a_parameter=10) assert str(e.value) == 'not_a_parameter not a recognized parameter' ctf = CTF() with pytest.raises(RuntimeError) as e: ctf.evaluate(0, 0) assert str(e.value) == 'energy is not defined' ctf.energy = 200e3 ctf.evaluate(0, 0)
def apply_ctf(self, ctf: CTF = None, **kwargs): """ Apply the aberrations defined by a CTF object to wave function. :param ctf: Contrast Transfer Function object to be applied. :param kwargs: Provide the aberration coefficients as keyword arguments. :return: The wave functions with aberrations applied. """ xp = get_array_module(self.array) fft2_convolve = get_device_function(get_array_module(self.array), 'fft2_convolve') if ctf is None: ctf = CTF(**kwargs) ctf.accelerator.match(self.accelerator) kx, ky = spatial_frequencies(self.grid.gpts, self.grid.sampling) alpha, phi = polargrid(xp.asarray(kx * self.wavelength), xp.asarray(ky * self.wavelength)) kernel = ctf.evaluate(alpha, phi) return self.__class__(fft2_convolve(self.array, kernel, overwrite_x=False), extent=self.extent, energy=self.energy)
class SMatrix(HasGridAndAcceleratorMixin): """ Scattering matrix object. The scattering matrix object represents a plane wave expansion of a probe. :param array: The array representation of the scattering matrix. :param expansion_cutoff: The angular cutoff of the plane wave expansion [mrad]. :param interpolation: Interpolation factor. :param k: The spatial frequencies of each plane in the plane wave expansion. :param ctf: The probe contrast transfer function. :param extent: Lateral extent of wave functions [Å]. :param gpts: Number of grid points describing the wave functions. :param sampling: Lateral sampling of wave functions [1 / Å]. :param energy: Electron energy [eV]. """ def __init__(self, array: np.ndarray, expansion_cutoff: float, interpolation: int, k: Tuple[np.ndarray, np.ndarray], ctf: CTF = None, extent: Union[float, Sequence[float]] = None, sampling: Union[float, Sequence[float]] = None, energy: float = None, device='cpu'): self._array = array self._interpolation = interpolation self._expansion_cutoff = expansion_cutoff self._k = k self._grid = Grid(extent=extent, gpts=array.shape[1:], sampling=sampling, lock_gpts=True) self._accelerator = Accelerator(energy=energy) if ctf is None: ctf = CTF(semiangle_cutoff=expansion_cutoff, rolloff=.1) self.set_ctf(ctf) self._device = device def set_ctf(self, ctf: CTF = None, **kwargs): """ Set the contrast transfer function. :param ctf: New contrast transfer function. :param kwargs: Provide the contrast transfer function as keyword arguments. """ if ctf is None: self._ctf = CTF(**kwargs) else: self._ctf = copy(ctf) self._ctf._accelerator = self._accelerator @property def ctf(self) -> CTF: """ Probe contrast transfer function. """ return self._ctf @property def array(self) -> np.ndarray: """ Array representing the scattering matrix. """ return self._array @property def k(self) -> Tuple[np.ndarray, np.ndarray]: """ The spatial frequencies of each wave in the plane wave expansion. """ return self._k @property def interpolation(self) -> int: """ Interpolation factor. """ return self._interpolation @property def interpolated_grid(self) -> Grid: """ The grid of the interpolated scattering matrix. """ interpolated_gpts = tuple(n // self.interpolation for n in self.gpts) return Grid(gpts=interpolated_gpts, sampling=self.sampling, lock_gpts=True) def _evaluate_ctf(self): xp = get_array_module(self._array) alpha = xp.sqrt(self.k[0]**2 + self.k[1]**2) * self.wavelength phi = xp.arctan2(self.k[0], self.k[1]) return self._ctf.evaluate(alpha, phi) def __len__(self) -> int: return len(self._array) def _generate_partial(self, max_batch: int = None, pbar: bool = True): if max_batch is None: n_batches = 1 else: n_batches = (len(self) + (-len(self) % max_batch)) // max_batch batch_pbar = ProgressBar(total=len(self), desc='Batches', disable=(not pbar) or (n_batches == 1)) batch_sizes = split_integer(len(self), n_batches) N = 0 for batch_size in batch_sizes: yield PartialSMatrix(N, N + batch_size, self) N += batch_size batch_pbar.update(batch_size) batch_pbar.refresh() batch_pbar.close() def multislice(self, potential: AbstractPotential, max_batch=None, pbar: bool = True): """ Propagate the scattering matrix through the provided potential. :param positions: Positions of the probe wave functions :param max_batch: The probe batch size. Larger batches are faster, but require more memory. :param pbar: If true, display progress bars. :return: Probe exit wave functions as a Waves object. """ propagator = FresnelPropagator() if isinstance(pbar, bool): pbar = ProgressBar(total=len(potential), desc='Multislice', disable=not pbar) for partial_s_matrix in self._generate_partial(max_batch): _multislice(partial_s_matrix, potential, propagator=propagator, pbar=pbar) pbar.refresh() return self def collapse(self, positions: Sequence[float], max_batch_expansion: int = None) -> Waves: """ Collapse the scattering matrix to probe wave functions centered on the provided positions. :param positions: The positions of the probe wave functions. :param max_batch_expansion: The maximum number of plane waves the reduction is applied to simultanously. :return: Probe wave functions for the provided positions. """ xp = get_array_module(self.array) complex_exponential = get_device_function(xp, 'complex_exponential') scale_reduce = get_device_function(xp, 'scale_reduce') windowed_scale_reduce = get_device_function(xp, 'windowed_scale_reduce') positions = np.array(positions, dtype=xp.float32) if positions.shape == (2, ): positions = positions[None] elif (len(positions.shape) != 2) or (positions.shape[-1] != 2): raise RuntimeError() interpolated_grid = self.interpolated_grid W = np.floor_divide(interpolated_grid.gpts, 2) corners = np.rint(positions / self.sampling - W).astype(np.int) corners = np.asarray(corners, dtype=xp.int) corners = np.remainder(corners, np.asarray(self.gpts)) corners = xp.asarray(corners) window = xp.zeros(( len(positions), interpolated_grid.gpts[0], interpolated_grid.gpts[1], ), dtype=xp.complex64) positions = xp.asarray(positions) translation = (complex_exponential( 2. * np.pi * self.k[0][None] * positions[:, 0, None]) * complex_exponential(2. * np.pi * self.k[1][None] * positions[:, 1, None])) coefficients = translation * self._evaluate_ctf() for partial_s_matrix in self._generate_partial(max_batch_expansion, pbar=False): partial_coefficients = coefficients[:, partial_s_matrix. start:partial_s_matrix.stop] if self.interpolation > 1: windowed_scale_reduce(window, partial_s_matrix.array, corners, partial_coefficients) else: scale_reduce(window, partial_s_matrix.array, partial_coefficients) return Waves(window, extent=interpolated_grid.extent, energy=self.energy) def _generate_probes(self, scan: AbstractScan, max_batch_probes, max_batch_expansion): for start, end, positions in scan.generate_positions( max_batch=max_batch_probes): yield start, end, self.collapse( positions, max_batch_expansion=max_batch_expansion) def scan(self, scan: AbstractScan, detectors: Sequence[AbstractDetector], max_batch_probes=1, max_batch_expansion=None, pbar: Union[ProgressBar, bool] = True): """ Raster scan the probe across the potential and record a measurement for each detector. :param scan: Scan object defining the positions of the probe wave functions. :param detectors: The detectors recording the measurments. :param potential: The potential across which to scan the probe. :param max_batch_probes: The probe batch size. Larger batches are faster, but require more memory. :param max_batch_expansion: The expansion plane wave batch size. :param pbar: If true, display progress bars. :return: Dictionary of measurements with keys given by the detector. """ measurements = {} for detector in detectors: measurements[detector] = detector.allocate_measurement( self.interpolated_grid, self.wavelength, scan) if isinstance(pbar, bool): pbar = ProgressBar(total=len(scan), desc='Scan', disable=not pbar) for start, end, exit_probes in self._generate_probes( scan, max_batch_probes, max_batch_expansion): for detector, measurement in measurements.items(): scan.insert_new_measurement(measurement, start, end, detector.detect(exit_probes)) pbar.update(end - start) pbar.refresh() pbar.close() return measurements
class Probe(HasGridAndAcceleratorMixin): """ Probe wave function object. The probe object can represent a stack of electron probe wave function for simulating scanning transmission electron microscopy. <<<<<<< HEAD :param semiangle_cutoff: Convergence semi-angle [mrad.]. ======= :param semiangle_cutoff: Convergence semi-angle [mrad]. >>>>>>> 97df8915641cc8531f632f24e687653e7cdf83ed :param rolloff: Softens the cutoff. A value of 0 gives a hard cutoff, while 1 gives the softest possible cutoff. :param focal_spread: The focal spread due to, among other factors, chromatic aberrations and lens current instabilities. :param angular_spread: :param ctf_parameters: The parameters describing the phase aberrations using polar notation or an alias. See the documentation of the CTF object for a description. Convert from cartesian to polar parameters using ´transfer.cartesian2polar´. :param extent: Lateral extent of wave functions [Å]. :param gpts: Number of grid points describing the wave functions. :param sampling: Lateral sampling of wave functions [1 / Å]. :param energy: Electron energy [eV]. :param device: The probe wave functions will be build on this device. :param kwargs: Provide the aberration coefficients as keyword arguments. """ def __init__(self, semiangle_cutoff: float = np.inf, rolloff: float = 0.1, focal_spread: float = 0., angular_spread: float = 0., ctf_parameters: dict = None, extent: Union[float, Sequence[float]] = None, gpts: Union[int, Sequence[int]] = None, sampling: Union[float, Sequence[float]] = None, energy: float = None, device='cpu', **kwargs): self._ctf = CTF(semiangle_cutoff=semiangle_cutoff, rolloff=rolloff, focal_spread=focal_spread, angular_spread=angular_spread, parameters=ctf_parameters, energy=energy, **kwargs) self._accelerator = self._ctf._accelerator self._grid = Grid(extent=extent, gpts=gpts, sampling=sampling) self._ctf_cache = Cache(1) self._ctf.changed.register(cache_clear_callback(self._ctf_cache)) self._grid.changed.register(cache_clear_callback(self._ctf_cache)) self._accelerator.changed.register( cache_clear_callback(self._ctf_cache)) self._device = device @property def ctf(self) -> CTF: """ Probe contrast transfer function. """ return self._ctf def _fourier_translation_operator(self, positions): xp = get_array_module(positions) complex_exponential = get_device_function(xp, 'complex_exponential') kx, ky = spatial_frequencies(self.grid.gpts, self.grid.sampling) kx = kx.reshape((1, -1, 1)) ky = ky.reshape((1, 1, -1)) kx = xp.asarray(kx) ky = xp.asarray(ky) positions = xp.asarray(positions) x = positions[:, 0].reshape((-1, ) + (1, 1)) y = positions[:, 1].reshape((-1, ) + (1, 1)) return complex_exponential(2 * np.pi * kx * x) * complex_exponential( 2 * np.pi * ky * y) @cached_method('_ctf_cache') def _evaluate_ctf(self, xp): kx, ky = spatial_frequencies(self.grid.gpts, self.grid.sampling) alpha, phi = polargrid(xp.asarray(kx * self.wavelength), xp.asarray(ky * self.wavelength)) return self._ctf.evaluate(alpha, phi) def build(self, positions: Sequence[Sequence[float]] = None) -> Waves: """ Build probe wave functions at the provided positions. :param positions: Positions of the probe wave functions :return: Probe wave functions as a Waves object. """ self.grid.check_is_defined() self.accelerator.check_is_defined() xp = get_array_module_from_device(self._device) fft2 = get_device_function(xp, 'fft2') if positions is None: positions = xp.zeros((1, 2), dtype=xp.float32) else: positions = xp.array(positions, dtype=xp.float32) if len(positions.shape) == 1: positions = xp.expand_dims(positions, axis=0) array = fft2(self._evaluate_ctf(xp) * self._fourier_translation_operator(positions), overwrite_x=True) return Waves(array, extent=self.extent, energy=self.energy) def multislice(self, positions: Sequence[Sequence[float]], potential: AbstractPotential, pbar=True) -> Waves: """ Build probe wave functions at the provided positions and propagate them through the potential. :param positions: Positions of the probe wave functions. :param potential: The probe batch size. Larger batches are faster, but require more memory. :param pbar: If true, display progress bars. :return: Probe exit wave functions as a Waves object. """ self.grid.match(potential) return _multislice(self.build(positions), potential, None, pbar) def _generate_probes(self, scan: AbstractScan, potential: Union[AbstractPotential, Atoms], max_batch: int): for start, end, positions in scan.generate_positions( max_batch=max_batch): yield start, end, self.multislice(positions, potential, pbar=False) def _generate_tds_probes(self, scan, potential, max_batch, pbar): tds_bar = ProgressBar(total=len(potential.frozen_phonons), desc='TDS', disable=(not pbar) or (len(potential.frozen_phonons) == 1)) potential_pbar = ProgressBar(total=len(potential), desc='Potential', disable=not pbar) for potential_config in potential.generate_frozen_phonon_potentials( pbar=potential_pbar): yield self._generate_probes(scan, potential_config, max_batch) tds_bar.update(1) potential_pbar.close() tds_bar.refresh() tds_bar.close() def scan(self, scan: AbstractScan, detectors: Union[AbstractDetector, Sequence[AbstractDetector]], potential: Union[Atoms, AbstractPotential], max_batch: int = 1, pbar: bool = True) -> dict: """ Raster scan the probe across the potential and record a measurement for each detector. :param scan: Scan object defining the positions of the probe wave functions. :param detectors: The detectors recording the measurements. :param potential: The potential across which to scan the probe . :param max_batch: The probe batch size. Larger batches are faster, but require more memory. :param pbar: If true, display progress bars. :return: Dictionary of measurements with keys given by the detector. """ self.grid.match(potential.grid) self.grid.check_is_defined() if isinstance(detectors, AbstractDetector): detectors = [detectors] measurements = {} for detector in detectors: measurements[detector] = detector.allocate_measurement( self.grid, self.wavelength, scan) scan_bar = ProgressBar(total=len(scan), desc='Scan', disable=not pbar) if isinstance(potential, AbstractTDSPotentialBuilder): probe_generators = self._generate_tds_probes( scan, potential, max_batch, pbar) else: if isinstance(potential, AbstractPotentialBuilder): potential = potential.build(pbar=True) probe_generators = [ self._generate_probes(scan, potential, max_batch) ] for probe_generator in probe_generators: scan_bar.reset() for start, end, exit_probes in probe_generator: for detector, measurement in measurements.items(): scan.insert_new_measurement(measurement, start, end, detector.detect(exit_probes)) scan_bar.update(end - start) scan_bar.refresh() scan_bar.close() return measurements def show(self, profile: bool = False, **kwargs): """ Show the probe wave function. :param profile: If true, show a 1D slice of the probe as a line profile. :param kwargs: Additional keyword arguments for the plot.show_line or plot.show_image functions. See the documentation of the respective function for a description. """ measurement = self.build( (self.extent[0] / 2, self.extent[1] / 2)).intensity() if profile: array = measurement.array[0] array = array[array.shape[0] // 2, :] calibration = calibrations_from_grid( gpts=(self.grid.gpts[1], ), sampling=(self.grid.sampling[1], ), names=['x'])[0] show_line(array, calibration, **kwargs) else: return measurement.show(**kwargs) def __copy__(self) -> 'Probe': new_copy = self.__class__() new_copy._grid = copy(self.grid) new_copy._ctf = copy(self.ctf) new_copy._accelerator = copy(self._ctf._accelerator) return new_copy