Exemplo n.º 1
0
        def __init__(self):
            self._notify_count = 0

            def callback(*args):
                self._notify_count += 1

            self._notified_property = 0

            self.event = Event()
            self.event.observe(callback)
Exemplo n.º 2
0
        def __init__(self):
            self._notify_count = 0

            def callback(notifier, property_name, change):
                self._notify_count += 1

            self._notified_property = 0

            self.event = Event()
            self.event.register(callback)
Exemplo n.º 3
0
    def __init__(self,
                 semiangle_cutoff: float = np.inf,
                 rolloff: float = 0.1,
                 focal_spread: float = 0.,
                 angular_spread: float = 0.,
                 gaussian_spread: float = 0.,
                 energy: float = None,
                 parameters: Mapping[str, float] = None,
                 **kwargs):

        for key in kwargs.keys():
            if (key not in polar_symbols) and (key
                                               not in polar_aliases.keys()):
                raise ValueError('{} not a recognized parameter'.format(key))

        self.changed = Event()

        self._accelerator = Accelerator(energy=energy)
        self._accelerator.changed.register(self.changed.notify)

        self._semiangle_cutoff = semiangle_cutoff
        self._rolloff = rolloff
        self._focal_spread = focal_spread
        self._angular_spread = angular_spread
        self._gaussian_spread = gaussian_spread
        self._parameters = dict(zip(polar_symbols, [0.] * len(polar_symbols)))

        if parameters is None:
            parameters = {}

        parameters.update(kwargs)
        self.set_parameters(parameters)

        def parametrization_property(key):
            def getter(self):
                return self._parameters[key]

            def setter(self, value):
                old = getattr(self, key)
                self._parameters[key] = value
                self.changed.notify(**{
                    'notifier': self,
                    'property_name': key,
                    'change': old != value
                })

            return property(getter, setter)

        for symbol in polar_symbols:
            setattr(self.__class__, symbol, parametrization_property(symbol))

        for key, value in polar_aliases.items():
            if key != 'defocus':
                setattr(self.__class__, key, parametrization_property(value))
Exemplo n.º 4
0
    def __init__(self,
                 inner: float = None,
                 outer: float = None,
                 radial_steps: float = 1.,
                 azimuthal_steps: float = None,
                 offset: Tuple[float, float] = None,
                 rotation: float = 0.,
                 save_file: str = None):

        self._inner = inner
        self._outer = outer

        self._radial_steps = radial_steps

        if azimuthal_steps is None:
            azimuthal_steps = 2 * np.pi

        self._azimuthal_steps = azimuthal_steps

        self._rotation = rotation
        self._offset = offset

        self.cache = Cache(1)
        self.changed = Event()
        super().__init__(max_detected_angle=outer, save_file=save_file)
Exemplo n.º 5
0
    def __init__(self,
                 potential_unit: AbstractPotential,
                 repetitions: Tuple[int, int, int],
                 num_frozen_phonon_configs: int = 1):

        self._potential_unit = potential_unit
        self.repetitions = repetitions
        self._num_frozen_phonon_configs = num_frozen_phonon_configs

        if (potential_unit.num_frozen_phonon_configs
                == 1) & (num_frozen_phonon_configs > 1):
            warnings.warn(
                '"num_frozen_phonon_configs" is greater than one, but the potential unit does not have'
                'frozen phonons')

        if (potential_unit.num_frozen_phonon_configs >
                1) & (num_frozen_phonon_configs == 1):
            warnings.warn(
                'the potential unit has frozen phonons, but "num_frozen_phonon_configs" is set to 1'
            )

        self._cache = Cache(1)
        self._changed = Event()

        gpts = (self._potential_unit.gpts[0] * self.repetitions[0],
                self._potential_unit.gpts[1] * self.repetitions[1])
        extent = (self._potential_unit.extent[0] * self.repetitions[0],
                  self._potential_unit.extent[1] * self.repetitions[1])

        self._grid = Grid(extent=extent,
                          gpts=gpts,
                          sampling=self._potential_unit.sampling,
                          lock_extent=True)
        self._grid.changed.register(self._changed.notify)
        self._changed.register(cache_clear_callback(self._cache))

        super().__init__(precalculate=False)
