Esempio n. 1
0
    def setUp(self, ntheta=3, pw=15, nscan=27):
        """Load a dataset for reconstruction."""
        self.nscan = nscan
        self.ntheta = ntheta
        self.nprobe = 3
        self.probe_shape = (ntheta, nscan, 1, self.nprobe, pw, pw)
        self.detector_shape = (pw * 3, pw * 3)
        self.original_shape = (ntheta, 128, 128)
        self.scan_shape = (ntheta, nscan, 2)
        print(Ptycho)

        np.random.seed(0)
        scan = np.random.rand(*self.scan_shape).astype('float32') * (127 - 16)
        probe = random_complex(*self.probe_shape)
        original = random_complex(*self.original_shape)
        farplane = random_complex(*self.probe_shape[:-2], *self.detector_shape)

        self.operator = Ptycho(
            nscan=self.scan_shape[-2],
            probe_shape=self.probe_shape[-1],
            detector_shape=self.detector_shape[-1],
            nz=self.original_shape[-2],
            n=self.original_shape[-1],
            ntheta=self.ntheta,
        )
        self.operator.__enter__()
        self.xp = self.operator.xp

        probe = self.xp.asarray(probe.astype('complex64'))
        original = self.xp.asarray(original.astype('complex64'))
        farplane = self.xp.asarray(farplane.astype('complex64'))
        scan = self.xp.asarray(scan.astype('float32'))

        self.m = self.xp.asarray(original, dtype='complex64')
        self.m_name = 'psi'
        self.kwargs = {
            'scan': self.xp.asarray(scan, dtype='float32'),
            'probe': self.xp.asarray(probe, dtype='complex64')
        }

        self.m1 = self.xp.asarray(probe, dtype='complex64')
        self.m1_name = 'probe'
        self.kwargs1 = {
            'scan': self.xp.asarray(scan, dtype='float32'),
            'psi': self.xp.asarray(original, dtype='complex64')
        }
        self.kwargs2 = {
            'scan': self.xp.asarray(scan, dtype='float32'),
        }

        self.d = self.xp.asarray(farplane, dtype='complex64')
        self.d_name = 'farplane'
Esempio n. 2
0
def simulate(
        detector_shape,
        probe, scan,
        psi,
        **kwargs
):  # yapf: disable
    """Return real-valued detector counts of simulated ptychography data."""
    assert scan.ndim == 3
    assert psi.ndim == 3
    check_allowed_positions(scan, psi, probe)
    with Ptycho(
            probe_shape=probe.shape[-1],
            detector_shape=int(detector_shape),
            nz=psi.shape[-2],
            n=psi.shape[-1],
            ntheta=scan.shape[0],
            **kwargs,
    ) as operator:
        data = 0
        for mode in np.split(probe, probe.shape[-3], axis=-3):
            farplane = operator.fwd(
                probe=operator.asarray(mode, dtype='complex64'),
                scan=operator.asarray(scan, dtype='float32'),
                psi=operator.asarray(psi, dtype='complex64'),
                **kwargs,
            )
            data += np.square(
                np.linalg.norm(
                    farplane.reshape(operator.ntheta,
                                     scan.shape[-2] // operator.fly, -1,
                                     detector_shape, detector_shape),
                    ord=2,
                    axis=2,
                ))
        return operator.asnumpy(data)
