def detect(self, waves) -> np.ndarray: """ Integrate the intensity of a the wave functions over the detector range. Parameters ---------- waves : Waves object The batch of wave functions to detect. Returns ------- 1d array Detected values as a 1D array. The array has the same length as the batch size of the wave functions. """ xp = get_array_module(waves.array) fft2 = get_device_function(xp, 'fft2') abs2 = get_device_function(xp, 'abs2') intensity = abs2(fft2(waves.array, overwrite_x=False)) return self._integrate_array(intensity, waves.angular_sampling, min(waves.cutoff_scattering_angles))
def detect(self, waves) -> np.ndarray: """ Integrate the intensity of a the wave functions over the detector range. Parameters ---------- waves: Waves object The batch of wave functions to detect. Returns ------- 3d array Detected values. The first dimension indexes the batch size, the second and third indexes the radial and angular bins, respectively. """ xp = get_array_module(waves.array) fft2 = get_device_function(xp, 'fft2') abs2 = get_device_function(xp, 'abs2') sum_run_length_encoded = get_device_function(xp, 'sum_run_length_encoded') intensity = abs2(fft2(waves.array, overwrite_x=False)) indices = self._get_regions(waves.gpts, waves.angular_sampling, min(waves.cutoff_scattering_angles), xp) total = xp.sum(intensity, axis=(-2, -1)) separators = xp.concatenate( (xp.array([0]), xp.cumsum(xp.array([len(ring) for ring in indices])))) intensity = intensity.reshape( (intensity.shape[0], -1))[:, xp.concatenate(indices)] result = xp.zeros((len(intensity), len(separators) - 1), dtype=xp.float32) sum_run_length_encoded(intensity, result, separators) return result.reshape( (-1, self.nbins_radial, self.nbins_angular)) / total[:, None, None]
def diffraction_pattern(self) -> Measurement: """ :return: The intensity of the wave functions at the diffraction plane. """ calibrations = calibrations_from_grid(self.grid.antialiased_gpts, self.grid.antialiased_sampling, names=['alpha_x', 'alpha_y'], units='mrad', scale_factor=self.wavelength * 1000, fourier_space=True) calibrations = (None, ) * (len(self.array.shape) - 2) + calibrations xp = get_array_module(self.array) abs2 = get_device_function(xp, 'abs2') fft2 = get_device_function(xp, 'fft2') pattern = asnumpy( abs2( crop_to_center( xp.fft.fftshift(fft2(self.array, overwrite_x=False))))) return Measurement(pattern, calibrations)
def fourier_translation_operator(positions: np.ndarray, shape: tuple) -> np.ndarray: """ Create an array representing one or more phase ramp(s) for shifting another array. Parameters ---------- positions : array of xy-positions shape : two int Returns ------- """ positions_shape = positions.shape if len(positions_shape) == 1: positions = positions[None] xp = get_array_module(positions) complex_exponential = get_device_function(xp, 'complex_exponential') kx, ky = spatial_frequencies(shape, (1., 1.)) kx = kx.reshape((1, -1, 1)) ky = ky.reshape((1, 1, -1)) kx = xp.asarray(kx) ky = xp.asarray(ky) positions = xp.asarray(positions) x = positions[:, 0].reshape((-1, ) + (1, 1)) y = positions[:, 1].reshape((-1, ) + (1, 1)) result = complex_exponential(-2 * np.pi * kx * x) * complex_exponential( -2 * np.pi * ky * y) if len(positions_shape) == 1: return result[0] else: return result
def evaluate_spatial_envelope(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> \ Union[float, np.ndarray]: xp = get_array_module(alpha) p = self.parameters dchi_dk = 2 * xp.pi / self.wavelength * ( (p['C12'] * xp.cos(2. * (phi - p['phi12'])) + p['C10']) * alpha + (p['C23'] * xp.cos(3. * (phi - p['phi23'])) + p['C21'] * xp.cos(1. * (phi - p['phi21']))) * alpha**2 + (p['C34'] * xp.cos(4. * (phi - p['phi34'])) + p['C32'] * xp.cos(2. * (phi - p['phi32'])) + p['C30']) * alpha**3 + (p['C45'] * xp.cos(5. * (phi - p['phi45'])) + p['C43'] * xp.cos(3. * (phi - p['phi43'])) + p['C41'] * xp.cos(1. * (phi - p['phi41']))) * alpha**4 + (p['C56'] * xp.cos(6. * (phi - p['phi56'])) + p['C54'] * xp.cos(4. * (phi - p['phi54'])) + p['C52'] * xp.cos(2. * (phi - p['phi52'])) + p['C50']) * alpha**5) dchi_dphi = -2 * xp.pi / self.wavelength * ( 1 / 2. * (2. * p['C12'] * xp.sin(2. * (phi - p['phi12']))) * alpha + 1 / 3. * (3. * p['C23'] * xp.sin(3. * (phi - p['phi23'])) + 1. * p['C21'] * xp.sin(1. * (phi - p['phi21']))) * alpha**2 + 1 / 4. * (4. * p['C34'] * xp.sin(4. * (phi - p['phi34'])) + 2. * p['C32'] * xp.sin(2. * (phi - p['phi32']))) * alpha**3 + 1 / 5. * (5. * p['C45'] * xp.sin(5. * (phi - p['phi45'])) + 3. * p['C43'] * xp.sin(3. * (phi - p['phi43'])) + 1. * p['C41'] * xp.sin(1. * (phi - p['phi41']))) * alpha**4 + 1 / 6. * (6. * p['C56'] * xp.sin(6. * (phi - p['phi56'])) + 4. * p['C54'] * xp.sin(4. * (phi - p['phi54'])) + 2. * p['C52'] * xp.sin(2. * (phi - p['phi52']))) * alpha**5) return xp.exp(-xp.sign(self.angular_spread) * (self.angular_spread / 2 / 1000)**2 * (dchi_dk**2 + dchi_dphi**2))
def evaluate_gaussian_envelope( self, alpha: Union[float, np.ndarray]) -> Union[float, np.ndarray]: xp = get_array_module(alpha) return xp.exp(-.5 * self.gaussian_spread**2 * alpha**2 / self.wavelength**2)
def evaluate_temporal_envelope( self, alpha: Union[float, np.ndarray]) -> Union[float, np.ndarray]: xp = get_array_module(alpha) return xp.exp(-(.5 * xp.pi / self.wavelength * self.focal_spread * alpha**2)**2).astype(xp.float32)
def _run_epie(object, probe: np.ndarray, diffraction_patterns: np.ndarray, positions: np.ndarray, maxiter: int, alpha: float = 1., beta: float = 1., fix_probe: bool = False, fix_com: bool = False, return_iterations: bool = False, seed=None): xp = get_array_module(probe) object = xp.array(object) probe = xp.array(probe) if len(diffraction_patterns.shape) != 3: raise ValueError() if len(diffraction_patterns) != len(positions): raise ValueError() if object.shape == (2, ): object = xp.ones((int(object[0]), int(object[1])), dtype=xp.complex64) elif len(object.shape) != 2: raise ValueError() if probe.shape != diffraction_patterns.shape[1:]: raise ValueError() if probe.shape != object.shape: raise ValueError() if return_iterations: object_iterations = [] probe_iterations = [] SSE_iterations = [] if seed is not None: np.random.seed(seed) diffraction_patterns = np.fft.ifftshift(np.sqrt(diffraction_patterns), axes=(-2, -1)) SSE = 0. k = 0 outer_pbar = ProgressBar(total=maxiter) inner_pbar = ProgressBar(total=len(positions)) while k < maxiter: indices = np.arange(len(positions)) np.random.shuffle(indices) old_position = xp.array((0., 0.)) inner_pbar.reset() SSE = 0. for j in indices: position = xp.array(positions[j]) diffraction_pattern = xp.array(diffraction_patterns[j]) illuminated_object = fft_shift(object, old_position - position) g = illuminated_object * probe gprime = xp.fft.ifft2(diffraction_pattern * xp.exp(1j * xp.angle(xp.fft.fft2(g)))) object = illuminated_object + alpha * ( gprime - g) * xp.conj(probe) / (xp.max(xp.abs(probe))**2) old_position = position if not fix_probe: probe = probe + beta * ( gprime - g) * xp.conj(illuminated_object) / (xp.max( xp.abs(illuminated_object))**2) # SSE += xp.sum(xp.abs(G) ** 2 - diffraction_pattern) ** 2 inner_pbar.update(1) object = fft_shift(object, position) if fix_com: com = center_of_mass(xp.fft.fftshift(xp.abs(probe)**2)) probe = xp.fft.ifftshift(fft_shift(probe, -xp.array(com))) # SSE = SSE / np.prod(diffraction_patterns.shape) if return_iterations: object_iterations.append(object) probe_iterations.append(probe) SSE_iterations.append(SSE) outer_pbar.update(1) # if verbose: # print(f'Iteration {k:<{len(str(maxiter))}}, SSE = {float(SSE):.3e}') k += 1 inner_pbar.close() outer_pbar.close() if return_iterations: return object_iterations, probe_iterations, SSE_iterations else: return object, probe, SSE
def transmit(self, waves): self.accelerator.check_match(waves) xp = get_array_module(waves._array) waves._array *= copy_to_device(self.array, xp) return waves
def measure(self): array = np.fft.fftshift(self.build())[0] calibrations = calibrations_from_grid(self.gpts, self.sampling, ['x', 'y']) abs2 = get_device_function(get_array_module(array), 'abs2') return Measurement(array, calibrations, name=str(self))
def polar_coordinates(x, y): """Calculate a polar grid for a given Cartesian grid.""" xp = get_array_module(x) alpha = xp.sqrt(x.reshape((-1, 1))**2 + y.reshape((1, -1))**2) phi = xp.arctan2(x.reshape((-1, 1)), y.reshape((1, -1))) return alpha, phi
def evaluate_chi(self, alpha, phi) -> np.ndarray: """ Calculates the polar expansion of the phase error up to 5th order. See Eq. 2.22 in ref [1]. Parameters ---------- alpha : numpy.ndarray Angle between the scattered electrons and the optical axis [mrad]. phi : numpy.ndarray Angle around the optical axis of the scattered electrons [mrad]. wavelength : float Relativistic wavelength of wavefunction [Å]. parameters : Mapping[str, float] Mapping from Cnn, phinn coefficients to their corresponding values. See parameter `parameters` in class CTFBase. Returns ------- References ---------- .. [1] Kirkland, E. J. (2010). Advanced Computing in Electron Microscopy (2nd ed.). Springer. """ xp = get_array_module(alpha) p = self.parameters alpha2 = alpha**2 array = xp.zeros(alpha.shape, dtype=np.float32) if any([p[symbol] != 0. for symbol in ('C10', 'C12', 'phi12')]): array += (1 / 2 * alpha2 * (p['C10'] + p['C12'] * xp.cos(2 * (phi - p['phi12'])))) if any( [p[symbol] != 0. for symbol in ('C21', 'phi21', 'C23', 'phi23')]): array += (1 / 3 * alpha2 * alpha * (p['C21'] * xp.cos(phi - p['phi21']) + p['C23'] * xp.cos(3 * (phi - p['phi23'])))) if any([ p[symbol] != 0. for symbol in ('C30', 'C32', 'phi32', 'C34', 'phi34') ]): array += (1 / 4 * alpha2**2 * (p['C30'] + p['C32'] * xp.cos(2 * (phi - p['phi32'])) + p['C34'] * xp.cos(4 * (phi - p['phi34'])))) if any([ p[symbol] != 0. for symbol in ('C41', 'phi41', 'C43', 'phi43', 'C45', 'phi41') ]): array += (1 / 5 * alpha2**2 * alpha * (p['C41'] * xp.cos((phi - p['phi41'])) + p['C43'] * xp.cos(3 * (phi - p['phi43'])) + p['C45'] * xp.cos(5 * (phi - p['phi45'])))) if any([ p[symbol] != 0. for symbol in ('C50', 'C52', 'phi52', 'C54', 'phi54', 'C56', 'phi56') ]): array += (1 / 6 * alpha2**3 * (p['C50'] + p['C52'] * xp.cos(2 * (phi - p['phi52'])) + p['C54'] * xp.cos(4 * (phi - p['phi54'])) + p['C56'] * xp.cos(6 * (phi - p['phi56'])))) array = 2 * xp.pi / self.wavelength * array return array
def polargrid(x, y): xp = get_array_module(x) alpha = xp.sqrt(x.reshape((-1, 1))**2 + y.reshape((1, -1))**2) phi = xp.arctan2(x.reshape((-1, 1)), y.reshape((1, -1))) return alpha, phi
def collapse(self, positions: Sequence[float], max_batch_expansion: int = None) -> Waves: """ Collapse the scattering matrix to probe wave functions centered on the provided positions. :param positions: The positions of the probe wave functions. :param max_batch_expansion: The maximum number of plane waves the reduction is applied to simultanously. :return: Probe wave functions for the provided positions. """ xp = get_array_module(self.array) complex_exponential = get_device_function(xp, 'complex_exponential') scale_reduce = get_device_function(xp, 'scale_reduce') windowed_scale_reduce = get_device_function(xp, 'windowed_scale_reduce') positions = np.array(positions, dtype=xp.float32) if positions.shape == (2, ): positions = positions[None] elif (len(positions.shape) != 2) or (positions.shape[-1] != 2): raise RuntimeError() interpolated_grid = self.interpolated_grid W = np.floor_divide(interpolated_grid.gpts, 2) corners = np.rint(positions / self.sampling - W).astype(np.int) corners = np.asarray(corners, dtype=xp.int) corners = np.remainder(corners, np.asarray(self.gpts)) corners = xp.asarray(corners) window = xp.zeros(( len(positions), interpolated_grid.gpts[0], interpolated_grid.gpts[1], ), dtype=xp.complex64) positions = xp.asarray(positions) translation = (complex_exponential( 2. * np.pi * self.k[0][None] * positions[:, 0, None]) * complex_exponential(2. * np.pi * self.k[1][None] * positions[:, 1, None])) coefficients = translation * self._evaluate_ctf() for partial_s_matrix in self._generate_partial(max_batch_expansion, pbar=False): partial_coefficients = coefficients[:, partial_s_matrix. start:partial_s_matrix.stop] if self.interpolation > 1: windowed_scale_reduce(window, partial_s_matrix.array, corners, partial_coefficients) else: scale_reduce(window, partial_s_matrix.array, partial_coefficients) return Waves(window, extent=interpolated_grid.extent, energy=self.energy)
def _evaluate_ctf(self): xp = get_array_module(self._array) alpha = xp.sqrt(self.k[0]**2 + self.k[1]**2) * self.wavelength phi = xp.arctan2(self.k[0], self.k[1]) return self._ctf.evaluate(alpha, phi)
def evaluate_aberrations(self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]) -> \ Union[float, np.ndarray]: xp = get_array_module(alpha) complex_exponential = get_device_function(xp, 'complex_exponential') return complex_exponential(-self.evaluate_chi(alpha, phi))
def fft_shift(array, positions): xp = get_array_module(array) return xp.fft.ifft2( xp.fft.fft2(array) * fourier_translation_operator(positions, array.shape[-2:]))