Exemplo n.º 6
0
    class DummyWithWatchedMethod:
        def __init__(self):
            self._notify_count = 0

            def callback(notifier, property_name, change):
                self._notify_count += 1

            self._notified_property = 0

            self.event = Event()
            self.event.register(callback)

        @watched_method('event')
        def notified_method(self):
            pass

        @property
        def notified_property(self):
            return self._notified_property

        @notified_property.setter
        @watched_property('event')
        def notified_property(self, value):
            self._notified_property = value
Exemplo n.º 7
0
    class DummyWithWatchedMethod:
        def __init__(self):
            self._notify_count = 0

            def callback(*args):
                self._notify_count += 1

            self._notified_property = 0

            self.event = Event()
            self.event.observe(callback)

        @watched_method('event')
        def notified_method(self):
            pass

        @property
        def notified_property(self):
            return self._notified_property

        @notified_property.setter
        @watched_property('event')
        def notified_property(self, value):
            self._notified_property = value
Exemplo n.º 8
0
def test_register_event():
    event = Event()

    num_calls = {}

    def callback():
        num_calls['a'] = 'a'

    event.register(callback)
    event.notify()

    assert event.notify_count == 1
    assert num_calls['a'] == 'a'
Exemplo n.º 9
0
def test_register_event():
    event = Event()

    num_calls = {}

    def callback(*args):
        num_calls['a'] = 'a'

    event.observe(callback)
    event.notify(None)

    assert event.notify_count == 1
    assert num_calls['a'] == 'a'