Esempio n. 3
0
    def test_adjoint(self):
        """Check that the adjoint operator is correct."""
        np.random.seed(0)
        scan = np.random.rand(*self.scan_shape).astype('float32') * (127 - 16)
        probe = random_complex(*self.probe_shape)
        original = random_complex(*self.original_shape)
        farplane = random_complex(*self.probe_shape[:-2], *self.detector_shape)

        with Ptycho(
                nscan=self.scan_shape[-2],
                probe_shape=self.probe_shape[-1],
                detector_shape=self.detector_shape[-1],
                nz=self.original_shape[-2],
                n=self.original_shape[-1],
                ntheta=self.ntheta,
                fly=self.fly,
        ) as op:

            probe = op.asarray(probe.astype('complex64'))
            original = op.asarray(original.astype('complex64'))
            farplane = op.asarray(farplane.astype('complex64'))
            scan = op.asarray(scan.astype('float32'))

            d = op.fwd(
                probe=probe,
                scan=scan,
                psi=original,
            )
            assert d.shape == farplane.shape
            o = op.adj(
                farplane=farplane,
                probe=probe,
                scan=scan,
            )
            assert original.shape == o.shape
            p = op.adj_probe(
                farplane=farplane,
                scan=scan,
                psi=original,
            )
            assert probe.shape == p.shape
            a = inner_complex(d, farplane)
            b = inner_complex(probe, p)
            c = inner_complex(original, o)
            print()
            print('<FQP,     Ψ> = {:.6f}{:+.6f}j'.format(
                a.real.item(), a.imag.item()))
            print('<P  , Q*F*Ψ> = {:.6f}{:+.6f}j'.format(
                b.real.item(), b.imag.item()))
            print('<Q  , P*F*Ψ> = {:.6f}{:+.6f}j'.format(
                c.real.item(), c.imag.item()))
            # Test whether Adjoint fixed probe operator is correct
            op.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5)
            op.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5)
            op.xp.testing.assert_allclose(a.real, c.real, rtol=1e-5)
            op.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-5)
Esempio n. 4
0
def simulate(
        detector_shape,
        probe, scan,
        psi,
        fly=1,
        eigen_probe=None,
        eigen_weights=None,
        **kwargs
):  # yapf: disable
    """Return real-valued detector counts of simulated ptychography data.

    Parameters
    ----------
    detector_shape : int
        The pixel width of the detector.
    probe : (..., 1, 1, SHARED, WIDE, HIGH) complex64
        The shared complex illumination function amongst all positions.
    scan : (..., POSI, 2) float32
        Coordinates of the minimum corner of the probe grid for each
        measurement in the coordinate system of psi.
    psi : (..., WIDE, HIGH) complex64
        The complex wavefront modulation of the object.
    fly : int
        The number of scan positions which combine for one detector frame.
    eigen_probe : (..., 1, EIGEN, SHARED, WIDE, HIGH) complex64
        The eigen probes for all positions.
    eigen_weights : (..., POSI, EIGEN, SHARED) float32
        The relative intensity of the eigen probes at each position.

    Returns
    -------
    data : (..., FRAME, WIDE, HIGH) float32
        The simulated intensity on the detector.

    """
    check_allowed_positions(scan, psi, probe)
    with Ptycho(
            probe_shape=probe.shape[-1],
            detector_shape=int(detector_shape),
            nz=psi.shape[-2],
            n=psi.shape[-1],
            ntheta=scan.shape[0],
            **kwargs,
    ) as operator:
        scan = operator.asarray(scan, dtype='float32')
        psi = operator.asarray(psi, dtype='complex64')
        probe = operator.asarray(probe, dtype='complex64')
        if eigen_weights is not None:
            eigen_weights = operator.asarray(eigen_weights, dtype='float32')
        data = _compute_intensity(operator, psi, scan, probe, eigen_weights,
                                  eigen_probe, fly)
        return operator.asnumpy(data.real)
