Ejemplo n.º 1
0
    def detect(self, waves, normalize: bool = True) -> np.ndarray:
        """
        Integrate the intensity of a the wave functions over the detector range.

        Parameters
        ----------
        waves : Waves object
            The batch of wave functions to detect.
        normalize : bool
            Normalize output by the total intensity of the wave function.

        Returns
        -------
        1d array
            Detected values as a 1D array. The array has the same length as the batch size of the wave functions.
        """

        xp = get_array_module(waves.array)
        fft2 = get_device_function(xp, 'fft2')
        abs2 = get_device_function(xp, 'abs2')

        intensity = abs2(fft2(waves.array, overwrite_x=False))
        return self._integrate_array(intensity, waves.angular_sampling,
                                     min(waves.cutoff_scattering_angles),
                                     normalize)
Ejemplo n.º 2
0
    def detect(self, waves) -> np.ndarray:
        """
        Integrate the intensity of a the wave functions over the detector range.

        Parameters
        ----------
        waves: Waves object
            The batch of wave functions to detect.

        Returns
        -------
        2d array
            Detected values. The array has shape of (batch size, number of bins).
        """

        xp = get_array_module(waves.array)
        fft2 = get_device_function(xp, 'fft2')
        abs2 = get_device_function(xp, 'abs2')
        sum_run_length_encoded = get_device_function(xp, 'sum_run_length_encoded')

        intensity = abs2(fft2(waves.array, overwrite_x=False))

        indices = self._get_regions(waves.gpts, waves.angular_sampling, min(waves.cutoff_scattering_angles), xp)
        total = xp.sum(intensity, axis=(-2, -1))

        separators = xp.concatenate((xp.array([0]), xp.cumsum(xp.array([len(ring) for ring in indices]))))
        intensity = intensity.reshape((intensity.shape[0], -1))[:, xp.concatenate(indices)]
        result = xp.zeros((len(intensity), len(separators) - 1), dtype=xp.float32)
        sum_run_length_encoded(intensity, result, separators)

        return result
Ejemplo n.º 3
0
def fft_interpolate_2d(array,
                       new_shape,
                       normalization='values',
                       overwrite_x=False):
    xp = get_array_module(array)
    fft2 = get_device_function(xp, 'fft2')
    ifft2 = get_device_function(xp, 'ifft2')

    old_size = array.shape[-2] * array.shape[-1]

    if np.iscomplexobj(array):
        cropped = fft_crop(fft2(array), new_shape)
        array = ifft2(cropped, overwrite_x=overwrite_x)
    else:
        array = xp.complex64(array)
        array = ifft2(fft_crop(fft2(array), new_shape),
                      overwrite_x=overwrite_x).real

    if normalization == 'values':
        array *= array.shape[-1] * array.shape[-2] / old_size
    elif normalization == 'norm':
        array *= array.shape[-1] * array.shape[-2] / old_size
    elif (normalization != False) and (normalization != None):
        raise RuntimeError()

    return array
Ejemplo n.º 4
0
def calculate_far_field_intensity(waves, overwrite: bool = False):
    xp = get_array_module(waves.array)
    fft2 = get_device_function(xp, 'fft2')
    abs2 = get_device_function(xp, 'abs2')
    array = fft2(waves.array, overwrite_x=overwrite)
    intensity = crop_to_center(xp.fft.fftshift(array, axes=(-2, -1)))
    return abs2(intensity)
Ejemplo n.º 5
0
    def detect(self, waves) -> np.ndarray:
        xp = get_array_module(waves.array)
        abs2 = get_device_function(xp, 'abs2')
        fft2 = get_device_function(xp, 'fft2')

        intensity = abs2(fft2(waves.array, overwrite_x=False))
        intensity = xp.fft.fftshift(intensity, axes=(-1, -2))
        intensity = crop_to_center(intensity)
        return intensity
Ejemplo n.º 6
0
    def apply_ctf(self, ctf: CTF = None, **kwargs):
        """
        Apply the aberrations defined by a CTF object to wave function.

        :param ctf: Contrast Transfer Function object to be applied.
        :param kwargs: Provide the aberration coefficients as keyword arguments.
        :return: The wave functions with aberrations applied.
        """
        xp = get_array_module(self.array)
        fft2_convolve = get_device_function(get_array_module(self.array),
                                            'fft2_convolve')

        if ctf is None:
            ctf = CTF(**kwargs)

        ctf.accelerator.match(self.accelerator)
        kx, ky = spatial_frequencies(self.grid.gpts, self.grid.sampling)
        alpha, phi = polargrid(xp.asarray(kx * self.wavelength),
                               xp.asarray(ky * self.wavelength))
        kernel = ctf.evaluate(alpha, phi)

        return self.__class__(fft2_convolve(self.array,
                                            kernel,
                                            overwrite_x=False),
                              extent=self.extent,
                              energy=self.energy)
