def scan(self, scan: AbstractScan, detectors: Union[AbstractDetector, Sequence[AbstractDetector]], potential: Union[Atoms, AbstractPotential], max_batch: int = 1, pbar: bool = True) -> dict: """ Raster scan the probe across the potential and record a measurement for each detector. :param scan: Scan object defining the positions of the probe wave functions. :param detectors: The detectors recording the measurements. :param potential: The potential across which to scan the probe . :param max_batch: The probe batch size. Larger batches are faster, but require more memory. :param pbar: If true, display progress bars. :return: Dictionary of measurements with keys given by the detector. """ self.grid.match(potential.grid) self.grid.check_is_defined() if isinstance(detectors, AbstractDetector): detectors = [detectors] measurements = {} for detector in detectors: measurements[detector] = detector.allocate_measurement( self.grid, self.wavelength, scan) scan_bar = ProgressBar(total=len(scan), desc='Scan', disable=not pbar) if isinstance(potential, AbstractTDSPotentialBuilder): probe_generators = self._generate_tds_probes( scan, potential, max_batch, pbar) else: if isinstance(potential, AbstractPotentialBuilder): potential = potential.build(pbar=True) probe_generators = [ self._generate_probes(scan, potential, max_batch) ] for probe_generator in probe_generators: scan_bar.reset() for start, end, exit_probes in probe_generator: for detector, measurement in measurements.items(): scan.insert_new_measurement(measurement, start, end, detector.detect(exit_probes)) scan_bar.update(end - start) scan_bar.refresh() scan_bar.close() return measurements
def get_transition_potentials(self, extent: Union[float, Sequence[float]] = None, gpts: Union[float, Sequence[float]] = None, sampling: Union[float, Sequence[float]] = None, energy: float = None, pbar=True): transitions = [] if isinstance(pbar, bool): pbar = ProgressBar(total=len(self), desc='Transitions', disable=(not pbar)) _, bound_wave = self._calculate_bound() _, continuum_waves = self._calculate_continuum() energy_loss = self.energy_loss bound_wave = interp1d(*bound_wave, kind='cubic', fill_value='extrapolate', bounds_error=False) for bound_state, continuum_state in self.get_transition_quantum_numbers( ): continuum_wave = continuum_waves[continuum_state[0]] continuum_wave = interp1d(*continuum_wave, kind='cubic', fill_value='extrapolate', bounds_error=False) transition = ProjectedAtomicTransition( Z=self.Z, bound_wave=bound_wave, continuum_wave=continuum_wave, bound_state=bound_state, continuum_state=continuum_state, energy_loss=energy_loss, extent=extent, gpts=gpts, sampling=sampling, energy=energy) transitions += [transition] pbar.update(1) pbar.refresh() pbar.close() return transitions
def _generate_tds_probes(self, scan, potential, max_batch, pbar): tds_bar = ProgressBar(total=len(potential.frozen_phonons), desc='TDS', disable=(not pbar) or (len(potential.frozen_phonons) == 1)) potential_pbar = ProgressBar(total=len(potential), desc='Potential', disable=not pbar) for potential_config in potential.generate_frozen_phonon_potentials( pbar=potential_pbar): yield self._generate_probes(scan, potential_config, max_batch) tds_bar.update(1) potential_pbar.close() tds_bar.refresh() tds_bar.close()
def _generate_partial(self, max_batch: int = None, pbar: bool = True): if max_batch is None: n_batches = 1 else: n_batches = (len(self) + (-len(self) % max_batch)) // max_batch batch_pbar = ProgressBar(total=len(self), desc='Batches', disable=(not pbar) or (n_batches == 1)) batch_sizes = split_integer(len(self), n_batches) N = 0 for batch_size in batch_sizes: yield PartialSMatrix(N, N + batch_size, self) N += batch_size batch_pbar.update(batch_size) batch_pbar.refresh() batch_pbar.close()
def scan(self, potential: Union[Atoms, AbstractPotential], scan: AbstractScan, detectors: Sequence[AbstractDetector], max_batch_probes: int = 1, max_batch_expansion: int = None, pbar: bool = True): self.grid.match(potential.grid) self.grid.check_is_defined() measurements = {} for detector in detectors: measurements[detector] = detector.allocate_measurement( self.interpolated_grid, self.wavelength, scan) if isinstance(potential, AbstractTDSPotentialBuilder): probe_generators = self._generate_tds_probes( scan, potential, max_batch_probes=max_batch_probes, max_batch_expansion=max_batch_expansion, potential_pbar=pbar, multislice_pbar=pbar) else: if isinstance(potential, AbstractPotentialBuilder): potential = potential.build(pbar=True) S = self.multislice(potential, max_batch=max_batch_probes, pbar=pbar) probe_generators = [ S._generate_probes(scan, max_batch_probes, max_batch_expansion) ] tds_bar = ProgressBar(total=len(potential.frozen_phonons), desc='TDS', disable=(not pbar) or (len(potential.frozen_phonons) == 1)) scan_bar = ProgressBar(total=len(scan), desc='Scan', disable=not pbar) for probe_generator in probe_generators: scan_bar.reset() for start, end, exit_probes in probe_generator: for detector, measurement in measurements.items(): scan.insert_new_measurement(measurement, start, end, detector.detect(exit_probes)) scan_bar.update(end - start) scan_bar.refresh() tds_bar.update(1) scan_bar.close() tds_bar.refresh() tds_bar.close() return measurements