Exemplo n.º 10
0
class CTF(HasAcceleratorMixin):
    """
    Contrast transfer function object

    The Contrast Transfer Function (CTF) describes the aberrations of the objective lens in HRTEM and specifies how the
    condenser system shapes the probe in STEM.

    abTEM implements phase aberrations up to 5th order using polar coefficients. See Eq. 2.22 in the reference [1]_.
    Cartesian coefficients can be converted to polar using the utility function abtem.transfer.cartesian2polar.

    Partial coherence is included as an envelope in the quasi-coherent approximation. See Chapter 3.2 in reference [1]_.

    For a more detailed discussion with examples, see our `walkthrough
    <https://abtem.readthedocs.io/en/latest/walkthrough/05_contrast_transfer_function.html>`_.

    Parameters
    ----------
    semiangle_cutoff: float
        The semiangle cutoff describes the sharp Fourier space cutoff due to the objective aperture [mrad].
    rolloff: float
        Softens the cutoff. A value of 0 gives a hard cutoff, while 1 gives the softest possible cutoff [Å].
    focal_spread: float
        The 1/e width of the focal spread due to chromatic aberration and lens current instability [Å].
    angular_spread: float
        The 1/e width of the angular deviations due to source size [Å].
    gaussian_spread:
        The 1/e width image deflections due to vibrations and thermal magnetic noise [Å].
    energy: float
        The electron energy of the wave functions this contrast transfer function will be applied to [eV].
    parameters: dict
        Mapping from aberration symbols to their corresponding values. All aberration magnitudes should be given in Å.
    kwargs:
        Provide the aberration coefficients as keyword arguments.

    References
    ----------
    .. [1] Kirkland, E. J. (2010). Advanced Computing in Electron Microscopy (2nd ed.). Springer.

    """
    def __init__(self,
                 semiangle_cutoff: float = np.inf,
                 rolloff: float = 0.1,
                 focal_spread: float = 0.,
                 angular_spread: float = 0.,
                 gaussian_spread: float = 0.,
                 energy: float = None,
                 parameters: Mapping[str, float] = None,
                 **kwargs):

        for key in kwargs.keys():
            if (key not in polar_symbols) and (key
                                               not in polar_aliases.keys()):
                raise ValueError('{} not a recognized parameter'.format(key))

        self.changed = Event()

        self._accelerator = Accelerator(energy=energy)
        self._accelerator.changed.register(self.changed.notify)

        self._semiangle_cutoff = semiangle_cutoff
        self._rolloff = rolloff
        self._focal_spread = focal_spread
        self._angular_spread = angular_spread
        self._gaussian_spread = gaussian_spread
        self._parameters = dict(zip(polar_symbols, [0.] * len(polar_symbols)))

        if parameters is None:
            parameters = {}

        parameters.update(kwargs)
        self.set_parameters(parameters)

        def parametrization_property(key):
            def getter(self):
                return self._parameters[key]

            def setter(self, value):
                old = getattr(self, key)
                self._parameters[key] = value
                self.changed.notify(**{
                    'notifier': self,
                    'property_name': key,
                    'change': old != value
                })

            return property(getter, setter)

        for symbol in polar_symbols:
            setattr(self.__class__, symbol, parametrization_property(symbol))

        for key, value in polar_aliases.items():
            if key != 'defocus':
                setattr(self.__class__, key, parametrization_property(value))

    @property
    def nyquist_sampling(self):
        return 1 / (4 * self.semiangle_cutoff / self.wavelength * 1e-3)

    @property
    def parameters(self):
        """The parameters."""
        return self._parameters

    @property
    def defocus(self) -> float:
        """The defocus [Å]."""
        return -self._parameters['C10']

    @defocus.setter
    def defocus(self, value: float):
        self.C10 = -value

    @property
    def semiangle_cutoff(self) -> float:
        """The semi-angle cutoff [mrad]."""
        return self._semiangle_cutoff

    @semiangle_cutoff.setter
    @watched_property('changed')
    def semiangle_cutoff(self, value: float):
        self._semiangle_cutoff = value

    @property
    def rolloff(self) -> float:
        """The fraction of soft tapering of the cutoff."""
        return self._rolloff

    @rolloff.setter
    @watched_property('changed')
    def rolloff(self, value: float):
        self._rolloff = value

    @property
    def focal_spread(self) -> float:
        """The focal spread [Å]."""
        return self._focal_spread

    @focal_spread.setter
    @watched_property('changed')
    def focal_spread(self, value: float):
        """The angular spread [mrad]."""
        self._focal_spread = value

    @property
    def angular_spread(self) -> float:
        return self._angular_spread

    @angular_spread.setter
    @watched_property('changed')
    def angular_spread(self, value: float):
        self._angular_spread = value

    @property
    def gaussian_spread(self) -> float:
        """The Gaussian spread [Å]."""
        return self._gaussian_spread

    @gaussian_spread.setter
    @watched_property('changed')
    def gaussian_spread(self, value: float):
        self._gaussian_spread = value

    @watched_method('changed')
    def set_parameters(self, parameters: dict):
        """
        Set the phase of the phase aberration.

        Parameters
        ----------
        parameters: dict
            Mapping from aberration symbols to their corresponding values.
        """

        for symbol, value in parameters.items():
            if symbol in self._parameters.keys():
                self._parameters[symbol] = value

            elif symbol == 'defocus':
                self._parameters[polar_aliases[symbol]] = -value

            elif symbol in polar_aliases.keys():
                self._parameters[polar_aliases[symbol]] = value

            else:
                raise ValueError(
                    '{} not a recognized parameter'.format(symbol))

        return parameters

    def evaluate_aperture(
            self, alpha: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
        xp = get_array_module(alpha)
        semiangle_cutoff = self.semiangle_cutoff / 1000

        if self.semiangle_cutoff == xp.inf:
            return xp.ones_like(alpha)

        if self.rolloff > 0.:
            rolloff = self.rolloff * semiangle_cutoff
            array = .5 * (
                1 + xp.cos(np.pi *
                           (alpha - semiangle_cutoff + rolloff) / rolloff))
            array[alpha > semiangle_cutoff] = 0.
            array = xp.where(alpha > semiangle_cutoff - rolloff, array,
                             xp.ones_like(alpha, dtype=xp.float32))
        else:
            array = xp.array(alpha < semiangle_cutoff).astype(xp.float32)
        return array

    def evaluate_temporal_envelope(
            self, alpha: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
        xp = get_array_module(alpha)
        return xp.exp(-(.5 * xp.pi / self.wavelength * self.focal_spread *
                        alpha**2)**2).astype(xp.float32)

    def evaluate_gaussian_envelope(
            self, alpha: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
        xp = get_array_module(alpha)
        return xp.exp(-.5 * self.gaussian_spread**2 * alpha**2 /
                      self.wavelength**2)

    def evaluate_spatial_envelope(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> \
            Union[float, np.ndarray]:
        xp = get_array_module(alpha)
        p = self.parameters
        dchi_dk = 2 * xp.pi / self.wavelength * (
            (p['C12'] * xp.cos(2. * (phi - p['phi12'])) + p['C10']) * alpha +
            (p['C23'] * xp.cos(3. * (phi - p['phi23'])) +
             p['C21'] * xp.cos(1. * (phi - p['phi21']))) * alpha**2 +
            (p['C34'] * xp.cos(4. * (phi - p['phi34'])) +
             p['C32'] * xp.cos(2. *
                               (phi - p['phi32'])) + p['C30']) * alpha**3 +
            (p['C45'] * xp.cos(5. * (phi - p['phi45'])) +
             p['C43'] * xp.cos(3. * (phi - p['phi43'])) +
             p['C41'] * xp.cos(1. * (phi - p['phi41']))) * alpha**4 +
            (p['C56'] * xp.cos(6. * (phi - p['phi56'])) +
             p['C54'] * xp.cos(4. * (phi - p['phi54'])) +
             p['C52'] * xp.cos(2. * (phi - p['phi52'])) + p['C50']) * alpha**5)

        dchi_dphi = -2 * xp.pi / self.wavelength * (
            1 / 2. *
            (2. * p['C12'] * xp.sin(2. *
                                    (phi - p['phi12']))) * alpha + 1 / 3. *
            (3. * p['C23'] * xp.sin(3. * (phi - p['phi23'])) +
             1. * p['C21'] * xp.sin(1. *
                                    (phi - p['phi21']))) * alpha**2 + 1 / 4. *
            (4. * p['C34'] * xp.sin(4. * (phi - p['phi34'])) +
             2. * p['C32'] * xp.sin(2. *
                                    (phi - p['phi32']))) * alpha**3 + 1 / 5. *
            (5. * p['C45'] * xp.sin(5. * (phi - p['phi45'])) +
             3. * p['C43'] * xp.sin(3. * (phi - p['phi43'])) +
             1. * p['C41'] * xp.sin(1. *
                                    (phi - p['phi41']))) * alpha**4 + 1 / 6. *
            (6. * p['C56'] * xp.sin(6. * (phi - p['phi56'])) +
             4. * p['C54'] * xp.sin(4. * (phi - p['phi54'])) +
             2. * p['C52'] * xp.sin(2. * (phi - p['phi52']))) * alpha**5)

        return xp.exp(-xp.sign(self.angular_spread) *
                      (self.angular_spread / 2 / 1000)**2 *
                      (dchi_dk**2 + dchi_dphi**2))

    def evaluate_chi(
            self, alpha: Union[float, np.ndarray],
            phi: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
        xp = get_array_module(alpha)
        p = self.parameters

        alpha2 = alpha**2
        alpha = xp.array(alpha)

        array = xp.zeros(alpha.shape, dtype=np.float32)
        if any([p[symbol] != 0. for symbol in ('C10', 'C12', 'phi12')]):
            array += (1 / 2 * alpha2 *
                      (p['C10'] + p['C12'] * xp.cos(2 * (phi - p['phi12']))))

        if any(
            [p[symbol] != 0. for symbol in ('C21', 'phi21', 'C23', 'phi23')]):
            array += (1 / 3 * alpha2 * alpha *
                      (p['C21'] * xp.cos(phi - p['phi21']) +
                       p['C23'] * xp.cos(3 * (phi - p['phi23']))))

        if any([
                p[symbol] != 0.
                for symbol in ('C30', 'C32', 'phi32', 'C34', 'phi34')
        ]):
            array += (1 / 4 * alpha2**2 *
                      (p['C30'] + p['C32'] * xp.cos(2 * (phi - p['phi32'])) +
                       p['C34'] * xp.cos(4 * (phi - p['phi34']))))

        if any([
                p[symbol] != 0.
                for symbol in ('C41', 'phi41', 'C43', 'phi43', 'C45', 'phi41')
        ]):
            array += (1 / 5 * alpha2**2 * alpha *
                      (p['C41'] * xp.cos((phi - p['phi41'])) +
                       p['C43'] * xp.cos(3 * (phi - p['phi43'])) +
                       p['C45'] * xp.cos(5 * (phi - p['phi45']))))

        if any([
                p[symbol] != 0. for symbol in ('C50', 'C52', 'phi52', 'C54',
                                               'phi54', 'C56', 'phi56')
        ]):
            array += (1 / 6 * alpha2**3 *
                      (p['C50'] + p['C52'] * xp.cos(2 * (phi - p['phi52'])) +
                       p['C54'] * xp.cos(4 * (phi - p['phi54'])) +
                       p['C56'] * xp.cos(6 * (phi - p['phi56']))))

        array = 2 * xp.pi / self.wavelength * array
        return array

    def evaluate_aberrations(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> \
            Union[float, np.ndarray]:
        xp = get_array_module(alpha)
        complex_exponential = get_device_function(xp, 'complex_exponential')
        return complex_exponential(-self.evaluate_chi(alpha, phi))

    def evaluate(self, alpha: Union[float, np.ndarray],
                 phi: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
        array = self.evaluate_aberrations(alpha, phi)

        if self.semiangle_cutoff < np.inf:
            array *= self.evaluate_aperture(alpha)

        if self.focal_spread > 0.:
            array *= self.evaluate_temporal_envelope(alpha)

        if self.angular_spread > 0.:
            array *= self.evaluate_spatial_envelope(alpha, phi)

        if self.gaussian_spread > 0.:
            array *= self.evaluate_gaussian_envelope(alpha)

        return array

    def evaluate_on_grid(self, grid, xp=np):
        kx, ky = spatial_frequencies(grid.gpts, grid.sampling)
        kx = kx.reshape((1, -1, 1))
        ky = ky.reshape((1, 1, -1))
        kx = xp.asarray(kx)
        ky = xp.asarray(ky)
        alpha, phi = polar_coordinates(xp.asarray(kx * self.wavelength),
                                       xp.asarray(ky * self.wavelength))
        return self.evaluate(alpha, phi)

    def profiles(self, max_semiangle: float = None, phi: float = 0.):
        if max_semiangle is None:
            if self._semiangle_cutoff == np.inf:
                max_semiangle = 50
            else:
                max_semiangle = self._semiangle_cutoff * 1.6

        alpha = np.linspace(0, max_semiangle / 1000., 500)

        aberrations = self.evaluate_aberrations(alpha, phi)
        aperture = self.evaluate_aperture(alpha)
        temporal_envelope = self.evaluate_temporal_envelope(alpha)
        spatial_envelope = self.evaluate_spatial_envelope(alpha, phi)
        gaussian_envelope = self.evaluate_gaussian_envelope(alpha)
        envelope = aperture * temporal_envelope * spatial_envelope * gaussian_envelope

        calibration = Calibration(offset=0.,
                                  sampling=(alpha[1] - alpha[0]) * 1000.,
                                  units='mrad',
                                  name='alpha')

        profiles = {}
        profiles['ctf'] = Measurement(aberrations.imag * envelope,
                                      calibrations=[calibration],
                                      name='CTF')
        profiles['aperture'] = Measurement(aperture,
                                           calibrations=[calibration],
                                           name='Aperture')
        profiles['temporal_envelope'] = Measurement(temporal_envelope,
                                                    calibrations=[calibration],
                                                    name='Temporal')
        profiles['spatial_envelope'] = Measurement(spatial_envelope,
                                                   calibrations=[calibration],
                                                   name='Spatial')
        profiles['gaussian_spread'] = Measurement(gaussian_envelope,
                                                  calibrations=[calibration],
                                                  name='Gaussian')
        profiles['envelope'] = Measurement(envelope,
                                           calibrations=[calibration],
                                           name='Envelope')
        return profiles

    def apply(self, waves, interact=False, sliders=None, throttling=0.):
        from abtem.visualize.bqplot import show_measurement_2d
        from abtem.visualize.widgets import quick_sliders, throttle
        import ipywidgets as widgets

        if interact:
            image_waves = waves.copy()

            def update():
                image_waves._array[:] = waves.apply_ctf(self).array
                return image_waves.intensity()

            figure, callback = show_measurement_2d(update)

            if throttling:
                callback = throttle(throttling)(callback)

            self.changed.register(callback)
            if sliders:
                sliders = quick_sliders(self, **sliders)
                figure = widgets.HBox([figure, widgets.VBox(sliders)])
            return image_waves, figure
        else:
            if sliders:
                raise RuntimeError()

            return waves.apply_ctf(self)

    def interact(self,
                 max_semiangle: float = None,
                 phi: float = 0.,
                 sliders=None,
                 throttling=False):
        import bqplot.pyplot as plt
        from abtem.visualize.bqplot import show_measurement_1d
        from abtem.visualize.widgets import quick_sliders, throttle
        import ipywidgets as widgets

        figure = plt.figure(fig_margin={
            'top': 0,
            'bottom': 50,
            'left': 50,
            'right': 0
        })
        figure.layout.height = '250px'
        figure.layout.width = '300px'

        _, callback = show_measurement_1d(
            lambda: self.profiles(max_semiangle, phi).values(), figure)

        if throttling:
            callback = throttle(throttling)(callback)

        self.changed.register(callback)

        if sliders:
            sliders = quick_sliders(self, **sliders)
            return widgets.HBox([figure, widgets.VBox(sliders)])
        else:
            return figure

    def show(self,
             max_semiangle: float = None,
             phi: float = 0,
             ax=None,
             **kwargs):
        """
        Show the contrast transfer function.

        Parameters
        ----------
        max_semiangle: float
            Maximum semiangle to display in the plot.
        ax: matplotlib Axes, optional
            If given, the plot will be added to this matplotlib axes.
        phi: float, optional
            The contrast transfer function will be plotted along this angle. Default is 0.
        n: int, optional
            Number of evaluation points to use in the plot. Default is 1000.
        title: str, optional
            The title of the plot. Default is 'None'.
        kwargs:
            Additional keyword arguments for the line plots.
        """
        import matplotlib.pyplot as plt

        if ax is None:
            ax = plt.subplot()

        for key, profile in self.profiles(max_semiangle, phi).items():
            if not np.all(profile.array == 1.):
                ax, lines = profile.show(legend=True, ax=ax, **kwargs)

        return ax

    def copy(self):
        parameters = self.parameters.copy()
        return self.__class__(semiangle_cutoff=self.semiangle_cutoff,
                              rolloff=self.rolloff,
                              focal_spread=self.focal_spread,
                              angular_spread=self.angular_spread,
                              gaussian_spread=self.gaussian_spread,
                              energy=self.energy,
                              parameters=parameters)
Exemplo n.º 11
0
    def __init__(self,
                 atoms: Union[Atoms, AbstractFrozenPhonons] = None,
                 gpts: Union[int, Sequence[int]] = None,
                 sampling: Union[float, Sequence[float]] = None,
                 slice_thickness: float = .5,
                 parametrization: str = 'lobato',
                 projection: str = 'finite',
                 cutoff_tolerance: float = 1e-3,
                 device='cpu',
                 precalculate: bool = True,
                 storage=None):

        self._cutoff_tolerance = cutoff_tolerance
        self._parametrization = parametrization
        self._slice_thickness = slice_thickness

        self._storage = storage

        if parametrization.lower() == 'lobato':
            self._parameters = load_lobato_parameters()
            self._function = lobato
            self._derivative = dvdr_lobato

        elif parametrization.lower() == 'kirkland':
            self._parameters = load_kirkland_parameters()
            self._function = kirkland
            self._derivative = dvdr_kirkland
        else:
            raise RuntimeError(
                'Parametrization {} not recognized'.format(parametrization))

        if projection == 'infinite':
            if parametrization.lower() != 'kirkland':
                raise RuntimeError(
                    'Infinite projections are only implemented for the Kirkland parametrization'
                )
        elif (projection != 'finite'):
            raise RuntimeError('Projection must be "finite" or "infinite"')

        self._projection = projection

        if isinstance(atoms, AbstractFrozenPhonons):
            self._frozen_phonons = atoms
        else:
            self._frozen_phonons = DummyFrozenPhonons(atoms)

        atoms = next(iter(self._frozen_phonons))

        if np.abs(atoms.cell[2, 2]) < 1e-12:
            raise RuntimeError('Atoms cell has no thickness')

        if not is_cell_orthogonal(atoms):
            raise RuntimeError('Atoms are not orthogonal')

        self._atoms = atoms
        self._grid = Grid(extent=np.diag(atoms.cell)[:2],
                          gpts=gpts,
                          sampling=sampling,
                          lock_extent=True)

        self._cutoffs = {}
        self._integrators = {}
        self._disc_indices = {}

        def grid_changed_callback(*args, **kwargs):
            self._integrators = {}
            self._disc_indices = {}

        self.grid.changed.register(grid_changed_callback)
        self.changed = Event()

        if storage is None:
            storage = device

        super().__init__(precalculate=precalculate,
                         device=device,
                         storage=storage)
Exemplo n.º 12
0
class CrystalPotential(AbstractPotential):
    """
    Crystal potential object

    The crystal potential may be used to represent a potential consisting of a repeating unit. This may allow
    calculations to be performed with lower memory and computational cost.

    The crystal potential has an additional function in conjunction with frozen phonon calculations. The number of
    frozen phonon configurations are not given by the FrozenPhonon objects, rather the ensemble of frozen phonon
    potentials represented by a potential with frozen phonons represent a collection of units, which will be assembled
    randomly to represent a random potential. The number of frozen phonon configurations should be given explicitely.
    This may save computational cost since a smaller number of units can be combined to a larger frozen phonon ensemble.

    Parameters
    ----------
    potential_unit : AbstractPotential
        The potential unit that repeated will create the full potential.
    repetitions : three int
        The repetitions of the potential in x, y and z.
    num_frozen_phonon_configs : int
        Number of frozen phonon configurations.
    """
    def __init__(self,
                 potential_unit: AbstractPotential,
                 repetitions: Tuple[int, int, int],
                 num_frozen_phonon_configs: int = 1):

        self._potential_unit = potential_unit
        self.repetitions = repetitions
        self._num_frozen_phonon_configs = num_frozen_phonon_configs

        if (potential_unit.num_frozen_phonon_configs
                == 1) & (num_frozen_phonon_configs > 1):
            warnings.warn(
                '"num_frozen_phonon_configs" is greater than one, but the potential unit does not have'
                'frozen phonons')

        if (potential_unit.num_frozen_phonon_configs >
                1) & (num_frozen_phonon_configs == 1):
            warnings.warn(
                'the potential unit has frozen phonons, but "num_frozen_phonon_configs" is set to 1'
            )

        self._cache = Cache(1)
        self._changed = Event()

        gpts = (self._potential_unit.gpts[0] * self.repetitions[0],
                self._potential_unit.gpts[1] * self.repetitions[1])
        extent = (self._potential_unit.extent[0] * self.repetitions[0],
                  self._potential_unit.extent[1] * self.repetitions[1])

        self._grid = Grid(extent=extent,
                          gpts=gpts,
                          sampling=self._potential_unit.sampling,
                          lock_extent=True)
        self._grid.changed.register(self._changed.notify)
        self._changed.register(cache_clear_callback(self._cache))

        super().__init__(precalculate=False)

    @HasGridMixin.gpts.setter
    def gpts(self, gpts):
        if not ((gpts[0] % self.repetitions[0] == 0) and
                (gpts[1] % self.repetitions[0] == 0)):
            raise ValueError(
                'gpts must be divisible by the number of potential repetitions'
            )
        self.grid.gpts = gpts
        self._potential_unit.gpts = (gpts[0] // self._repetitions[0],
                                     gpts[1] // self._repetitions[1])

    @HasGridMixin.sampling.setter
    def sampling(self, sampling):
        self.sampling = sampling
        self._potential_unit.sampling = sampling

    @property
    def num_frozen_phonon_configs(self):
        return self._num_frozen_phonon_configs

    def generate_frozen_phonon_potentials(self, pbar=False):
        for i in range(self.num_frozen_phonon_configs):
            yield self

    @property
    def repetitions(self) -> Tuple[int, int, int]:
        return self._repetitions

    @repetitions.setter
    def repetitions(self, repetitions: Tuple[int, int, int]):
        repetitions = tuple(repetitions)

        if len(repetitions) != 3:
            raise ValueError('repetitions must be sequence of length 3')

        self._repetitions = repetitions

    @property
    def num_slices(self) -> int:
        return self._potential_unit.num_slices * self.repetitions[2]

    def get_slice_thickness(self, i) -> float:
        return self._potential_unit.get_slice_thickness(i)

    @cached_method('_cache')
    def _calculate_configs(self, energy, max_batch=1):
        potential_generators = self._potential_unit.generate_frozen_phonon_potentials(
            pbar=False)

        potential_configs = []
        for potential in potential_generators:

            if isinstance(potential, AbstractPotentialBuilder):
                potential = potential.build(max_batch=max_batch)
            elif not isinstance(potential, PotentialArray):
                raise RuntimeError()

            if energy is not None:
                potential = potential.as_transmission_function(
                    energy=energy, max_batch=max_batch)

            potential = potential.tile(self.repetitions[:2])
            potential_configs.append(potential)

        return potential_configs

    def _generate_slices_base(self,
                              first_slice=0,
                              last_slice=None,
                              max_batch=1,
                              energy=None):

        first_layer = first_slice // self._potential_unit.num_slices
        if last_slice is None:
            last_layer = self.repetitions[2]
        else:
            last_layer = last_slice // self._potential_unit.num_slices

        first_slice = first_slice % self._potential_unit.num_slices
        last_slice = None

        configs = self._calculate_configs(energy, max_batch)

        if len(configs) == 1:
            layers = configs * self.repetitions[2]
        else:
            layers = [
                configs[np.random.randint(len(configs))]
                for _ in range(self.repetitions[2])
            ]

        for layer_num, layer in enumerate(layers[first_layer:last_layer]):

            if layer_num == last_layer:
                last_slice = last_slice % self._potential_unit.num_slices

            for start, end, potential_slice in layer.generate_slices(
                    first_slice=first_slice,
                    last_slice=last_slice,
                    max_batch=max_batch):
                yield layer_num + start, layer_num + end, potential_slice

                first_slice = 0

    def generate_slices(self, first_slice=0, last_slice=None, max_batch=1):
        return self._generate_slices_base(first_slice=first_slice,
                                          last_slice=last_slice,
                                          max_batch=max_batch)

    def generate_transmission_functions(self,
                                        energy,
                                        first_slice=0,
                                        last_slice=None,
                                        max_batch=1):
        return self._generate_slices_base(first_slice=first_slice,
                                          last_slice=last_slice,
                                          max_batch=max_batch,
                                          energy=energy)