Ejemplo n.º 7
0
 def _bandlimit(self, array):
     xp = get_array_module(array)
     fft2_convolve = get_device_function(xp, 'fft2_convolve')
     fft2_convolve(array,
                   self.get_mask(array.shape[-2:], (1, 1), xp),
                   overwrite_x=True)
     return array
Ejemplo n.º 8
0
def fourier_translation_operator(positions: np.ndarray, shape: tuple):
    positions_shape = positions.shape

    if len(positions_shape) == 1:
        positions = positions[None]

    xp = get_array_module(positions)
    complex_exponential = get_device_function(xp, 'complex_exponential')

    kx, ky = spatial_frequencies(shape, (1., 1.))
    kx = kx.reshape((1, -1, 1))
    ky = ky.reshape((1, 1, -1))
    kx = xp.asarray(kx)
    ky = xp.asarray(ky)
    positions = xp.asarray(positions)
    x = positions[:, 0].reshape((-1, ) + (1, 1))
    y = positions[:, 1].reshape((-1, ) + (1, 1))

    result = complex_exponential(-2 * np.pi * kx * x) * complex_exponential(
        -2 * np.pi * ky * y)

    if len(positions_shape) == 1:
        return result[0]
    else:
        return result
Ejemplo n.º 9
0
    def detect(self, waves) -> np.ndarray:
        """
        Calculate the far field intensity of the wave functions. The output is cropped to include the non-suppressed
        frequencies from the antialiased 2D fourier spectrum.

        Parameters
        ----------
        waves: Waves object
            The batch of wave functions to detect.

        Returns
        -------
            Detected values. The first dimension indexes the batch size, the second and third indexes the two components
            of the spatial frequency.
        """

        xp = get_array_module(waves.array)
        abs2 = get_device_function(xp, 'abs2')

        waves = waves.far_field(max_angle=self.max_angle)
        intensity = abs2(waves.array)

        intensity = xp.fft.fftshift(intensity, axes=(-2, -1))
        intensity = self._interpolate(intensity, waves.angular_sampling)
        return intensity
Ejemplo n.º 10
0
    def generate_slices(self, start=0, end=None):
        if end is None:
            end = len(self)

        self.grid.check_is_defined()

        xp = get_array_module(self._device)
        interpolate_radial_functions = get_device_function(
            xp, 'interpolate_radial_functions')

        atoms = self.atoms.copy()
        indices_by_number = {
            number: np.where(atoms.numbers == number)[0]
            for number in np.unique(atoms.numbers)
        }

        array = xp.zeros(self.gpts, dtype=xp.float32)
        a = np.sum([self.get_slice_thickness(i) for i in range(0, start)])
        for i in range(start, end):
            array[:] = 0.
            b = a + self.get_slice_thickness(i)

            for number, indices in indices_by_number.items():
                slice_atoms = atoms[indices]

                integrator, disc_indices = self.get_integrator(number)
                disc_indices = xp.asarray(disc_indices)

                slice_atoms = slice_atoms[
                    (slice_atoms.positions[:, 2] > a - integrator.cutoff) *
                    (slice_atoms.positions[:, 2] < b + integrator.cutoff)]

                slice_atoms = pad_atoms(slice_atoms, integrator.cutoff)

                if len(slice_atoms) == 0:
                    continue

                vr = np.zeros((len(slice_atoms), len(integrator.r)),
                              np.float32)
                dvdr = np.zeros((len(slice_atoms), len(integrator.r)),
                                np.float32)
                for j, atom in enumerate(slice_atoms):
                    am, bm = a - atom.z, b - atom.z
                    vr[j], dvdr[j, :-1] = integrator.integrate(am, bm)
                vr = xp.asarray(vr, dtype=xp.float32)
                dvdr = xp.asarray(dvdr, dtype=xp.float32)
                r = xp.asarray(integrator.r, dtype=xp.float32)

                slice_positions = xp.asarray(slice_atoms.positions[:, :2],
                                             dtype=xp.float32)
                sampling = xp.asarray(self.sampling, dtype=xp.float32)

                interpolate_radial_functions(array, disc_indices,
                                             slice_positions, vr, r, dvdr,
                                             sampling)
            a = b

            yield ProjectedPotential(array / kappa,
                                     self.get_slice_thickness(i),
                                     extent=self.extent)
