Example #1
0
def test_probe_waves_raises():
    with pytest.raises(ValueError) as e:
        Probe(not_a_parameter=10)

    assert str(e.value) == 'not_a_parameter not a recognized parameter'

    probe = Probe()
    with pytest.raises(RuntimeError) as e:
        probe.build()

    assert str(e.value) == 'Grid extent is not defined'

    probe.extent = 10
    probe.gpts = 100
    with pytest.raises(RuntimeError) as e:
        probe.build()

    assert str(e.value) == 'Energy is not defined'

    probe.energy = 60e3
    probe.build()
Example #2
0
def epie(
    measurement: Measurement,
    probe_guess: Probe,
    maxiter: int = 5,
    alpha: float = 1.,
    beta: float = 1.,
    fix_probe: bool = False,
    fix_com: bool = False,
    return_iterations: bool = False,
    max_angle=None,
    seed=None,
    device='cpu',
):
    """
    Reconstruct the phase of a 4D-STEM measurement using the extended Ptychographical Iterative Engine.

    See https://doi.org/10.1016/j.ultramic.2009.05.012

    Parameters
    ----------
    measurement : Measurement object
        4D-STEM measurement.
    probe_guess : Probe object
        The initial guess for the probe.
    maxiter : int
        Run the algorithm for this many iterations.
    alpha : float
        Controls the size of the iterative updates for the object. See reference.
    beta : float
        Controls the size of the iterative updates for the probe. See reference.
    fix_probe : bool
        If True, the probe will not be updated by the algorithm. Default is False.
    fix_com : bool
        If True, the center of mass of the probe will be centered. Default is True.
    return_iterations : bool
        If True, return the reconstruction after every iteration. Default is False.
    max_angle : float, optional
        The maximum reconstructed scattering angle. If this is larger than the input data, the data will be zero-padded.
    seed : int, optional
        Seed the random number generator.
    device : str
        Set the calculation device.

    Returns
    -------
    List of Measurement objects

    """

    diffraction_patterns = measurement.array.reshape(
        (-1, ) + measurement.array.shape[2:])

    if max_angle:
        padding_x = int((max_angle / abs(measurement.calibrations[-2].offset) *
                         diffraction_patterns.shape[-2]) //
                        2) - diffraction_patterns.shape[-2] // 2
        padding_y = int((max_angle / abs(measurement.calibrations[-1].offset) *
                         diffraction_patterns.shape[-1]) //
                        2) - diffraction_patterns.shape[-1] // 2
        diffraction_patterns = np.pad(diffraction_patterns,
                                      ((0, ) * 2, (padding_x, ) * 2,
                                       (padding_y, ) * 2))

    extent = (probe_guess.wavelength * 1e3 /
              measurement.calibrations[2].sampling, probe_guess.wavelength *
              1e3 / measurement.calibrations[3].sampling)

    sampling = (extent[0] / diffraction_patterns.shape[-2],
                extent[1] / diffraction_patterns.shape[-1])

    x = measurement.calibrations[0].coordinates(
        measurement.shape[0]) / sampling[0]
    y = measurement.calibrations[1].coordinates(
        measurement.shape[1]) / sampling[1]
    x, y = np.meshgrid(x, y, indexing='ij')
    positions = np.array([x.ravel(), y.ravel()]).T

    probe_guess.extent = extent
    probe_guess.gpts = diffraction_patterns.shape[-2:]

    calibrations = calibrations_from_grid(probe_guess.gpts,
                                          probe_guess.sampling,
                                          names=['x', 'y'],
                                          units='Å')

    probe_guess._device = device
    probe_guess = probe_guess.build(np.array([0, 0])).array[0]

    result = _run_epie(diffraction_patterns.shape[-2:],
                       probe_guess,
                       diffraction_patterns,
                       positions,
                       maxiter=maxiter,
                       alpha=alpha,
                       beta=beta,
                       return_iterations=return_iterations,
                       fix_probe=fix_probe,
                       fix_com=fix_com,
                       seed=seed)

    if return_iterations:
        object_iterations = [
            Measurement(object, calibrations=calibrations)
            for object in result[0]
        ]
        probe_iterations = [
            Measurement(np.fft.fftshift(probe), calibrations=calibrations)
            for probe in result[1]
        ]
        return object_iterations, probe_iterations, result[2]
    else:
        return (Measurement(result[0], calibrations=calibrations),
                Measurement(np.fft.fftshift(result[1]),
                            calibrations=calibrations), result[2])