Beispiel #1
0
    def show(self, waves, **kwargs):
        """
        Visualize the detector region(s) of the detector as applied to a specified wave function.

        Parameters
        ----------
        waves : Waves or SMatrix object
            The wave function the visualization will be created to match
        kwargs :
            Additional keyword arguments for abtem.visualize.mpl.show_measurement_2d.
        """

        waves.grid.check_is_defined()
        array = np.full(waves.gpts, -1, dtype=np.int)

        for i, indices in enumerate(
                self._get_regions(waves.gpts, waves.angular_sampling,
                                  min(waves.cutoff_scattering_angles))):
            array.ravel()[indices] = i

        calibrations = calibrations_from_grid(waves.gpts,
                                              waves.sampling,
                                              names=['alpha_x', 'alpha_y'],
                                              units='mrad',
                                              scale_factor=waves.wavelength *
                                              1e3,
                                              fourier_space=True)

        array = np.fft.fftshift(array, axes=(-1, -2))

        measurement = Measurement(array,
                                  calibrations=calibrations,
                                  name='Detector regions')

        return show_measurement_2d(measurement, discrete_cmap=True, **kwargs)
Beispiel #2
0
    def allocate_measurement(self, waves, scan: AbstractScan) -> Measurement:
        """
        Allocate a Measurement object or an hdf5 file.

        Parameters
        ----------
        waves : Waves or SMatrix object
            The wave function that will define the shape of the diffraction patterns.
        scan: Scan object
            The scan object that will define the scan dimensions the measurement.

        Returns
        -------
        Measurement object or str
            The allocated measurement or path to hdf5 file with the measurement data.
        """

        waves.grid.check_is_defined()
        calibrations = calibrations_from_grid(waves.gpts,
                                              waves.sampling,
                                              names=['x', 'y'],
                                              units='Å')

        array = np.zeros(scan.shape + waves.gpts, dtype=np.complex64)
        measurement = Measurement(array,
                                  calibrations=scan.calibrations +
                                  calibrations)
        if isinstance(self.save_file, str):
            measurement = measurement.write(self.save_file)
        return measurement
Beispiel #3
0
    def allocate_measurement(self,
                             waves,
                             scan: AbstractScan = None) -> Measurement:
        """
        Allocate a Measurement object or an hdf5 file.

        Parameters
        ----------
        waves : Waves or SMatrix object
            The wave function that will define the shape of the diffraction patterns.
        scan: Scan object
            The scan object that will define the scan dimensions the measurement.

        Returns
        -------
        Measurement object or str
            The allocated measurement or path to hdf5 file with the measurement data.
        """

        waves.grid.check_is_defined()
        waves.accelerator.check_is_defined()
        check_max_angle_exceeded(waves, self.max_angle)

        gpts = waves.downsampled_gpts(self.max_angle)
        gpts, new_angular_sampling = self._resampled_gpts(
            gpts, angular_sampling=waves.angular_sampling)

        sampling = (1 / new_angular_sampling[0] / gpts[0] * waves.wavelength *
                    1000, 1 / new_angular_sampling[1] / gpts[1] *
                    waves.wavelength * 1000)

        calibrations = calibrations_from_grid(gpts,
                                              sampling,
                                              names=['alpha_x', 'alpha_y'],
                                              units='mrad',
                                              scale_factor=waves.wavelength *
                                              1000,
                                              fourier_space=True)

        if scan is None:
            scan_shape = ()
            scan_calibrations = ()
        elif isinstance(scan, tuple):
            scan_shape = scan
            scan_calibrations = (None, ) * len(scan)
        else:
            scan_shape = scan.shape
            scan_calibrations = scan.calibrations

        array = np.zeros(scan_shape + gpts)

        measurement = Measurement(array,
                                  calibrations=scan_calibrations +
                                  calibrations)
        if isinstance(self.save_file, str):
            measurement = measurement.write(self.save_file)
        return measurement
Beispiel #4
0
    def intensity(self) -> Measurement:
        """
        :return: The intensity of the wave functions at the image plane.
        """
        calibrations = calibrations_from_grid(self.grid.gpts,
                                              self.grid.sampling, ['x', 'y'])
        calibrations = (None, ) * (len(self.array.shape) - 2) + calibrations

        abs2 = get_device_function(get_array_module(self.array), 'abs2')
        return Measurement(abs2(self.array), calibrations)
Beispiel #5
0
def test_export_import_measurement(tmp_path):
    d = tmp_path / 'sub'
    d.mkdir()
    path = d / 'measurement.hdf5'

    calibrations = calibrations_from_grid((512, 256), (.1, .3), ['x', 'y'],
                                          'Å')

    measurement = Measurement(np.random.rand(512, 256), calibrations)
    measurement.write(path)
    imported_measurement = Measurement.read(path)
    assert np.allclose(measurement.array, imported_measurement.array)
    assert measurement.calibrations[0] == imported_measurement.calibrations[0]
    assert measurement.calibrations[1] == imported_measurement.calibrations[1]