Ejemplo n.º 11
0
    def build(self, positions: Sequence[Sequence[float]] = None) -> Waves:
        """
        Build probe wave functions at the provided positions.

        :param positions: Positions of the probe wave functions
        :return: Probe wave functions as a Waves object.
        """

        self.grid.check_is_defined()
        self.accelerator.check_is_defined()
        xp = get_array_module_from_device(self._device)
        fft2 = get_device_function(xp, 'fft2')

        if positions is None:
            positions = xp.zeros((1, 2), dtype=xp.float32)
        else:
            positions = xp.array(positions, dtype=xp.float32)

        if len(positions.shape) == 1:
            positions = xp.expand_dims(positions, axis=0)

        array = fft2(self._evaluate_ctf(xp) *
                     self._fourier_translation_operator(positions),
                     overwrite_x=True)

        return Waves(array, extent=self.extent, energy=self.energy)
Ejemplo n.º 12
0
 def _evaluate_propagator_array(self, gpts, sampling, wavelength, dz, xp):
     complex_exponential = get_device_function(xp, 'complex_exponential')
     kx = xp.fft.fftfreq(gpts[0], sampling[0]).astype(xp.float32)
     ky = xp.fft.fftfreq(gpts[1], sampling[1]).astype(xp.float32)
     f = (complex_exponential(-(kx**2)[:, None] * np.pi * wavelength * dz) *
          complex_exponential(-(ky**2)[None] * np.pi * wavelength * dz))
     f *= xp.asarray(self._antialiasing_aperture(gpts))
     return f
Ejemplo n.º 13
0
    def as_transmission_functions(self, energy):
        xp = get_array_module(self.array)
        complex_exponential = get_device_function(xp, 'complex_exponential')

        array = complex_exponential(energy2sigma(energy) * self._array)
        return TransmissionFunctions(array,
                                     slice_thicknesses=self._slice_thicknesses,
                                     extent=self.extent,
                                     energy=energy)
Ejemplo n.º 14
0
    def detect(self, waves) -> np.ndarray:
        """
        Integrate the intensity of a the wave functions over the detector range.

        Parameters
        ----------
        waves: Waves object
            The batch of wave functions to detect.

        Returns
        -------
        3d array
            Detected values. The first dimension indexes the batch size, the second and third indexes the radial and
            angular bins, respectively.
        """

        xp = get_array_module(waves.array)
        fft2 = get_device_function(xp, 'fft2')
        abs2 = get_device_function(xp, 'abs2')
        sum_run_length_encoded = get_device_function(xp,
                                                     'sum_run_length_encoded')
        intensity = abs2(fft2(waves.array, overwrite_x=False))

        indices = self._get_regions(waves.gpts, waves.angular_sampling,
                                    min(waves.cutoff_scattering_angles), xp)

        separators = xp.concatenate(
            (xp.array([0]), xp.cumsum(xp.array([len(ring)
                                                for ring in indices]))))
        intensity = intensity.reshape(
            (intensity.shape[0], -1))[:, xp.concatenate(indices)]
        result = xp.zeros((len(intensity), len(separators) - 1),
                          dtype=xp.float32)
        sum_run_length_encoded(intensity, result, separators)

        shape = (-1, )
        if self.nbins_radial > 1:
            shape += (self.nbins_radial, )

        if self.nbins_angular > 1:
            shape += (self.nbins_angular, )

        return result.reshape(shape)
Ejemplo n.º 15
0
    def intensity(self) -> Measurement:
        """
        :return: The intensity of the wave functions at the image plane.
        """
        calibrations = calibrations_from_grid(self.grid.gpts,
                                              self.grid.sampling, ['x', 'y'])
        calibrations = (None, ) * (len(self.array.shape) - 2) + calibrations

        abs2 = get_device_function(get_array_module(self.array), 'abs2')
        return Measurement(abs2(self.array), calibrations)
Ejemplo n.º 16
0
    def _interpolate(self, array, angular_sampling):
        xp = get_array_module(array)
        interpolate_bilinear = get_device_function(xp, 'interpolate_bilinear')

        new_gpts, new_angular_sampling = self._resampled_gpts(array.shape[-2:], angular_sampling)
        v, u, vw, uw = self._bilinear_nodes_and_weight(array.shape[-2:],
                                                       new_gpts,
                                                       angular_sampling,
                                                       new_angular_sampling,
                                                       xp)

        return interpolate_bilinear(array, v, u, vw, uw)