Esempio n. 5
0
def reconstruct(
        data,
        probe, scan,
        algorithm,
        psi=None, num_gpu=1, num_iter=1, rtol=-1,
        model='gaussian', use_mpi=False, cost=None, times=None,
        eigen_probe=None, eigen_weights=None,
        batch_size=None,
        **kwargs
):  # yapf: disable
    """Solve the ptychography problem using the given `algorithm`.

    Parameters
    ----------
    data : (..., FRAME, WIDE, HIGH) float32
        The intensity (square of the absolute value) of the propagated
        wavefront; i.e. what the detector records.
    eigen_probe : (..., 1, EIGEN, SHARED, WIDE, HIGH) complex64
        The eigen probes for all positions.
    eigen_weights : (..., POSI, EIGEN, SHARED) float32
        The relative intensity of the eigen probes at each position.
    psi : (..., WIDE, HIGH) complex64
        The wavefront modulation coefficients of the object.
    probe : (..., 1, 1, SHARED, WIDE, HIGH) complex64
        The shared complex illumination function amongst all positions.
    scan : (..., POSI, 2) float32
        Coordinates of the minimum corner of the probe grid for each
        measurement in the coordinate system of psi. Coordinate order
        consistent with WIDE, HIGH order.
    algorithm : string
        The name of one algorithms from :py:mod:`.ptycho.solvers`.
    rtol : float
        Terminate early if the relative decrease of the cost function is
        less than this amount.
    batch_size : int
        The approximate number of scan positions processed by each GPU
        simultaneously per view.
    """
    (psi, scan) = get_padded_object(scan, probe) if psi is None else (psi,
                                                                      scan)
    check_allowed_positions(scan, psi, probe)
    if use_mpi is True:
        mpi = MPIComm
    else:
        mpi = None
    if algorithm in solvers.__all__:
        # Initialize an operator.
        with Ptycho(
                probe_shape=probe.shape[-1],
                detector_shape=data.shape[-1],
                nz=psi.shape[-2],
                n=psi.shape[-1],
                ntheta=scan.shape[0],
                model=model,
        ) as operator, Comm(num_gpu, mpi) as comm:
            logger.info("{} for {:,d} - {:,d} by {:,d} frames for {:,d} "
                        "iterations.".format(algorithm, *data.shape[-3:],
                                             num_iter))
            num_batch = 1 if batch_size is None else max(
                1,
                int(data.shape[-3] / batch_size / comm.pool.num_workers),
            )
            # Divide the inputs into regions
            odd_pool = comm.pool.num_workers % 2
            order, scan, data, eigen_weights = split_by_scan_grid(
                comm.pool,
                (
                    comm.pool.num_workers
                    if odd_pool else comm.pool.num_workers // 2,
                    1 if odd_pool else 2,
                ),
                scan,
                data,
                eigen_weights,
            )
            result = {
                'psi':
                comm.pool.bcast(psi.astype('complex64')),
                'probe':
                comm.pool.bcast(probe.astype('complex64')),
                'eigen_probe':
                comm.pool.bcast(eigen_probe.astype('complex64'))
                if eigen_probe is not None else None,
                'scan':
                scan,
                'eigen_weights':
                eigen_weights,
            }
            for key, value in kwargs.items():
                if np.ndim(value) > 0:
                    kwargs[key] = comm.pool.bcast(value)

            result['probe'] = comm.pool.bcast(
                _rescale_obj_probe(
                    operator,
                    comm,
                    data[0],
                    result['psi'][0],
                    scan[0],
                    result['probe'][0],
                    num_batch=num_batch,
                ))

            costs = []
            times = []
            start = time.perf_counter()
            for i in range(num_iter):

                logger.info(f"{algorithm} epoch {i:,d}")

                kwargs.update(result)
                result = getattr(solvers, algorithm)(
                    operator,
                    comm,
                    data=data,
                    num_batch=num_batch,
                    **kwargs,
                )
                if result['cost'] is not None:
                    costs.append(result['cost'])

                times.append(time.perf_counter() - start)
                start = time.perf_counter()

                # Check for early termination
                if i > 0 and abs((costs[-1] - costs[-2]) / costs[-2]) < rtol:
                    logger.info(
                        "Cost function rtol < %g reached at %d "
                        "iterations.", rtol, i)
                    break

            reorder = np.argsort(np.concatenate(order))
            result['scan'] = comm.pool.gather(scan, axis=1)[:, reorder]
            if 'eigen_weights' in result:
                result['eigen_weights'] = comm.pool.gather(
                    eigen_weights,
                    axis=1,
                )[:, reorder]
                result['eigen_probe'] = result['eigen_probe'][0]
            result['probe'] = result['probe'][0]
            result['cost'] = operator.asarray(costs)
            result['times'] = operator.asarray(times)
            for k, v in result.items():
                if isinstance(v, list):
                    result[k] = v[0]
        return {k: operator.asnumpy(v) for k, v in result.items()}
    else:
        raise ValueError(f"The '{algorithm}' algorithm is not an option.\n"
                         f"\tAvailable algorithms are : {solvers.__all__}")
