def test_energy(): energy = Accelerator(energy=300e3) assert energy.energy == 300e3 assert np.isclose(energy.wavelength, energy2wavelength(300e3)) energy.energy = 200e3 assert np.isclose(energy.wavelength, energy2wavelength(200e3))
def __init__(self, transitions, atoms=None, slice_thickness=None, gpts: Union[int, Sequence[float]] = None, sampling: Union[float, Sequence[float]] = None, energy: float = None, min_contrast=.95): if isinstance(transitions, (SubshellTransitions, SubshellTransitionsArrays)): transitions = [transitions] self._slice_thickness = slice_thickness self._grid = Grid(gpts=gpts, sampling=sampling) self.atoms = atoms self._transitions = transitions self._accelerator = Accelerator(energy=energy) self._sliced_atoms = SlicedAtoms( atoms, slice_thicknesses=self._slice_thickness) self._potentials_cache = Cache(1)
def __init__(self, array: np.ndarray, slice_thicknesses: Union[float, Sequence[float]], extent: Union[float, Sequence[float]] = None, sampling: Union[float, Sequence[float]] = None, energy: float = None): self._accelerator = Accelerator(energy=energy) super().__init__(array, slice_thicknesses, extent, sampling)
def __init__(self, semiangle_cutoff: float = np.inf, rolloff: float = 0.1, focal_spread: float = 0., angular_spread: float = 0., gaussian_spread: float = 0., energy: float = None, parameters: Mapping[str, float] = None, **kwargs): for key in kwargs.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError('{} not a recognized parameter'.format(key)) self.changed = Event() self._accelerator = Accelerator(energy=energy) self._accelerator.changed.register(self.changed.notify) self._semiangle_cutoff = semiangle_cutoff self._rolloff = rolloff self._focal_spread = focal_spread self._angular_spread = angular_spread self._gaussian_spread = gaussian_spread self._parameters = dict(zip(polar_symbols, [0.] * len(polar_symbols))) if parameters is None: parameters = {} parameters.update(kwargs) self.set_parameters(parameters) def parametrization_property(key): def getter(self): return self._parameters[key] def setter(self, value): old = getattr(self, key) self._parameters[key] = value self.changed.notify(**{ 'notifier': self, 'property_name': key, 'change': old != value }) return property(getter, setter) for symbol in polar_symbols: setattr(self.__class__, symbol, parametrization_property(symbol)) for key, value in polar_aliases.items(): if key != 'defocus': setattr(self.__class__, key, parametrization_property(value))
def __init__(self, outer_angle, energy=None, inner_angle=0., num_radials=0, cross=0., rotation=0.): self._outer_angle = outer_angle self._inner_angle = inner_angle self._num_radials = num_radials self._rotation = rotation self._cross = cross self._accelerator = Accelerator(energy=energy)
class CTF(HasAcceleratorMixin, HasEventMixin): """ Contrast transfer function object The Contrast Transfer Function (CTF) describes the aberrations of the objective lens in HRTEM and specifies how the condenser system shapes the probe in STEM. abTEM implements phase aberrations up to 5th order using polar coefficients. See Eq. 2.22 in the reference [1]_. Cartesian coefficients can be converted to polar using the utility function abtem.transfer.cartesian2polar. Partial coherence is included as an envelope in the quasi-coherent approximation. See Chapter 3.2 in reference [1]_. For a more detailed discussion with examples, see our `walkthrough <https://abtem.readthedocs.io/en/latest/walkthrough/05_contrast_transfer_function.html>`_. Parameters ---------- semiangle_cutoff: float The semiangle cutoff describes the sharp Fourier space cutoff due to the objective aperture [mrad]. rolloff: float Tapers the cutoff edge over the given angular range [mrad]. focal_spread: float The 1/e width of the focal spread due to chromatic aberration and lens current instability [Å]. angular_spread: float The 1/e width of the angular deviations due to source size [mrad]. gaussian_spread: The 1/e width image deflections due to vibrations and thermal magnetic noise [Å]. energy: float The electron energy of the wave functions this contrast transfer function will be applied to [eV]. parameters: dict Mapping from aberration symbols to their corresponding values. All aberration magnitudes should be given in Å and angles should be given in radians. kwargs: Provide the aberration coefficients as keyword arguments. References ---------- .. [1] Kirkland, E. J. (2010). Advanced Computing in Electron Microscopy (2nd ed.). Springer. """ def __init__(self, semiangle_cutoff: float = np.inf, rolloff: float = 2, focal_spread: float = 0., angular_spread: float = 0., gaussian_spread: float = 0., energy: float = None, parameters: Mapping[str, float] = None, aperture=None, **kwargs): for key in kwargs.keys(): if (key not in polar_symbols) and (key not in polar_aliases.keys()): raise ValueError('{} not a recognized parameter'.format(key)) self._event = Event() self._accelerator = Accelerator(energy=energy) self._accelerator.observe(self.event.notify) self._semiangle_cutoff = semiangle_cutoff self._rolloff = rolloff self._focal_spread = focal_spread self._angular_spread = angular_spread self._gaussian_spread = gaussian_spread self._parameters = dict(zip(polar_symbols, [0.] * len(polar_symbols))) self._aperture = aperture if self._aperture is not None: self._aperture.accelerator.match(self) if parameters is None: parameters = {} parameters.update(kwargs) self.set_parameters(parameters) def parametrization_property(key): def getter(self): return self._parameters[key] def setter(self, value): old = getattr(self, key) self._parameters[key] = value self.event.notify({ 'notifier': self, 'name': key, 'change': old != value }) return property(getter, setter) for symbol in polar_symbols: setattr(self.__class__, symbol, parametrization_property(symbol)) for key, value in polar_aliases.items(): if key != 'defocus': setattr(self.__class__, key, parametrization_property(value)) @property def nyquist_sampling(self): return 1 / (4 * self.semiangle_cutoff / self.wavelength * 1e-3) @property def parameters(self): """The parameters.""" return self._parameters @property def defocus(self) -> float: """The defocus [Å].""" return -self._parameters['C10'] @defocus.setter def defocus(self, value: float): self.C10 = -value @property def semiangle_cutoff(self) -> float: """The semi-angle cutoff [mrad].""" return self._semiangle_cutoff @semiangle_cutoff.setter @watched_property('_event') def semiangle_cutoff(self, value: float): self._semiangle_cutoff = value @property def rolloff(self) -> float: """The fraction of soft tapering of the cutoff.""" return self._rolloff @rolloff.setter @watched_property('_event') def rolloff(self, value: float): self._rolloff = value @property def focal_spread(self) -> float: """The focal spread [Å].""" return self._focal_spread @focal_spread.setter @watched_property('_event') def focal_spread(self, value: float): """The angular spread [mrad].""" self._focal_spread = value @property def angular_spread(self) -> float: return self._angular_spread @angular_spread.setter @watched_property('_event') def angular_spread(self, value: float): self._angular_spread = value @property def gaussian_spread(self) -> float: """The Gaussian spread [Å].""" return self._gaussian_spread @gaussian_spread.setter @watched_property('_event') def gaussian_spread(self, value: float): self._gaussian_spread = value @watched_method('_event') def set_parameters(self, parameters: dict): """ Set the phase of the phase aberration. Parameters ---------- parameters: dict Mapping from aberration symbols to their corresponding values. """ for symbol, value in parameters.items(): if symbol in self._parameters.keys(): self._parameters[symbol] = value elif symbol == 'defocus': self._parameters[polar_aliases[symbol]] = -value elif symbol in polar_aliases.keys(): self._parameters[polar_aliases[symbol]] = value else: raise ValueError( '{} not a recognized parameter'.format(symbol)) return parameters def evaluate_aperture( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] = None) -> Union[float, np.ndarray]: if self._aperture is not None: return self._aperture.evaluate(alpha, phi) xp = get_array_module(alpha) semiangle_cutoff = self.semiangle_cutoff / 1000 if self.semiangle_cutoff == xp.inf: return xp.ones_like(alpha) if self.rolloff > 0.: rolloff = self.rolloff / 1000. # * semiangle_cutoff array = .5 * ( 1 + xp.cos(np.pi * (alpha - semiangle_cutoff + rolloff) / rolloff)) array[alpha > semiangle_cutoff] = 0. array = xp.where(alpha > semiangle_cutoff - rolloff, array, xp.ones_like(alpha, dtype=xp.float32)) else: array = xp.array(alpha < semiangle_cutoff).astype(xp.float32) return array def evaluate_temporal_envelope( self, alpha: Union[float, np.ndarray]) -> Union[float, np.ndarray]: xp = get_array_module(alpha) return xp.exp(-(.5 * xp.pi / self.wavelength * self.focal_spread * alpha**2)**2).astype(xp.float32) def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray]) -> Union[float, np.ndarray]: xp = get_array_module(alpha) return xp.exp(-.5 * self.gaussian_spread**2 * alpha**2 / self.wavelength**2) def evaluate_spatial_envelope(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> \ Union[float, np.ndarray]: xp = get_array_module(alpha) p = self.parameters dchi_dk = 2 * xp.pi / self.wavelength * ( (p['C12'] * xp.cos(2. * (phi - p['phi12'])) + p['C10']) * alpha + (p['C23'] * xp.cos(3. * (phi - p['phi23'])) + p['C21'] * xp.cos(1. * (phi - p['phi21']))) * alpha**2 + (p['C34'] * xp.cos(4. * (phi - p['phi34'])) + p['C32'] * xp.cos(2. * (phi - p['phi32'])) + p['C30']) * alpha**3 + (p['C45'] * xp.cos(5. * (phi - p['phi45'])) + p['C43'] * xp.cos(3. * (phi - p['phi43'])) + p['C41'] * xp.cos(1. * (phi - p['phi41']))) * alpha**4 + (p['C56'] * xp.cos(6. * (phi - p['phi56'])) + p['C54'] * xp.cos(4. * (phi - p['phi54'])) + p['C52'] * xp.cos(2. * (phi - p['phi52'])) + p['C50']) * alpha**5) dchi_dphi = -2 * xp.pi / self.wavelength * ( 1 / 2. * (2. * p['C12'] * xp.sin(2. * (phi - p['phi12']))) * alpha + 1 / 3. * (3. * p['C23'] * xp.sin(3. * (phi - p['phi23'])) + 1. * p['C21'] * xp.sin(1. * (phi - p['phi21']))) * alpha**2 + 1 / 4. * (4. * p['C34'] * xp.sin(4. * (phi - p['phi34'])) + 2. * p['C32'] * xp.sin(2. * (phi - p['phi32']))) * alpha**3 + 1 / 5. * (5. * p['C45'] * xp.sin(5. * (phi - p['phi45'])) + 3. * p['C43'] * xp.sin(3. * (phi - p['phi43'])) + 1. * p['C41'] * xp.sin(1. * (phi - p['phi41']))) * alpha**4 + 1 / 6. * (6. * p['C56'] * xp.sin(6. * (phi - p['phi56'])) + 4. * p['C54'] * xp.sin(4. * (phi - p['phi54'])) + 2. * p['C52'] * xp.sin(2. * (phi - p['phi52']))) * alpha**5) return xp.exp(-xp.sign(self.angular_spread) * (self.angular_spread / 2 / 1000)**2 * (dchi_dk**2 + dchi_dphi**2)) def evaluate_chi( self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> Union[float, np.ndarray]: xp = get_array_module(alpha) p = self.parameters alpha2 = alpha**2 alpha = xp.array(alpha) array = xp.zeros(alpha.shape, dtype=np.float32) if any([p[symbol] != 0. for symbol in ('C10', 'C12', 'phi12')]): array += (1 / 2 * alpha2 * (p['C10'] + p['C12'] * xp.cos(2 * (phi - p['phi12'])))) if any( [p[symbol] != 0. for symbol in ('C21', 'phi21', 'C23', 'phi23')]): array += (1 / 3 * alpha2 * alpha * (p['C21'] * xp.cos(phi - p['phi21']) + p['C23'] * xp.cos(3 * (phi - p['phi23'])))) if any([ p[symbol] != 0. for symbol in ('C30', 'C32', 'phi32', 'C34', 'phi34') ]): array += (1 / 4 * alpha2**2 * (p['C30'] + p['C32'] * xp.cos(2 * (phi - p['phi32'])) + p['C34'] * xp.cos(4 * (phi - p['phi34'])))) if any([ p[symbol] != 0. for symbol in ('C41', 'phi41', 'C43', 'phi43', 'C45', 'phi41') ]): array += (1 / 5 * alpha2**2 * alpha * (p['C41'] * xp.cos((phi - p['phi41'])) + p['C43'] * xp.cos(3 * (phi - p['phi43'])) + p['C45'] * xp.cos(5 * (phi - p['phi45'])))) if any([ p[symbol] != 0. for symbol in ('C50', 'C52', 'phi52', 'C54', 'phi54', 'C56', 'phi56') ]): array += (1 / 6 * alpha2**3 * (p['C50'] + p['C52'] * xp.cos(2 * (phi - p['phi52'])) + p['C54'] * xp.cos(4 * (phi - p['phi54'])) + p['C56'] * xp.cos(6 * (phi - p['phi56'])))) array = 2 * xp.pi / self.wavelength * array return array def evaluate_aberrations(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> \ Union[float, np.ndarray]: xp = get_array_module(alpha) complex_exponential = get_device_function(xp, 'complex_exponential') return complex_exponential(-self.evaluate_chi(alpha, phi)) def evaluate(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> Union[float, np.ndarray]: array = self.evaluate_aberrations(alpha, phi) if (self.semiangle_cutoff < np.inf) or (self._aperture is not None): array *= self.evaluate_aperture(alpha, phi) if self.focal_spread > 0.: array *= self.evaluate_temporal_envelope(alpha) if self.angular_spread > 0.: array *= self.evaluate_spatial_envelope(alpha, phi) if self.gaussian_spread > 0.: array *= self.evaluate_gaussian_envelope(alpha) return array def _polar_coordinates(self, gpts=None, extent=None, sampling=None, xp=np): grid = Grid(gpts=gpts, extent=extent, sampling=sampling) gpts = grid.gpts sampling = grid.sampling kx, ky = spatial_frequencies(gpts, sampling) kx = kx.reshape((1, -1, 1)) ky = ky.reshape((1, 1, -1)) kx = xp.asarray(kx) ky = xp.asarray(ky) return polar_coordinates(xp.asarray(kx * self.wavelength), xp.asarray(ky * self.wavelength)) def evaluate_on_grid(self, gpts=None, extent=None, sampling=None, xp=np): return self.evaluate( *self._polar_coordinates(gpts, extent, sampling, xp)) def profiles(self, max_semiangle: float = None, phi: float = 0.): if max_semiangle is None: if self._semiangle_cutoff == np.inf: max_semiangle = 50 else: max_semiangle = self._semiangle_cutoff * 1.6 alpha = np.linspace(0, max_semiangle / 1000., 500) aberrations = self.evaluate_aberrations(alpha, phi) aperture = self.evaluate_aperture(alpha) temporal_envelope = self.evaluate_temporal_envelope(alpha) spatial_envelope = self.evaluate_spatial_envelope(alpha, phi) gaussian_envelope = self.evaluate_gaussian_envelope(alpha) envelope = aperture * temporal_envelope * spatial_envelope * gaussian_envelope calibration = Calibration(offset=0., sampling=(alpha[1] - alpha[0]) * 1000., units='mrad', name='alpha') profiles = {} profiles['ctf'] = Measurement(aberrations.imag * envelope, calibrations=[calibration], name='CTF') profiles['aperture'] = Measurement(aperture, calibrations=[calibration], name='Aperture') profiles['temporal_envelope'] = Measurement(temporal_envelope, calibrations=[calibration], name='Temporal') profiles['spatial_envelope'] = Measurement(spatial_envelope, calibrations=[calibration], name='Spatial') profiles['gaussian_envelope'] = Measurement(gaussian_envelope, calibrations=[calibration], name='Gaussian') profiles['envelope'] = Measurement(envelope, calibrations=[calibration], name='Envelope') return profiles def apply(self, waves, interact=False, sliders=None, throttling=0.): if interact: from abtem.visualize.interactive import Canvas, MeasurementArtist2d from abtem.visualize.widgets import quick_sliders, throttle import ipywidgets as widgets image_waves = waves.copy() canvas = Canvas() artist = MeasurementArtist2d() canvas.artists = {'artist': artist} def update(*args): image_waves.array[:] = waves.apply_ctf(self).array artist.measurement = image_waves.intensity()[0] canvas.adjust_limits_to_artists() canvas.adjust_labels_to_artists() if throttling: update = throttle(throttling)(update) self.observe(update) update() if sliders: sliders = quick_sliders(self, **sliders) figure = widgets.HBox([canvas.figure, widgets.VBox(sliders)]) else: figure = canvas.figure return image_waves, figure else: if sliders: raise RuntimeError() return waves.apply_ctf(self) def interact(self, max_semiangle: float = None, phi: float = 0., sliders=None, throttling=False): from abtem.visualize.interactive.utils import quick_sliders, throttle from abtem.visualize.interactive import Canvas, MeasurementArtist1d import ipywidgets as widgets canvas = Canvas(lock_scale=False) ctf_artist = MeasurementArtist1d() envelope_artist = MeasurementArtist1d() canvas.artists = {'ctf': ctf_artist, 'envelope': envelope_artist} canvas.y_scale.min = -1.1 canvas.y_scale.max = 1.1 def callback(*args): profiles = self.profiles(max_semiangle, phi) for name, artist in canvas.artists.items(): artist.measurement = profiles[name] if throttling: callback = throttle(throttling)(callback) self.observe(callback) callback() canvas.adjust_limits_to_artists(adjust_y=False) canvas.adjust_labels_to_artists() if sliders: sliders = quick_sliders(self, **sliders) return widgets.HBox([canvas.figure, widgets.VBox(sliders)]) else: return canvas.figure def show(self, max_semiangle: float = None, phi: float = 0, ax=None, **kwargs): """ Show the contrast transfer function. Parameters ---------- max_semiangle: float Maximum semiangle to display in the plot. ax: matplotlib Axes, optional If given, the plot will be added to this matplotlib axes. phi: float, optional The contrast transfer function will be plotted along this angle. Default is 0. n: int, optional Number of evaluation points to use in the plot. Default is 1000. title: str, optional The title of the plot. Default is 'None'. kwargs: Additional keyword arguments for the line plots. """ import matplotlib.pyplot as plt if ax is None: ax = plt.subplot() for key, profile in self.profiles(max_semiangle, phi).items(): if not np.all(profile.array == 1.): ax, lines = profile.show(legend=True, ax=ax, **kwargs) return ax def copy(self): parameters = self.parameters.copy() return self.__class__(semiangle_cutoff=self.semiangle_cutoff, rolloff=self.rolloff, focal_spread=self.focal_spread, angular_spread=self.angular_spread, gaussian_spread=self.gaussian_spread, energy=self.energy, parameters=parameters)
def __init__(self, Z, extent, gpts, sampling, energy): self._Z = Z self._grid = Grid(extent=extent, gpts=gpts, sampling=sampling) self._accelerator = Accelerator(energy=energy)
def test_accelerator_event(): accelerator = Accelerator(300e3) accelerator.energy = 200e3 assert accelerator.changed._notify_count == 1
def test_energy_raises(): accelerator1 = Accelerator(300e3) accelerator2 = Accelerator() with pytest.raises(RuntimeError): accelerator2.check_is_defined() accelerator2.energy = 200e3 with pytest.raises(RuntimeError): accelerator1.check_match(accelerator2) accelerator2.energy = accelerator1.energy accelerator1.check_match(accelerator2)