Exemple #1
0
    def _generate_tds_probes(self,
                             scan: AbstractScan,
                             potential: AbstractTDSPotentialBuilder,
                             max_batch_probes: int,
                             max_batch_expansion: int,
                             potential_pbar: Union[ProgressBar, bool] = True,
                             multislice_pbar: Union[ProgressBar, bool] = True):

        if isinstance(potential_pbar, bool):
            potential_pbar = ProgressBar(total=len(potential),
                                         desc='Potential',
                                         disable=not potential_pbar)

        if isinstance(multislice_pbar, bool):
            multislice_pbar = ProgressBar(total=len(potential),
                                          desc='Multislice',
                                          disable=not multislice_pbar)

        for potential_config in potential.generate_frozen_phonon_potentials(
                pbar=potential_pbar):
            S = self.multislice(potential_config,
                                max_batch=max_batch_expansion,
                                pbar=multislice_pbar)
            yield S._generate_probes(scan, max_batch_probes,
                                     max_batch_expansion)

        multislice_pbar.refresh()
        multislice_pbar.close()

        potential_pbar.refresh()
        potential_pbar.close()
Exemple #2
0
    def scan(self,
             scan: AbstractScan,
             detectors: Union[AbstractDetector, Sequence[AbstractDetector]],
             potential: Union[Atoms, AbstractPotential],
             max_batch: int = 1,
             pbar: bool = True) -> dict:
        """
        Raster scan the probe across the potential and record a measurement for each detector.

        :param scan: Scan object defining the positions of the probe wave functions.
        :param detectors: The detectors recording the measurements.
        :param potential: The potential across which to scan the probe .
        :param max_batch: The probe batch size. Larger batches are faster, but require more memory.
        :param pbar: If true, display progress bars.
        :return: Dictionary of measurements with keys given by the detector.
        """

        self.grid.match(potential.grid)
        self.grid.check_is_defined()

        if isinstance(detectors, AbstractDetector):
            detectors = [detectors]

        measurements = {}
        for detector in detectors:
            measurements[detector] = detector.allocate_measurement(
                self.grid, self.wavelength, scan)

        scan_bar = ProgressBar(total=len(scan), desc='Scan', disable=not pbar)

        if isinstance(potential, AbstractTDSPotentialBuilder):
            probe_generators = self._generate_tds_probes(
                scan, potential, max_batch, pbar)
        else:
            if isinstance(potential, AbstractPotentialBuilder):
                potential = potential.build(pbar=True)

            probe_generators = [
                self._generate_probes(scan, potential, max_batch)
            ]

        for probe_generator in probe_generators:
            scan_bar.reset()
            for start, end, exit_probes in probe_generator:
                for detector, measurement in measurements.items():
                    scan.insert_new_measurement(measurement, start, end,
                                                detector.detect(exit_probes))

                scan_bar.update(end - start)

            scan_bar.refresh()

        scan_bar.close()

        return measurements
Exemple #3
0
    def generate_frozen_phonon_potentials(self,
                                          pbar: Union[ProgressBar,
                                                      bool] = True):
        """
        Function to generate scattering potentials for a set of frozen phonon configurations.

        Parameters
        ----------
        pbar: bool, optional
            Display a progress bar. Default is True.

        Returns
        -------
        generator
            Generator of potentials.
        """

        if isinstance(pbar, bool):
            pbar = ProgressBar(total=len(self),
                               desc='Potential',
                               disable=(not pbar) or (not self._precalculate))

        for atoms in self.frozen_phonons:
            self.atoms.positions[:] = atoms.positions
            # self.atoms.wrap()
            pbar.reset()

            if self._precalculate:
                yield self.build(pbar=pbar)
            else:
                yield self

        pbar.refresh()
        pbar.close()