Esempio n. 6
0
class TestPtycho(unittest.TestCase, OperatorTests):
    """Test the ptychography operator."""
    def setUp(self, ntheta=3, pw=15, nscan=27):
        """Load a dataset for reconstruction."""
        self.nscan = nscan
        self.ntheta = ntheta
        self.probe_shape = (ntheta, nscan, 1, 1, pw, pw)
        self.detector_shape = (pw * 3, pw * 3)
        self.original_shape = (ntheta, 128, 128)
        self.scan_shape = (ntheta, nscan, 2)
        print(Ptycho)

        np.random.seed(0)
        scan = np.random.rand(*self.scan_shape).astype('float32') * (127 - 16)
        probe = random_complex(*self.probe_shape)
        original = random_complex(*self.original_shape)
        farplane = random_complex(*self.probe_shape[:-2], *self.detector_shape)

        self.operator = Ptycho(
            nscan=self.scan_shape[-2],
            probe_shape=self.probe_shape[-1],
            detector_shape=self.detector_shape[-1],
            nz=self.original_shape[-2],
            n=self.original_shape[-1],
            ntheta=self.ntheta,
        )
        self.operator.__enter__()
        self.xp = self.operator.xp

        probe = self.xp.asarray(probe.astype('complex64'))
        original = self.xp.asarray(original.astype('complex64'))
        farplane = self.xp.asarray(farplane.astype('complex64'))
        scan = self.xp.asarray(scan.astype('float32'))

        self.m = self.xp.asarray(original, dtype='complex64')
        self.m_name = 'psi'
        self.kwargs = {
            'scan': self.xp.asarray(scan, dtype='float32'),
            'probe': self.xp.asarray(probe, dtype='complex64')
        }

        self.m1 = self.xp.asarray(probe, dtype='complex64')
        self.m1_name = 'probe'
        self.kwargs1 = {
            'scan': self.xp.asarray(scan, dtype='float32'),
            'psi': self.xp.asarray(original, dtype='complex64')
        }

        self.d = self.xp.asarray(farplane, dtype='complex64')
        self.d_name = 'farplane'

    def test_adjoint_probe(self):
        """Check that the adjoint operator is correct."""
        d = self.operator.fwd(**{self.m1_name: self.m1}, **self.kwargs1)
        assert d.shape == self.d.shape
        m = self.operator.adj_probe(**{self.d_name: self.d}, **self.kwargs1)
        assert m.shape == self.m1.shape
        a = inner_complex(d, self.d)
        b = inner_complex(self.m1, m)
        print()
        print('<Fm,   m> = {:.6f}{:+.6f}j'.format(a.real.item(),
                                                  a.imag.item()))
        print('< d, F*d> = {:.6f}{:+.6f}j'.format(b.real.item(),
                                                  b.imag.item()))
        self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-5)
        self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-5)

    def test_adj_probe_time(self):
        """Time the adjoint operation."""
        start = time.perf_counter()
        m = self.operator.adj_probe(**{self.d_name: self.d}, **self.kwargs1)
        elapsed = time.perf_counter() - start
        print(f"\n{elapsed:1.3e} seconds")

    @unittest.skip('FIXME: This operator is not scaled.')
    def test_scaled(self):
        pass