Ejemplo n.º 17
0
    def build(self) -> SMatrix:
        self.grid.check_is_defined()
        self.accelerator.check_is_defined()

        xp = get_array_module(self._device)
        storage_xp = get_array_module_from_device(self._storage)
        complex_exponential = get_device_function(xp, 'complex_exponential')

        n_max = int(
            xp.ceil(self.expansion_cutoff / 1000. /
                    (self.wavelength / self.extent[0] * self.interpolation)))
        m_max = int(
            xp.ceil(self.expansion_cutoff / 1000. /
                    (self.wavelength / self.extent[1] * self.interpolation)))

        n = xp.arange(-n_max, n_max + 1, dtype=xp.float32)
        w = xp.asarray(self.extent[0], dtype=xp.float32)
        m = xp.arange(-m_max, m_max + 1, dtype=xp.float32)
        h = xp.asarray(self.extent[1], dtype=xp.float32)

        kx = n / w * xp.float32(self.interpolation)
        ky = m / h * xp.float32(self.interpolation)

        mask = kx[:, None]**2 + ky[None, :]**2 < (self.expansion_cutoff /
                                                  1000. / self.wavelength)**2
        kx, ky = xp.meshgrid(kx, ky, indexing='ij')
        kx = kx[mask]
        ky = ky[mask]

        x, y = coordinates(extent=self.extent,
                           gpts=self.gpts,
                           endpoint=self.grid.endpoint)
        x = xp.asarray(x)
        y = xp.asarray(y)

        array = storage_xp.zeros((len(kx), ) + (self.gpts[0], self.gpts[1]),
                                 dtype=np.complex64)

        for i in range(len(kx)):
            array[i] = copy_to_device(
                complex_exponential(
                    -2 * np.pi * kx[i, None, None] * x[:, None]) *
                complex_exponential(
                    -2 * np.pi * ky[i, None, None] * y[None, :]),
                self._storage)

        return SMatrix(array,
                       expansion_cutoff=self.expansion_cutoff,
                       interpolation=self.interpolation,
                       extent=self.extent,
                       energy=self.energy,
                       k=(kx, ky))
Ejemplo n.º 18
0
    def diffraction_pattern(self) -> Measurement:
        """
        :return: The intensity of the wave functions at the diffraction plane.
        """
        calibrations = calibrations_from_grid(self.grid.antialiased_gpts,
                                              self.grid.antialiased_sampling,
                                              names=['alpha_x', 'alpha_y'],
                                              units='mrad',
                                              scale_factor=self.wavelength *
                                              1000,
                                              fourier_space=True)

        calibrations = (None, ) * (len(self.array.shape) - 2) + calibrations

        xp = get_array_module(self.array)
        abs2 = get_device_function(xp, 'abs2')
        fft2 = get_device_function(xp, 'fft2')
        pattern = asnumpy(
            abs2(
                crop_to_center(
                    xp.fft.fftshift(fft2(self.array, overwrite_x=False)))))
        return Measurement(pattern, calibrations)
Ejemplo n.º 19
0
    def _fourier_translation_operator(self, positions):
        xp = get_array_module(positions)
        complex_exponential = get_device_function(xp, 'complex_exponential')

        kx, ky = spatial_frequencies(self.grid.gpts, self.grid.sampling)
        kx = kx.reshape((1, -1, 1))
        ky = ky.reshape((1, 1, -1))
        kx = xp.asarray(kx)
        ky = xp.asarray(ky)
        positions = xp.asarray(positions)
        x = positions[:, 0].reshape((-1, ) + (1, 1))
        y = positions[:, 1].reshape((-1, ) + (1, 1))

        return complex_exponential(2 * np.pi * kx * x) * complex_exponential(
            2 * np.pi * ky * y)
Ejemplo n.º 20
0
    def propagate(self, waves: Union['Waves', 'SMatrix', 'PartialSMatrix'],
                  dz: float):
        """
        Propgate wave function or scattering matrix.

        :param waves: Wave function or scattering matrix to propagate.
        :param dz: Propagation distance [Å].
        """
        propagator_array = self._evaluate_propagator_array(
            waves.grid.gpts, waves.grid.sampling, dz, waves.wavelength,
            get_array_module(waves.array))

        fft2_convolve = get_device_function(get_array_module(waves.array),
                                            'fft2_convolve')

        fft2_convolve(waves._array, propagator_array)
Ejemplo n.º 21
0
    def bandlimit(self, waves):
        """

        Parameters
        ----------
        waves

        Returns
        -------

        """
        xp = get_array_module(waves.array)
        fft2_convolve = get_device_function(xp, 'fft2_convolve')
        fft2_convolve(waves.array,
                      self.get_mask(waves.gpts, waves.sampling, xp),
                      overwrite_x=True)
        return waves