Exemple #4
0
def _multislice(
    waves: Union['Waves', 'SMatrix', 'PartialSMatrix'],
    potential: AbstractPotential,
    propagator: FresnelPropagator = None,
    pbar: Union[ProgressBar, bool] = True
) -> Union['Waves', 'SMatrix', 'PartialSMatrix']:
    waves.grid.match(potential)

    waves.accelerator.check_is_defined()
    waves.grid.check_is_defined()

    if propagator is None:
        propagator = FresnelPropagator()

    if isinstance(pbar, bool):
        pbar = ProgressBar(total=len(potential),
                           desc='Multislice',
                           disable=not pbar)

    pbar.reset()
    for potential_slice in potential:
        transmit(waves, potential_slice)
        propagator.propagate(waves, potential_slice.thickness)
        pbar.update(1)

    pbar.refresh()
    return waves
Exemple #5
0
    def multislice(self,
                   potential: AbstractPotential,
                   max_batch=None,
                   pbar: bool = True):
        """
        Propagate the scattering matrix through the provided potential.

        :param positions: Positions of the probe wave functions
        :param max_batch: The probe batch size. Larger batches are faster, but require more memory.
        :param pbar: If true, display progress bars.
        :return: Probe exit wave functions as a Waves object.
        """
        propagator = FresnelPropagator()

        if isinstance(pbar, bool):
            pbar = ProgressBar(total=len(potential),
                               desc='Multislice',
                               disable=not pbar)

        for partial_s_matrix in self._generate_partial(max_batch):
            _multislice(partial_s_matrix,
                        potential,
                        propagator=propagator,
                        pbar=pbar)

        pbar.refresh()
        return self
Exemple #6
0
    def get_transition_potentials(self,
                                  extent: Union[float, Sequence[float]] = None,
                                  gpts: Union[float, Sequence[float]] = None,
                                  sampling: Union[float,
                                                  Sequence[float]] = None,
                                  energy: float = None,
                                  pbar=True):

        transitions = []
        if isinstance(pbar, bool):
            pbar = ProgressBar(total=len(self),
                               desc='Transitions',
                               disable=(not pbar))

        _, bound_wave = self._calculate_bound()
        _, continuum_waves = self._calculate_continuum()
        energy_loss = self.energy_loss

        bound_wave = interp1d(*bound_wave,
                              kind='cubic',
                              fill_value='extrapolate',
                              bounds_error=False)

        for bound_state, continuum_state in self.get_transition_quantum_numbers(
        ):
            continuum_wave = continuum_waves[continuum_state[0]]

            continuum_wave = interp1d(*continuum_wave,
                                      kind='cubic',
                                      fill_value='extrapolate',
                                      bounds_error=False)

            transition = ProjectedAtomicTransition(
                Z=self.Z,
                bound_wave=bound_wave,
                continuum_wave=continuum_wave,
                bound_state=bound_state,
                continuum_state=continuum_state,
                energy_loss=energy_loss,
                extent=extent,
                gpts=gpts,
                sampling=sampling,
                energy=energy)
            transitions += [transition]
            pbar.update(1)

        pbar.refresh()
        pbar.close()
        return transitions
Exemple #7
0
    def generate_frozen_phonon_potentials(self,
                                          pbar: Union[ProgressBar,
                                                      bool] = True):
        if isinstance(pbar, bool):
            pbar = ProgressBar(total=len(self),
                               desc='Potential',
                               disable=not pbar)

        for atoms in self.frozen_phonons:
            self.atoms.positions[:] = atoms.positions
            self.atoms.wrap()
            pbar.reset()
            yield self.build(pbar=pbar)

        pbar.refresh()
        pbar.close()
Exemple #8
0
    def _generate_partial(self, max_batch: int = None, pbar: bool = True):
        if max_batch is None:
            n_batches = 1
        else:
            n_batches = (len(self) + (-len(self) % max_batch)) // max_batch

        batch_pbar = ProgressBar(total=len(self),
                                 desc='Batches',
                                 disable=(not pbar) or (n_batches == 1))
        batch_sizes = split_integer(len(self), n_batches)
        N = 0
        for batch_size in batch_sizes:
            yield PartialSMatrix(N, N + batch_size, self)
            N += batch_size
            batch_pbar.update(batch_size)

        batch_pbar.refresh()
        batch_pbar.close()