Esempio n. 7
0
def reconstruct(
        data,
        probe, scan,
        algorithm,
        psi=None, num_gpu=1, num_iter=1, rtol=-1, **kwargs
):  # yapf: disable
    """Solve the ptychography problem using the given `algorithm`.

    Parameters
    ----------
    algorithm : string
        The name of one algorithms from :py:mod:`.ptycho.solvers`.
    rtol : float
        Terminate early if the relative decrease of the cost function is
        less than this amount.

    """
    (psi, scan) = get_padded_object(scan, probe) if psi is None else (psi, scan)
    check_allowed_positions(scan, psi, probe)
    if algorithm in solvers.__all__:
        # Initialize an operator.
        with Ptycho(
                probe_shape=probe.shape[-1],
                detector_shape=data.shape[-1],
                nz=psi.shape[-2],
                n=psi.shape[-1],
                ntheta=scan.shape[0],
                **kwargs,
        ) as operator, ThreadPool(num_gpu) as pool:
            logger.info("{} for {:,d} - {:,d} by {:,d} frames for {:,d} "
                        "iterations.".format(algorithm, *data.shape[1:],
                                             num_iter))
            # TODO: Merge code paths num_gpu is not used.
            num_gpu = pool.device_count
            # send any array-likes to device
            if (num_gpu <= 1):
                data = operator.asarray(data, dtype='float32')
                result = {
                    'psi': operator.asarray(psi, dtype='complex64'),
                    'probe': operator.asarray(probe, dtype='complex64'),
                    'scan': operator.asarray(scan, dtype='float32'),
                }
                for key, value in kwargs.items():
                    if np.ndim(value) > 0:
                        kwargs[key] = operator.asarray(value)
            else:
                scan, data = asarray_multi_split(
                    operator,
                    num_gpu,
                    scan,
                    data,
                )
                result = {
                    'psi': pool.bcast(psi.astype('complex64')),
                    'probe': pool.bcast(probe.astype('complex64')),
                    'scan': scan,
                }
                for key, value in kwargs.items():
                    if np.ndim(value) > 0:
                        kwargs[key] = pool.bcast(value)

            cost = 0
            for i in range(num_iter):
                result['probe'] = _rescale_obj_probe(operator, pool, num_gpu,
                                                     data, result['psi'],
                                                     result['scan'],
                                                     result['probe'])
                kwargs.update(result)
                result = getattr(solvers, algorithm)(
                    operator,
                    pool,
                    num_gpu=num_gpu,
                    data=data,
                    **kwargs,
                )
                # Check for early termination
                if i > 0 and abs((result['cost'] - cost) / cost) < rtol:
                    logger.info(
                        "Cost function rtol < %g reached at %d "
                        "iterations.", rtol, i)
                    break
                cost = result['cost']

            if (num_gpu > 1):
                result['scan'] = pool.gather(result['scan'], axis=1)
                for k, v in result.items():
                    if isinstance(v, list):
                        result[k] = v[0]
        return {k: operator.asnumpy(v) for k, v in result.items()}
    else:
        raise ValueError(
            "The '{}' algorithm is not an available.".format(algorithm))