Ejemplo n.º 22
0
def transmit(waves: Union['Waves', 'SMatrix', 'PartialSMatrix'],
             potential_slice: ProjectedPotential):
    """
    Transmit wave function or scattering matrix.

    :param waves: Wave function or scattering matrix to propagate.
    :param potential_slice: Projected potential to transmit the wave function through.
    """
    xp = get_array_module(waves.array)
    complex_exponential = get_device_function(xp, 'complex_exponential')
    dim_padding = len(waves._array.shape) - len(potential_slice.array.shape)
    slice_array = potential_slice.array.reshape((1, ) * dim_padding +
                                                potential_slice.array.shape)

    if np.iscomplexobj(slice_array):
        waves._array *= copy_to_device(slice_array, xp)
    else:
        waves._array *= complex_exponential(
            copy_to_device(waves.accelerator.sigma * slice_array, xp))
Ejemplo n.º 23
0
    def as_transmission_function(self,
                                 energy: float,
                                 in_place: bool = True,
                                 max_batch: int = 1,
                                 antialias_filter: AntialiasFilter = None):
        """
        Calculate the transmission functions for a specific energy.

        Parameters
        ----------
        energy: float
            Electron energy [eV].

        Returns
        -------
        TransmissionFunction object
        """

        xp = get_array_module(self.array)
        complex_exponential = get_device_function(xp, 'complex_exponential')

        array = self._array
        if not in_place:
            array = array.copy()

        array = complex_exponential(energy2sigma(energy) * array)

        t = TransmissionFunction(
            array,
            slice_thicknesses=self._slice_thicknesses.copy(),
            extent=self.extent,
            energy=energy)

        if antialias_filter is None:
            antialias_filter = AntialiasFilter()

        for start, end, potential_slices in t.generate_slices(
                max_batch=max_batch):
            antialias_filter.bandlimit(potential_slices)

        return t
Ejemplo n.º 24
0
def fourier_translation_operator(positions: np.ndarray,
                                 shape: tuple) -> np.ndarray:
    """
    Create an array representing one or more phase ramp(s) for shifting another array.

    Parameters
    ----------
    positions : array of xy-positions
    shape : two int

    Returns
    -------

    """

    positions_shape = positions.shape

    if len(positions_shape) == 1:
        positions = positions[None]

    xp = get_array_module(positions)
    complex_exponential = get_device_function(xp, 'complex_exponential')

    kx, ky = spatial_frequencies(shape, (1., 1.))
    kx = kx.reshape((1, -1, 1))
    ky = ky.reshape((1, 1, -1))
    kx = xp.asarray(kx)
    ky = xp.asarray(ky)
    positions = xp.asarray(positions)
    x = positions[:, 0].reshape((-1, ) + (1, 1))
    y = positions[:, 1].reshape((-1, ) + (1, 1))

    result = complex_exponential(-2 * np.pi * kx * x) * complex_exponential(
        -2 * np.pi * ky * y)

    if len(positions_shape) == 1:
        return result[0]
    else:
        return result