Exemple #9
0
    def generate_positions(self, max_batch, pbar=False):
        positions = self.get_positions()
        self._partition_batches(max_batch)

        if pbar:
            pbar = ProgressBar(total=len(self))

        for i in range(len(self._batches)):
            indices = self.get_next_batch()
            yield indices, positions[indices]

            if pbar:
                pbar.update(len(indices))

        if pbar:
            pbar.close()
Exemple #10
0
    def scan(self,
             scan: AbstractScan,
             detectors: Sequence[AbstractDetector],
             max_batch_probes=1,
             max_batch_expansion=None,
             pbar: Union[ProgressBar, bool] = True):
        """
        Raster scan the probe across the potential and record a measurement for each detector.

        :param scan: Scan object defining the positions of the probe wave functions.
        :param detectors: The detectors recording the measurments.
        :param potential: The potential across which to scan the probe.
        :param max_batch_probes: The probe batch size. Larger batches are faster, but require more memory.
        :param max_batch_expansion: The expansion plane wave batch size.
        :param pbar: If true, display progress bars.
        :return: Dictionary of measurements with keys given by the detector.
        """

        measurements = {}
        for detector in detectors:
            measurements[detector] = detector.allocate_measurement(
                self.interpolated_grid, self.wavelength, scan)

        if isinstance(pbar, bool):
            pbar = ProgressBar(total=len(scan), desc='Scan', disable=not pbar)

        for start, end, exit_probes in self._generate_probes(
                scan, max_batch_probes, max_batch_expansion):
            for detector, measurement in measurements.items():
                scan.insert_new_measurement(measurement, start, end,
                                            detector.detect(exit_probes))
            pbar.update(end - start)

        pbar.refresh()
        pbar.close()
        return measurements
Exemple #11
0
    def build(self,
              pbar: Union[bool, ProgressBar] = False) -> 'ArrayPotential':
        self.grid.check_is_defined()

        storage_xp = get_array_module_from_device(self._storage)
        array = storage_xp.zeros(
            (self.num_slices, ) + (self.gpts[0], self.gpts[1]),
            dtype=np.float32)
        slice_thicknesses = np.zeros(self.num_slices)

        if isinstance(pbar, bool):
            pbar = ProgressBar(total=len(self),
                               desc='Potential',
                               disable=not pbar)

        pbar.reset()
        for i, potential_slice in enumerate(self.generate_slices()):
            array[i] = copy_to_device(potential_slice.array, self._storage)
            slice_thicknesses[i] = potential_slice.thickness
            pbar.update(1)

        pbar.refresh()

        return ArrayPotential(array, slice_thicknesses, self.extent)