Esempio n. 8
0
def reconstruct(
        data,
        probe, scan,
        algorithm,
        psi=None, num_gpu=1, num_iter=1, rtol=-1,
        model='gaussian', use_mpi=False, cost=None, times=None,
        batch_size=None, subset_is_random=None,
        eigen_probe=None, eigen_weights=None,
        **kwargs
):  # yapf: disable
    """Solve the ptychography problem using the given `algorithm`.

    Parameters
    ----------
    algorithm : string
        The name of one algorithms from :py:mod:`.ptycho.solvers`.
    rtol : float
        Terminate early if the relative decrease of the cost function is
        less than this amount.
    split : 'grid' or 'stripe'
        The method to use for splitting the scan positions among GPUS.
    """
    (psi, scan) = get_padded_object(scan, probe) if psi is None else (psi,
                                                                      scan)
    check_allowed_positions(scan, psi, probe)
    if use_mpi is True:
        mpi = MPIComm
    else:
        mpi = None
    if algorithm in solvers.__all__:
        # Initialize an operator.
        with Ptycho(
                probe_shape=probe.shape[-1],
                detector_shape=data.shape[-1],
                nz=psi.shape[-2],
                n=psi.shape[-1],
                ntheta=scan.shape[0],
                model=model,
        ) as operator, Comm(num_gpu, mpi) as comm:
            logger.info("{} for {:,d} - {:,d} by {:,d} frames for {:,d} "
                        "iterations.".format(algorithm, *data.shape[-3:],
                                             num_iter))
            # Divide the inputs into regions and mini-batches
            num_batch = 1
            if batch_size is not None:
                num_batch = max(
                    1,
                    int(data.shape[1] / batch_size / comm.pool.num_workers),
                )
            odd_pool = comm.pool.num_workers % 2
            order = np.arange(data.shape[1])
            order, data, scan, eigen_weights = split_by_scan_grid(
                order,
                data,
                scan,
                (
                    comm.pool.num_workers
                    if odd_pool else comm.pool.num_workers // 2,
                    1 if odd_pool else 2,
                ),
                eigen_weights=eigen_weights,
            )
            order, data, scan, eigen_weights = zip(*comm.pool.map(
                _make_mini_batches,
                order,
                data,
                scan,
                eigen_weights,
                num_batch=num_batch,
                subset_is_random=subset_is_random,
            ))

            result = {
                'psi':
                comm.pool.bcast(psi.astype('complex64')),
                'probe':
                comm.pool.bcast(probe.astype('complex64')),
                'eigen_probe':
                comm.pool.bcast(eigen_probe.astype('complex64'))
                if eigen_probe is not None else None,
            }
            for key, value in kwargs.items():
                if np.ndim(value) > 0:
                    kwargs[key] = comm.pool.bcast(value)

            result['probe'] = comm.pool.bcast(
                _rescale_obj_probe(
                    operator,
                    comm,
                    data[0][0],
                    result['psi'][0],
                    scan[0][0],
                    result['probe'][0],
                ))

            costs = []
            times = []
            start = time.perf_counter()
            for i in range(num_iter):

                logger.info(f"{algorithm} epoch {i:,d}")

                for b in randomizer.permutation(num_batch):
                    kwargs.update(result)
                    kwargs['scan'] = [s[b] for s in scan]
                    kwargs['eigen_weights'] = [w[b] for w in eigen_weights]
                    result = getattr(solvers, algorithm)(
                        operator,
                        comm,
                        data=[d[b] for d in data],
                        **kwargs,
                    )
                    if result['cost'] is not None:
                        costs.append(result['cost'])
                    for g in range(comm.pool.num_workers):
                        scan[g][b] = result['scan'][g]
                        eigen_weights[g][b] = result['eigen_weights'][
                            g] if 'eigen_weights' in result else None

                times.append(time.perf_counter() - start)
                start = time.perf_counter()

                # Check for early termination
                if i > 0 and abs((costs[-1] - costs[-2]) / costs[-2]) < rtol:
                    logger.info(
                        "Cost function rtol < %g reached at %d "
                        "iterations.", rtol, i)
                    break

            reorder = np.argsort(
                np.concatenate(list(chain.from_iterable(order))))
            result['scan'] = comm.pool.gather(
                list(comm.pool.map(cp.concatenate, scan, axis=1)),
                axis=1,
            )[:, reorder]
            if 'eigen_weights' in result:
                result['eigen_weights'] = comm.pool.gather(
                    list(comm.pool.map(cp.concatenate, eigen_weights, axis=1)),
                    axis=1,
                )[:, reorder]
                result['eigen_probe'] = result['eigen_probe'][0]
            result['probe'] = result['probe'][0]
            result['cost'] = operator.asarray(costs)
            result['times'] = operator.asarray(times)
            for k, v in result.items():
                if isinstance(v, list):
                    result[k] = v[0]
        return {k: operator.asnumpy(v) for k, v in result.items()}
    else:
        raise ValueError(f"The '{algorithm}' algorithm is not an option.\n"
                         f"\tAvailable algorithms are : {solvers.__all__}")