Ejemplo n.º 25
0
    def generate_slices(self, first_slice=0, last_slice=None, max_batch=1):
        interpolate_radial_functions = get_device_function(np, 'interpolate_radial_functions')

        if last_slice is None:
            last_slice = len(self)

        if self._plane != 'xy':
            atoms = rotate_atoms_to_plane(self._calculator.atoms.copy(), self._plane)
        else:
            atoms = self._calculator.atoms.copy()

        old_cell = atoms.cell

        atoms.set_tags(range(len(atoms)))

        if self._orthogonal_cell is None:
            atoms = orthogonalize_cell(atoms)
        else:
            scaled = atoms.cell.scaled_positions(np.diag(self._orthogonal_cell))
            atoms = cut(atoms, a=scaled[0], b=scaled[1], c=scaled[2])

        valence = self._calculator.get_electrostatic_potential()
        new_gpts = self.gpts + (sum(self._slice_vertical_voxels),)

        axes = plane_to_axes(self._plane)
        if self._plane != 'xy':
            array = np.moveaxis(valence, axes[:2], (0, 1))
        else:
            array = valence

        from scipy.interpolate import RegularGridInterpolator

        origin = (0., 0., 0.)

        padded_array = np.zeros((array.shape[0] + 1, array.shape[1] + 1, array.shape[2] + 1))
        padded_array[:-1, :-1, :-1] = array
        padded_array[-1] = padded_array[0]
        padded_array[:, -1] = padded_array[:, 0]
        padded_array[:, :, -1] = padded_array[:, :, 0]

        x = np.linspace(0, 1, padded_array.shape[0], endpoint=True)
        y = np.linspace(0, 1, padded_array.shape[1], endpoint=True)
        z = np.linspace(0, 1, padded_array.shape[2], endpoint=True)

        interpolator = RegularGridInterpolator((x, y, z), padded_array)

        new_cell = np.diag(atoms.cell)
        x = np.linspace(origin[0], origin[0] + new_cell[0], new_gpts[0], endpoint=False)
        y = np.linspace(origin[1], origin[1] + new_cell[1], new_gpts[1], endpoint=False)
        z = np.linspace(origin[2], origin[2] + new_cell[2], new_gpts[2], endpoint=False)

        P = np.array(old_cell)
        P_inv = np.linalg.inv(P)

        cutoffs = {}
        for number in np.unique(atoms.numbers):
            indices = np.where(atoms.numbers == number)[0]
            r = self._calculator.density.setups[indices[0]].xc_correction.rgd.r_g[1:] * units.Bohr
            cutoffs[number] = r[-1]

        if self._periodic_z:
            atoms = pad_atoms(atoms, margin=max(cutoffs.values()), directions='z', in_place=True)

        indices_by_number = {number: np.where(atoms.numbers == number)[0] for number in np.unique(atoms.numbers)}

        na = sum(self._slice_vertical_voxels[:first_slice])
        a = na * self._voxel_height
        for i in range(first_slice, last_slice):
            nb = na + self._slice_vertical_voxels[i]
            b = a + self._slice_vertical_voxels[i] * self._voxel_height

            X, Y, Z = np.meshgrid(x, y, z[na:nb], indexing='ij')

            points = np.array([X.ravel(), Y.ravel(), Z.ravel()]).T

            scaled_points = np.dot(points, P_inv) % 1.0

            projected_valence = interpolator(scaled_points).reshape(self.gpts + (nb - na,)).sum(
                axis=-1) * self._voxel_height

            array = np.zeros((1,) + self.gpts, dtype=np.float32)
            for number, indices in indices_by_number.items():
                slice_atoms = atoms[indices]

                if len(slice_atoms) == 0:
                    continue

                cutoff = cutoffs[number]
                margin = np.int(np.ceil(cutoff / np.min(self.sampling)))
                rows, cols = _disc_meshgrid(margin)
                disc_indices = np.hstack((rows[:, None], cols[:, None]))

                slice_atoms = slice_atoms[(slice_atoms.positions[:, 2] > a - cutoff) *
                                          (slice_atoms.positions[:, 2] < b + cutoff)]

                slice_atoms = pad_atoms(slice_atoms, margin=cutoff, directions='xy', )

                R = np.geomspace(np.min(self.sampling) / 2, cutoff, int(np.ceil(cutoff / np.min(self.sampling))) * 10)

                vr = np.zeros((len(slice_atoms), len(R)), np.float32)
                dvdr = np.zeros((len(slice_atoms), len(R)), np.float32)
                # TODO : improve speed of this
                for j, atom in enumerate(slice_atoms):
                    r, v = get_paw_corrections(atom.tag, self._calculator, self._core_size)

                    f = interp1d(r * units.Bohr, v, fill_value=(v[0], 0), bounds_error=False, kind='linear')

                    integrator = PotentialIntegrator(f, R, self.get_slice_thickness(i), tolerance=1e-6)

                    vr[j], dvdr[j] = integrator.integrate(np.array([atom.z]), a, b)

                sampling = np.asarray(self.sampling, dtype=np.float32)
                run_length_enconding = np.zeros((2,), dtype=np.int32)
                run_length_enconding[1] = len(slice_atoms)

                interpolate_radial_functions(array,
                                             run_length_enconding,
                                             disc_indices,
                                             slice_atoms.positions,
                                             vr,
                                             R,
                                             dvdr,
                                             sampling)

            array = -(projected_valence + array / np.sqrt(4 * np.pi) * units.Ha)

            yield i, i + 1, PotentialArray(array, np.array([self.get_slice_thickness(i)]), extent=self.extent)

            a = b
            na = nb