Exemple #12
0
def _run_epie(object,
              probe: np.ndarray,
              diffraction_patterns: np.ndarray,
              positions: np.ndarray,
              maxiter: int,
              alpha: float = 1.,
              beta: float = 1.,
              fix_probe: bool = False,
              fix_com: bool = False,
              return_iterations: bool = False,
              seed=None):
    xp = get_array_module(probe)

    object = xp.array(object)
    probe = xp.array(probe)

    if len(diffraction_patterns.shape) != 3:
        raise ValueError()

    if len(diffraction_patterns) != len(positions):
        raise ValueError()

    if object.shape == (2, ):
        object = xp.ones((int(object[0]), int(object[1])), dtype=xp.complex64)
    elif len(object.shape) != 2:
        raise ValueError()

    if probe.shape != diffraction_patterns.shape[1:]:
        raise ValueError()

    if probe.shape != object.shape:
        raise ValueError()

    if return_iterations:
        object_iterations = []
        probe_iterations = []
        SSE_iterations = []

    if seed is not None:
        np.random.seed(seed)

    diffraction_patterns = np.fft.ifftshift(np.sqrt(diffraction_patterns),
                                            axes=(-2, -1))

    SSE = 0.
    k = 0
    outer_pbar = ProgressBar(total=maxiter)
    inner_pbar = ProgressBar(total=len(positions))

    while k < maxiter:
        indices = np.arange(len(positions))
        np.random.shuffle(indices)

        old_position = xp.array((0., 0.))
        inner_pbar.reset()
        SSE = 0.
        for j in indices:
            position = xp.array(positions[j])

            diffraction_pattern = xp.array(diffraction_patterns[j])
            illuminated_object = fft_shift(object, old_position - position)

            g = illuminated_object * probe
            gprime = xp.fft.ifft2(diffraction_pattern *
                                  xp.exp(1j * xp.angle(xp.fft.fft2(g))))

            object = illuminated_object + alpha * (
                gprime - g) * xp.conj(probe) / (xp.max(xp.abs(probe))**2)
            old_position = position

            if not fix_probe:
                probe = probe + beta * (
                    gprime - g) * xp.conj(illuminated_object) / (xp.max(
                        xp.abs(illuminated_object))**2)

            # SSE += xp.sum(xp.abs(G) ** 2 - diffraction_pattern) ** 2
            inner_pbar.update(1)

        object = fft_shift(object, position)

        if fix_com:
            com = center_of_mass(xp.fft.fftshift(xp.abs(probe)**2))
            probe = xp.fft.ifftshift(fft_shift(probe, -xp.array(com)))

        # SSE = SSE / np.prod(diffraction_patterns.shape)

        if return_iterations:
            object_iterations.append(object)
            probe_iterations.append(probe)
            SSE_iterations.append(SSE)

        outer_pbar.update(1)
        # if verbose:
        #    print(f'Iteration {k:<{len(str(maxiter))}}, SSE = {float(SSE):.3e}')

        k += 1

    inner_pbar.close()
    outer_pbar.close()

    if return_iterations:
        return object_iterations, probe_iterations, SSE_iterations
    else:
        return object, probe, SSE
Exemple #13
0
    def build(
        self,
        first_slice: int = 0,
        last_slice: int = None,
        energy: float = None,
        max_batch: int = None,
        pbar: Union[bool, ProgressBar] = False,
    ) -> 'PotentialArray':
        """
        Precalcaulate the potential as a potential array.

        Parameters
        ----------
        first_slice: int
            First potential slice to generate.
        last_slice: int, optional
            Last potential slice generate.
        energy: float
            Electron energy [eV]. If given, the transmission functions will be returned.
        max_batch: int
            Maximum number of potential slices calculated in parallel.
        pbar: bool
            If true, show progress bar.

        Returns
        -------
        PotentialArray object
        """

        self.grid.check_is_defined()

        if last_slice is None:
            last_slice = len(self)

        if max_batch is None:
            max_batch = self._estimate_max_batch()

        storage_xp = get_array_module_from_device(self._storage)

        if energy is None:
            array = storage_xp.zeros(
                (last_slice - first_slice, ) + (self.gpts[0], self.gpts[1]),
                dtype=np.float32)
            generator = self.generate_slices(max_batch=max_batch,
                                             first_slice=first_slice,
                                             last_slice=last_slice)
        else:
            array = storage_xp.zeros(
                (last_slice - first_slice, ) + (self.gpts[0], self.gpts[1]),
                dtype=np.complex64)
            generator = self.generate_transmission_functions(
                energy=energy,
                max_batch=max_batch,
                first_slice=first_slice,
                last_slice=last_slice)

        slice_thicknesses = np.zeros(last_slice - first_slice)

        if isinstance(pbar, bool):
            pbar = ProgressBar(total=len(self),
                               desc='Potential',
                               disable=not pbar)
            close_pbar = True
        else:
            close_pbar = False

        pbar.reset()
        for start, end, potential_slice in generator:
            array[start:end] = copy_to_device(potential_slice.array,
                                              self._storage)
            slice_thicknesses[start:end] = potential_slice.slice_thicknesses
            pbar.update(end - start)

        pbar.refresh()

        if close_pbar:
            pbar.close()

        if energy is None:
            return PotentialArray(array,
                                  slice_thicknesses=slice_thicknesses,
                                  extent=self.extent)
        else:
            return TransmissionFunction(array,
                                        slice_thicknesses=slice_thicknesses,
                                        extent=self.extent,
                                        energy=energy)