Beispiel #6
0
    def project(self):
        """
        Create a 2d measurement of the projected potential.

        Returns
        -------
        Measurement
        """
        calibrations = calibrations_from_grid(self.grid.gpts,
                                              self.grid.sampling,
                                              names=['x', 'y'])
        array = asnumpy(self.array.sum(0))
        array -= array.min()
        return Measurement(array, calibrations)
Beispiel #7
0
    def allocate_measurement(self, grid: Grid, wavelength: float,
                             scan: AbstractScan) -> Measurement:
        grid.check_is_defined()
        calibrations = calibrations_from_grid(grid.gpts,
                                              grid.sampling,
                                              names=['x', 'y'],
                                              units='Å')

        array = np.zeros(scan.shape + grid.gpts, dtype=np.complex64)
        measurement = Measurement(array,
                                  calibrations=scan.calibrations +
                                  calibrations)
        if isinstance(self.save_file, str):
            measurement = measurement.write(self.save_file)
        return measurement
Beispiel #8
0
    def allocate_measurement(self, grid: Grid, wavelength: float,
                             scan: AbstractScan) -> Measurement:
        grid.check_is_defined()
        shape = (grid.gpts[0] // 2, grid.gpts[1] // 2)

        calibrations = calibrations_from_grid(grid.antialiased_gpts,
                                              grid.antialiased_sampling,
                                              names=['alpha_x', 'alpha_y'],
                                              units='mrad',
                                              scale_factor=wavelength * 1000,
                                              fourier_space=True)

        array = np.zeros(scan.shape + shape)
        measurement = Measurement(array,
                                  calibrations=scan.calibrations +
                                  calibrations)
        if isinstance(self.save_file, str):
            measurement = measurement.write(self.save_file)
        return measurement
Beispiel #9
0
    def show(self, profile: bool = False, **kwargs):
        """
        Show the probe wave function.

        :param profile: If true, show a 1D slice of the probe as a line profile.
        :param kwargs: Additional keyword arguments for the plot.show_line or plot.show_image functions.
            See the documentation of the respective function for a description.
        """

        measurement = self.build(
            (self.extent[0] / 2, self.extent[1] / 2)).intensity()

        if profile:
            array = measurement.array[0]
            array = array[array.shape[0] // 2, :]
            calibration = calibrations_from_grid(
                gpts=(self.grid.gpts[1], ),
                sampling=(self.grid.sampling[1], ),
                names=['x'])[0]
            show_line(array, calibration, **kwargs)
        else:
            return measurement.show(**kwargs)
Beispiel #10
0
    def diffraction_pattern(self) -> Measurement:
        """
        :return: The intensity of the wave functions at the diffraction plane.
        """
        calibrations = calibrations_from_grid(self.grid.antialiased_gpts,
                                              self.grid.antialiased_sampling,
                                              names=['alpha_x', 'alpha_y'],
                                              units='mrad',
                                              scale_factor=self.wavelength *
                                              1000,
                                              fourier_space=True)

        calibrations = (None, ) * (len(self.array.shape) - 2) + calibrations

        xp = get_array_module(self.array)
        abs2 = get_device_function(xp, 'abs2')
        fft2 = get_device_function(xp, 'fft2')
        pattern = asnumpy(
            abs2(
                crop_to_center(
                    xp.fft.fftshift(fft2(self.array, overwrite_x=False)))))
        return Measurement(pattern, calibrations)
Beispiel #11
0
    def show(self, transitions_idx=0):
        intensity = None

        if self._sliced_atoms.slice_thicknesses is None:
            none_slice_thickess = True
            self._sliced_atoms.slice_thicknesses = self._sliced_atoms.atoms.cell[
                2, 2]
        else:
            none_slice_thickess = False

        for slice_idx in range(self.num_slices):
            for t in self._generate_slice_transition_potentials(
                    slice_idx, transitions_idx):
                if intensity is None:
                    intensity = np.abs(t)**2
                else:
                    intensity += np.abs(t)**2

        if none_slice_thickess:
            self._sliced_atoms.slice_thicknesses = None

        calibrations = calibrations_from_grid(self.gpts, self.sampling,
                                              ['x', 'y'])
        Measurement(intensity[0], calibrations, name=str(self)).show()
Beispiel #12
0
    def show(self,
             grid: Grid,
             wavelength: float,
             cbar_label: str = 'Detector regions',
             **kwargs):
        grid.check_is_defined()

        array = np.full(grid.antialiased_gpts, -1, dtype=np.int)
        for i, indices in enumerate(
                self._get_regions(grid.antialiased_gpts,
                                  grid.antialiased_sampling, wavelength)):
            array.ravel()[indices] = i

        calibrations = calibrations_from_grid(grid.antialiased_gpts,
                                              grid.antialiased_sampling,
                                              names=['alpha_x', 'alpha_y'],
                                              units='mrad.',
                                              scale_factor=wavelength * 1000,
                                              fourier_space=True)
        return show_image(array,
                          calibrations,
                          cbar_label=cbar_label,
                          discrete=True,
                          **kwargs)
Beispiel #13
0
def epie(
    measurement: Measurement,
    probe_guess: Probe,
    maxiter: int = 5,
    alpha: float = 1.,
    beta: float = 1.,
    fix_probe: bool = False,
    fix_com: bool = False,
    return_iterations: bool = False,
    max_angle=None,
    seed=None,
    device='cpu',
):
    """
    Reconstruct the phase of a 4D-STEM measurement using the extended Ptychographical Iterative Engine.

    See https://doi.org/10.1016/j.ultramic.2009.05.012

    Parameters
    ----------
    measurement : Measurement object
        4D-STEM measurement.
    probe_guess : Probe object
        The initial guess for the probe.
    maxiter : int
        Run the algorithm for this many iterations.
    alpha : float
        Controls the size of the iterative updates for the object. See reference.
    beta : float
        Controls the size of the iterative updates for the probe. See reference.
    fix_probe : bool
        If True, the probe will not be updated by the algorithm. Default is False.
    fix_com : bool
        If True, the center of mass of the probe will be centered. Default is True.
    return_iterations : bool
        If True, return the reconstruction after every iteration. Default is False.
    max_angle : float, optional
        The maximum reconstructed scattering angle. If this is larger than the input data, the data will be zero-padded.
    seed : int, optional
        Seed the random number generator.
    device : str
        Set the calculation device.

    Returns
    -------
    List of Measurement objects

    """

    diffraction_patterns = measurement.array.reshape(
        (-1, ) + measurement.array.shape[2:])

    if max_angle:
        padding_x = int((max_angle / abs(measurement.calibrations[-2].offset) *
                         diffraction_patterns.shape[-2]) //
                        2) - diffraction_patterns.shape[-2] // 2
        padding_y = int((max_angle / abs(measurement.calibrations[-1].offset) *
                         diffraction_patterns.shape[-1]) //
                        2) - diffraction_patterns.shape[-1] // 2
        diffraction_patterns = np.pad(diffraction_patterns,
                                      ((0, ) * 2, (padding_x, ) * 2,
                                       (padding_y, ) * 2))

    extent = (probe_guess.wavelength * 1e3 /
              measurement.calibrations[2].sampling, probe_guess.wavelength *
              1e3 / measurement.calibrations[3].sampling)

    sampling = (extent[0] / diffraction_patterns.shape[-2],
                extent[1] / diffraction_patterns.shape[-1])

    x = measurement.calibrations[0].coordinates(
        measurement.shape[0]) / sampling[0]
    y = measurement.calibrations[1].coordinates(
        measurement.shape[1]) / sampling[1]
    x, y = np.meshgrid(x, y, indexing='ij')
    positions = np.array([x.ravel(), y.ravel()]).T

    probe_guess.extent = extent
    probe_guess.gpts = diffraction_patterns.shape[-2:]

    calibrations = calibrations_from_grid(probe_guess.gpts,
                                          probe_guess.sampling,
                                          names=['x', 'y'],
                                          units='Å')

    probe_guess._device = device
    probe_guess = probe_guess.build(np.array([0, 0])).array[0]

    result = _run_epie(diffraction_patterns.shape[-2:],
                       probe_guess,
                       diffraction_patterns,
                       positions,
                       maxiter=maxiter,
                       alpha=alpha,
                       beta=beta,
                       return_iterations=return_iterations,
                       fix_probe=fix_probe,
                       fix_com=fix_com,
                       seed=seed)

    if return_iterations:
        object_iterations = [
            Measurement(object, calibrations=calibrations)
            for object in result[0]
        ]
        probe_iterations = [
            Measurement(np.fft.fftshift(probe), calibrations=calibrations)
            for probe in result[1]
        ]
        return object_iterations, probe_iterations, result[2]
    else:
        return (Measurement(result[0], calibrations=calibrations),
                Measurement(np.fft.fftshift(result[1]),
                            calibrations=calibrations), result[2])
Beispiel #14
0
 def measure(self):
     array = np.fft.fftshift(self.build())[0]
     calibrations = calibrations_from_grid(self.gpts, self.sampling,
                                           ['x', 'y'])
     abs2 = get_device_function(get_array_module(array), 'abs2')
     return Measurement(array, calibrations, name=str(self))
Beispiel #15
0
 def show(self, **kwargs):
     calibrations = calibrations_from_grid(self.grid.gpts,
                                           self.grid.sampling,
                                           names=['x', 'y'])
     return show_image(self.array, calibrations, **kwargs)