Ejemplo n.º 26
0
 def evaluate_aberrations(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> \
         Union[float, np.ndarray]:
     xp = get_array_module(alpha)
     complex_exponential = get_device_function(xp, 'complex_exponential')
     return complex_exponential(-self.evaluate_chi(alpha, phi))
Ejemplo n.º 27
0
    def generate_slices(self, first_slice=0, last_slice=None, max_batch=1):

        interpolate_radial_functions = get_device_function(np, 'interpolate_radial_functions')

        if last_slice is None:
            last_slice = len(self)

        valence = self._calculator.get_electrostatic_potential()
        cell = self._calculator.atoms.cell[:2, :2]

        atoms = self._calculator.atoms.copy()
        atoms.set_tags(range(len(atoms)))
        atoms = orthogonalize_cell(atoms)

        indices_by_number = {number: np.where(atoms.numbers == number)[0] for number in np.unique(atoms.numbers)}

        na = sum(self._slice_vertical_voxels[:first_slice])
        a = na * self._voxel_height
        for i in range(first_slice, last_slice):
            nb = na + self._slice_vertical_voxels[i]
            b = a + self._slice_vertical_voxels[i] * self._voxel_height

            projected_valence = valence[..., na:nb].sum(axis=-1) * self._voxel_height
            projected_valence = interpolate_rectangle(projected_valence, cell, self.extent, self.gpts, self._origin)

            array = np.zeros((1,) + self.gpts, dtype=np.float32)
            for number, indices in indices_by_number.items():
                slice_atoms = atoms[indices]

                if len(slice_atoms) == 0:
                    continue

                r = self._calculator.density.setups[indices[0]].xc_correction.rgd.r_g[1:] * units.Bohr
                cutoff = r[-1]

                margin = np.int(np.ceil(cutoff / np.min(self.sampling)))
                rows, cols = _disc_meshgrid(margin)
                disc_indices = np.hstack((rows[:, None], cols[:, None]))

                slice_atoms = slice_atoms[(slice_atoms.positions[:, 2] > a - cutoff) *
                                          (slice_atoms.positions[:, 2] < b + cutoff)]

                slice_atoms = pad_atoms(slice_atoms, cutoff)

                R = np.geomspace(np.min(self.sampling) / 2, cutoff, int(np.ceil(cutoff / np.min(self.sampling))) * 10)

                vr = np.zeros((len(slice_atoms), len(R)), np.float32)
                dvdr = np.zeros((len(slice_atoms), len(R)), np.float32)
                for j, atom in enumerate(slice_atoms):
                    r, v = get_paw_corrections(atom.tag, self._calculator, self._core_size)

                    f = interp1d(r * units.Bohr, v, fill_value=(v[0], 0), bounds_error=False, kind='linear')

                    integrator = PotentialIntegrator(f, R, self.get_slice_thickness(i), tolerance=1e-6)

                    vr[j], dvdr[j] = integrator.integrate(np.array([atom.z]), a, b)

                sampling = np.asarray(self.sampling, dtype=np.float32)
                run_length_enconding = np.zeros((2,), dtype=np.int32)
                run_length_enconding[1] = len(slice_atoms)

                interpolate_radial_functions(array,
                                             run_length_enconding,
                                             disc_indices,
                                             slice_atoms.positions,
                                             vr,
                                             R,
                                             dvdr,
                                             sampling)

            array = -(projected_valence + array / np.sqrt(4 * np.pi) * units.Ha)

            yield i, i + 1, PotentialArray(array, np.array([self.get_slice_thickness(i)]), extent=self.extent)

            a = b
            na = nb
Ejemplo n.º 28
0
 def measure(self):
     array = np.fft.fftshift(self.build())[0]
     calibrations = calibrations_from_grid(self.gpts, self.sampling,
                                           ['x', 'y'])
     abs2 = get_device_function(get_array_module(array), 'abs2')
     return Measurement(array, calibrations, name=str(self))
Ejemplo n.º 29
0
    def _generate_slices_infinite(self,
                                  first_slice=0,
                                  last_slice=None,
                                  max_batch=1) -> Generator:
        xp = get_array_module_from_device(self._device)

        fft2_convolve = get_device_function(xp, 'fft2_convolve')

        atoms = self.atoms.copy()
        atoms.wrap()
        positions = atoms.get_positions().astype(np.float32)
        numbers = atoms.get_atomic_numbers()
        unique = np.unique(numbers)
        order = np.argsort(positions[:, 2])

        positions = positions[order]
        numbers = numbers[order]

        kx = xp.fft.fftfreq(self.gpts[0], self.sampling[0])
        ky = xp.fft.fftfreq(self.gpts[1], self.sampling[1])
        kx, ky = xp.meshgrid(kx, ky, indexing='ij')
        k = xp.sqrt(kx**2 + ky**2)

        sinc = xp.sinc(
            xp.sqrt((kx * self.sampling[0])**2 + (kx * self.sampling[1])**2))

        scattering_factors = {}
        for atomic_number in unique:
            f = kirkland_projected_fourier(k, self.parameters[atomic_number])
            scattering_factors[atomic_number] = (
                f /
                (sinc * self.sampling[0] * self.sampling[1] * kappa)).astype(
                    xp.complex64)

        slice_idx = np.floor(positions[:, 2] / atoms.cell[2, 2] *
                             self.num_slices).astype(np.int)

        start, end = next(
            generate_batches(last_slice - first_slice,
                             max_batch=max_batch,
                             start=first_slice))

        array = xp.zeros((end - start, ) + self.gpts, dtype=xp.complex64)
        temp = xp.zeros((end - start, ) + self.gpts, dtype=xp.complex64)

        for start, end in generate_batches(last_slice - first_slice,
                                           max_batch=max_batch,
                                           start=first_slice):
            array[:] = 0.
            start_idx = np.searchsorted(slice_idx, start)
            end_idx = np.searchsorted(slice_idx, end)

            if start_idx != end_idx:
                for j, number in enumerate(unique):
                    temp[:] = 0.
                    chunk_positions = positions[start_idx:end_idx]
                    chunk_slice_idx = slice_idx[start_idx:end_idx] - start

                    if len(unique) > 1:
                        chunk_positions = chunk_positions[
                            numbers[start_idx:end_idx] == number]
                        chunk_slice_idx = chunk_slice_idx[
                            numbers[start_idx:end_idx] == number]

                    chunk_positions = xp.asarray(chunk_positions[:, :2] /
                                                 self.sampling)

                    superpose_deltas(chunk_positions, chunk_slice_idx, temp)
                    fft2_convolve(temp, scattering_factors[number])

                    array += temp

            slice_thicknesses = [
                self.get_slice_thickness(i) for i in range(start, end)
            ]
            yield start, end, PotentialArray(array.real[:end - start],
                                             slice_thicknesses,
                                             extent=self.extent)
Ejemplo n.º 30
0
    def _generate_slices_finite(self,
                                first_slice=0,
                                last_slice=None,
                                max_batch=1) -> Generator:
        xp = get_array_module_from_device(self._device)

        interpolate_radial_functions = get_device_function(
            xp, 'interpolate_radial_functions')

        atoms = self.atoms.copy()
        atoms.wrap()
        indices_by_number = {
            number: np.where(atoms.numbers == number)[0]
            for number in np.unique(atoms.numbers)
        }

        start, end = next(
            generate_batches(last_slice - first_slice,
                             max_batch=max_batch,
                             start=first_slice))
        array = xp.zeros((end - start, ) + self.gpts, dtype=xp.float32)

        slice_edges = np.linspace(0, self.atoms.cell[2, 2],
                                  self.num_slices + 1)

        for start, end in generate_batches(last_slice - first_slice,
                                           max_batch=max_batch,
                                           start=first_slice):
            array[:] = 0.

            for number, indices in indices_by_number.items():
                species_atoms = atoms[indices]
                integrator = self.get_integrator(number)
                disc_indices = xp.asarray(
                    self._get_radial_interpolation_points(number))

                a = slice_edges[start]
                b = slice_edges[end]
                chunk_atoms = species_atoms[
                    (species_atoms.positions[:, 2] > a - integrator.cutoff) *
                    (species_atoms.positions[:, 2] < b + integrator.cutoff)]
                chunk_atoms = pad_atoms(chunk_atoms, integrator.cutoff)
                chunk_positions = chunk_atoms.positions

                if len(chunk_atoms) == 0:
                    continue

                positions = np.zeros((0, 3), dtype=xp.float32)
                A = np.zeros((0, ), dtype=xp.float32)
                B = np.zeros((0, ), dtype=xp.float32)
                run_length_enconding = np.zeros((end - start + 1, ),
                                                dtype=xp.int32)

                for i, j in enumerate(range(start, end)):
                    a = slice_edges[j]
                    b = slice_edges[j + 1]
                    slice_positions = chunk_positions[
                        (chunk_positions[:, 2] > a - integrator.cutoff) *
                        (chunk_positions[:, 2] < b + integrator.cutoff)]

                    positions = np.vstack((positions, slice_positions))
                    A = np.concatenate((A, [a] * len(slice_positions)))
                    B = np.concatenate((B, [b] * len(slice_positions)))

                    run_length_enconding[
                        i + 1] = run_length_enconding[i] + len(slice_positions)

                vr, dvdr = integrator.integrate(positions[:, 2], A, B, xp=xp)

                vr = xp.asarray(vr, dtype=xp.float32)
                dvdr = xp.asarray(dvdr, dtype=xp.float32)
                r = xp.asarray(integrator.r, dtype=xp.float32)
                sampling = xp.asarray(self.sampling, dtype=xp.float32)

                interpolate_radial_functions(array, run_length_enconding,
                                             disc_indices, positions, vr, r,
                                             dvdr, sampling)

            slice_thicknesses = [
                self.get_slice_thickness(i) for i in range(start, end)
            ]

            yield start, end, PotentialArray(array[:end - start] / kappa,
                                             slice_thicknesses,
                                             extent=self.extent)