Exemple #14
0
    def _generate_tds_probes(self, scan, potential, max_batch, pbar):
        tds_bar = ProgressBar(total=len(potential.frozen_phonons),
                              desc='TDS',
                              disable=(not pbar)
                              or (len(potential.frozen_phonons) == 1))
        potential_pbar = ProgressBar(total=len(potential),
                                     desc='Potential',
                                     disable=not pbar)

        for potential_config in potential.generate_frozen_phonon_potentials(
                pbar=potential_pbar):
            yield self._generate_probes(scan, potential_config, max_batch)
            tds_bar.update(1)

        potential_pbar.close()
        tds_bar.refresh()
        tds_bar.close()
Exemple #15
0
    def multislice(self,
                   potential: AbstractPotential,
                   pbar: Union[ProgressBar, bool] = True) -> 'Waves':
        """
        Propagate and transmit wave function through the provided potential.

        :param potential: The potential through which to propagate the wave function.
        :param pbar: If true, display a progress bar.
        :return: Wave function at the exit plane of the potential.
        """
        self.grid.match(potential)

        propagator = FresnelPropagator()

        if isinstance(potential, AbstractTDSPotentialBuilder):
            xp = get_array_module(self.array)
            N = len(potential.frozen_phonons)
            out_array = xp.zeros((N, ) + self.array.shape, dtype=xp.complex64)
            tds_waves = self.__class__(out_array,
                                       extent=self.extent,
                                       energy=self.energy)

            tds_pbar = ProgressBar(total=N,
                                   desc='TDS',
                                   disable=(not pbar) or (N == 1))
            multislice_pbar = ProgressBar(total=len(potential),
                                          desc='Multislice',
                                          disable=not pbar)

            for i, potential_config in enumerate(
                    potential.generate_frozen_phonon_potentials(pbar=pbar)):
                multislice_pbar.reset()

                exit_waves = _multislice(copy(self),
                                         potential_config,
                                         propagator=propagator,
                                         pbar=multislice_pbar)
                tds_waves.array[i] = exit_waves.array
                tds_pbar.update(1)

            multislice_pbar.close()
            tds_pbar.close()

            return tds_waves
        else:
            return _multislice(self, potential, propagator, pbar)
Exemple #16
0
    def scan(self,
             potential: Union[Atoms, AbstractPotential],
             scan: AbstractScan,
             detectors: Sequence[AbstractDetector],
             max_batch_probes: int = 1,
             max_batch_expansion: int = None,
             pbar: bool = True):

        self.grid.match(potential.grid)
        self.grid.check_is_defined()

        measurements = {}
        for detector in detectors:
            measurements[detector] = detector.allocate_measurement(
                self.interpolated_grid, self.wavelength, scan)

        if isinstance(potential, AbstractTDSPotentialBuilder):
            probe_generators = self._generate_tds_probes(
                scan,
                potential,
                max_batch_probes=max_batch_probes,
                max_batch_expansion=max_batch_expansion,
                potential_pbar=pbar,
                multislice_pbar=pbar)
        else:
            if isinstance(potential, AbstractPotentialBuilder):
                potential = potential.build(pbar=True)

            S = self.multislice(potential,
                                max_batch=max_batch_probes,
                                pbar=pbar)
            probe_generators = [
                S._generate_probes(scan, max_batch_probes, max_batch_expansion)
            ]

        tds_bar = ProgressBar(total=len(potential.frozen_phonons),
                              desc='TDS',
                              disable=(not pbar)
                              or (len(potential.frozen_phonons) == 1))

        scan_bar = ProgressBar(total=len(scan), desc='Scan', disable=not pbar)

        for probe_generator in probe_generators:
            scan_bar.reset()

            for start, end, exit_probes in probe_generator:
                for detector, measurement in measurements.items():
                    scan.insert_new_measurement(measurement, start, end,
                                                detector.detect(exit_probes))

                scan_bar.update(end - start)

            scan_bar.refresh()
            tds_bar.update(1)

        scan_bar.close()

        tds_bar.refresh()
        tds_bar.close()

        return measurements