Esempio n. 9
0
def reconstruct(
    data,
    probe,
    psi,
    scan,
    algorithm_options=solvers.RpieOptions(),
    eigen_probe=None,
    eigen_weights=None,
    model='gaussian',
    num_gpu=1,
    object_options=None,
    position_options=None,
    probe_options=None,
    use_mpi=False,
):
    """Solve the ptychography problem using the given `algorithm`.

    Parameters
    ----------
    data : (FRAME, WIDE, HIGH) float32
        The intensity (square of the absolute value) of the propagated
        wavefront; i.e. what the detector records. FFT-shifted so the
        diffraction peak is at the corners.
    probe : (1, 1, SHARED, WIDE, HIGH) complex64
        The shared complex illumination function amongst all positions.
    scan : (POSI, 2) float32
        Coordinates of the minimum corner of the probe grid for each
        measurement in the coordinate system of psi. Coordinate order
        consistent with WIDE, HIGH order.
    algorithm_options : :py:class:`tike.ptycho.solvers.IterativeOptions`
        A class containing algorithm specific parameters
    eigen_probe : (EIGEN, SHARED, WIDE, HIGH) complex64
        The eigen probes for all positions.
    eigen_weights : (POSI, EIGEN, SHARED) float32
        The relative intensity of the eigen probes at each position.
    model : "gaussian", "poisson"
        The noise model to use for the cost function.
    num_gpu : int, tuple(int)
        The number of GPUs to use or a tuple of the device numbers of the GPUs
        to use. If the number of GPUs is less than the requested number, only
        workers for the available GPUs are allocated.
    object_options : :py:class:`tike.ptycho.ObjectOptions`
        A class containing settings related to object updates.
    position_options : :py:class:`tike.ptycho.PositionOptions`
        A class containing settings related to position correction.
    probe_options : :py:class:`tike.ptycho.ProbeOptions`
        A class containing settings related to probe updates.
    psi : (WIDE, HIGH) complex64
        The wavefront modulation coefficients of the object.
    use_mpi : bool
        Whether to use MPI or not.

    Raises
    ------
        ValueError
            When shapes are incorrect for various input parameters.

    Returns
    -------
    result : dict
        A dictionary of the above parameters that may be passed to this
        function to resume reconstruction from the previous state.

    """
    if (np.any(np.asarray(data.shape) < 1) or data.ndim != 3
            or data.shape[-2] != data.shape[-1]):
        raise ValueError(
            f"data shape {data.shape} is incorrect. "
            "It should be (N, W, H), "
            "where N >= 1 is the number of square diffraction patterns.")
    if (scan.ndim != 2 or scan.shape[1] != 2
            or np.any(np.asarray(scan.shape) < 1)):
        raise ValueError(f"scan shape {scan.shape} is incorrect. "
                         "It should be (N, 2) "
                         "where N >= 1 is the number of scan positions.")
    if data.shape[0] != scan.shape[0]:
        raise ValueError(
            f"data shape {data.shape} and scan shape {scan.shape} "
            "are incompatible. They should have the same leading dimension.")
    if (probe.ndim != 5 or probe.shape[:2] != (1, 1)
            or np.any(np.asarray(probe.shape) < 1)
            or probe.shape[-2] != probe.shape[-1]):
        raise ValueError(f"probe shape {probe.shape} is incorrect. "
                         "It should be (1, 1, S, W, H) "
                         "where S >=1 is the number of probes, and "
                         "W, H >= 1 are the square probe grid dimensions.")
    if np.any(np.asarray(probe.shape[-2:]) > np.asarray(data.shape[-2:])):
        raise ValueError(f"probe shape {probe.shape} is incorrect."
                         "The probe width/height must be "
                         f"<= the data width/height {data.shape}.")
    if (psi.ndim != 2
            or np.any(np.asarray(psi.shape) <= np.asarray(probe.shape[-2:]))):
        raise ValueError(f"psi shape {psi.shape} is incorrect. "
                         "It should be (W, H) where W, H > probe.shape[-2:].")
    check_allowed_positions(scan, psi, probe.shape)
    logger.info("{} for {:,d} - {:,d} by {:,d} frames for {:,d} "
                "iterations.".format(
                    algorithm_options.name,
                    *data.shape[-3:],
                    algorithm_options.num_iter,
                ))

    if use_mpi is True:
        mpi = MPIComm
        if psi is None:
            raise ValueError(
                "When MPI is enabled, initial object guess cannot be None; "
                "automatic psi initialization is not synchronized "
                "across processes.")
    else:
        mpi = None
    with (cp.cuda.Device(num_gpu[0] if isinstance(num_gpu, tuple) else None)):
        with Ptycho(
                probe_shape=probe.shape[-1],
                detector_shape=data.shape[-1],
                nz=psi.shape[-2],
                n=psi.shape[-1],
                model=model,
        ) as operator, Comm(num_gpu, mpi) as comm:

            (
                batches,
                data,
                result,
                scan,
            ) = _setup(
                algorithm_options,
                comm,
                data,
                eigen_probe,
                eigen_weights,
                object_options,
                operator,
                probe,
                psi,
                position_options,
                probe_options,
                scan,
            )

            start = time.perf_counter()
            for i in range(algorithm_options.num_iter):

                logger.info(f"{algorithm_options.name} epoch {i:,d}")

                # TODO: Append new information to everything that emits from
                # _setup.

                result = _iterate(
                    algorithm_options,
                    batches,
                    comm,
                    data,
                    operator,
                    position_options,
                    probe_options,
                    result,
                )

                # TODO: Grab intermediate psi/probe from GPU.

                algorithm_options.times.append(time.perf_counter() - start)
                start = time.perf_counter()

            return _teardown(
                algorithm_options,
                comm,
                eigen_probe,
                eigen_weights,
                object_options,
                position_options,
                probe_options,
                result,
                scan,
            )