Beispiel #1
0
def _make_mini_batches(
    order,
    data,
    scan,
    eigen_weights=None,
    num_batch=1,
    subset_is_random=True,
):
    """Divide ptycho-inputs into mini-batches along position dimension.

    Parameters
    ----------
    data : (M, N, ...)
    scan : (M, N, 2)
    probe : (M, N, ...), (M, 1, ...)

    Returns
    -------
    data, scan
        The inputs shuffled in the same way.
    """
    logger.info(f'Split data into {num_batch} mini-batches.')
    # FIXME: fly positions must stay together
    if subset_is_random:
        indices = randomizer.permutation(data.shape[1])
    else:
        indices = np.arange(data.shape[1])
    indices = np.array_split(indices, num_batch)
    order = [order[i] for i in indices]
    data = [cp.asarray(data[:, i], dtype='float32') for i in indices]
    scan = [cp.asarray(scan[:, i], dtype='float32') for i in indices]
    if eigen_weights is not None:
        eigen_weights = [
            cp.asarray(eigen_weights[:, i], dtype='float32') for i in indices
        ]
    else:
        eigen_weights = [None for i in indices]
    return order, data, scan, eigen_weights
Beispiel #2
0
def rpie(
    op,
    comm,
    data,
    batches,
    *,
    probe,
    scan,
    psi,
    algorithm_options,
    probe_options=None,
    position_options=None,
    object_options=None,
):
    """Solve the ptychography problem using regularized ptychographical engine.

    Parameters
    ----------
    op : :py:class:`tike.operators.Ptycho`
        A ptychography operator.
    comm : :py:class:`tike.communicators.Comm`
        An object which manages communications between GPUs and nodes.
    data : list((FRAME, WIDE, HIGH) float32, ...)
        A list of unique CuPy arrays for each device containing
        the intensity (square of the absolute value) of the propagated
        wavefront; i.e. what the detector records. FFT-shifted so the
        diffraction peak is at the corners.
    batches : list(list((BATCH_SIZE, ) int, ...), ...)
        A list of list of indices along the FRAME axis of `data` for
        each device which define the batches of `data` to process
        simultaneously.
    probe : list((1, 1, SHARED, WIDE, HIGH) complex64, ...)
        A list of duplicate CuPy arrays for each device containing
        the shared complex illumination function amongst all positions.
    scan : list((POSI, 2) float32, ...)
        A list of unique CuPy arrays for each device containing
        coordinates of the minimum corner of the probe grid for each
        measurement in the coordinate system of psi. Coordinate order
        consistent with WIDE, HIGH order.
    psi : list((WIDE, HIGH) complex64, ...)
        A list of duplicate CuPy arrays for each device containing
        the wavefront modulation coefficients of the object.
    algorithm_options : :py:class:`tike.ptycho.IterativeOptions`
        The options class for this algorithm.
    position_options : :py:class:`tike.ptycho.PositionOptions`
        A class containing settings related to position correction.
    probe_options : :py:class:`tike.ptycho.ProbeOptions`
        A class containing settings related to probe updates.
    object_options : :py:class:`tike.ptycho.ObjectOptions`
        A class containing settings related to object updates.

    Returns
    -------
    result : dict
        A dictionary containing the updated keyword-only arguments passed to
        this function.

    References
    ----------
    Maiden, Andrew M., and John M. Rodenburg. 2009. “An Improved
    Ptychographical Phase Retrieval Algorithm for Diffractive Imaging.”
    Ultramicroscopy 109 (10): 1256–62.
    https://doi.org/10.1016/j.ultramic.2009.05.012.

    .. seealso:: :py:mod:`tike.ptycho`

    """
    for n in randomizer.permutation(len(batches[0])):

        bdata = comm.pool.map(get_batch, data, batches, n=n)
        bscan = comm.pool.map(get_batch, scan, batches, n=n)

        if position_options is None:
            bposition_options = None
        else:
            bposition_options = comm.pool.map(
                PositionOptions.split,
                position_options,
                [b[n] for b in batches],
            )

        unique_probe = probe
        beigen_probe = None
        beigen_weights = None

        nearplane, cost = zip(*comm.pool.map(
            _update_wavefront,
            bdata,
            unique_probe,
            bscan,
            psi,
            op=op,
        ))

        if comm.use_mpi:
            # TODO: This reduction should be mean
            cost = comm.Allreduce_reduce(cost, 'cpu')
        else:
            cost = comm.reduce(cost, 'cpu')

        (
            psi,
            probe,
            beigen_probe,
            beigen_weights,
            bscan,
            bposition_options,
        ) = _update_nearplane(
            op,
            comm,
            nearplane,
            psi,
            bscan,
            probe,
            unique_probe,
            beigen_probe,
            beigen_weights,
            object_options is not None,
            probe_options is not None,
            position_options=bposition_options,
            algorithm_options=algorithm_options,
        )

        if position_options is not None:
            comm.pool.map(
                PositionOptions.join,
                position_options,
                bposition_options,
                [b[n] for b in batches],
            )

        comm.pool.map(
            put_batch,
            bscan,
            scan,
            batches,
            n=n,
        )

    if probe_options and probe_options.orthogonality_constraint:
        probe = comm.pool.map(orthogonalize_eig, probe)

    if object_options:
        psi = comm.pool.map(positivity_constraint,
                            psi,
                            r=object_options.positivity_constraint)

        psi = comm.pool.map(smoothness_constraint,
                            psi,
                            a=object_options.smoothness_constraint)

    algorithm_options.costs.append(cost)
    return {
        'probe': probe,
        'psi': psi,
        'scan': scan,
        'algorithm_options': algorithm_options,
        'probe_options': probe_options,
        'object_options': object_options,
        'position_options': position_options,
    }
Beispiel #3
0
def lstsq_grad(
    op,
    comm,
    data,
    batches,
    *,
    probe,
    scan,
    psi,
    algorithm_options,
    eigen_probe=None,
    eigen_weights=None,
    probe_options=None,
    position_options=None,
    object_options=None,
):
    """Solve the ptychography problem using Odstrcil et al's approach.

    Object and probe are updated simultaneouly using optimal step sizes
    computed using a least squares approach.

    Parameters
    ----------
    op : :py:class:`tike.operators.Ptycho`
        A ptychography operator.
    comm : :py:class:`tike.communicators.Comm`
        An object which manages communications between GPUs and nodes.
    data : list((FRAME, WIDE, HIGH) float32, ...)
        A list of unique CuPy arrays for each device containing
        the intensity (square of the absolute value) of the propagated
        wavefront; i.e. what the detector records. FFT-shifted so the
        diffraction peak is at the corners.
    batches : list(list((BATCH_SIZE, ) int, ...), ...)
        A list of list of indices along the FRAME axis of `data` for
        each device which define the batches of `data` to process
        simultaneously.
    probe : list((1, 1, SHARED, WIDE, HIGH) complex64, ...)
        A list of duplicate CuPy arrays for each device containing
        the shared complex illumination function amongst all positions.
    scan : list((POSI, 2) float32, ...)
        A list of unique CuPy arrays for each device containing
        coordinates of the minimum corner of the probe grid for each
        measurement in the coordinate system of psi. Coordinate order
        consistent with WIDE, HIGH order.
    psi : list((WIDE, HIGH) complex64, ...)
        A list of duplicate CuPy arrays for each device containing
        the wavefront modulation coefficients of the object.
    algorithm_options : :py:class:`tike.ptycho.IterativeOptions`
        The options class for this algorithm.
    position_options : :py:class:`tike.ptycho.PositionOptions`
        A class containing settings related to position correction.
    probe_options : :py:class:`tike.ptycho.ProbeOptions`
        A class containing settings related to probe updates.
    object_options : :py:class:`tike.ptycho.ObjectOptions`
        A class containing settings related to object updates.

    Returns
    -------
    result : dict
        A dictionary containing the updated keyword-only arguments passed to
        this function.

    References
    ----------
    Michal Odstrcil, Andreas Menzel, and Manuel Guizar-Sicaros. Iterative
    least-squares solver for generalized maximum-likelihood ptychography.
    Optics Express. 2018.

    .. seealso:: :py:mod:`tike.ptycho`

    """

    for n in randomizer.permutation(len(batches[0])):

        bdata = comm.pool.map(get_batch, data, batches, n=n)
        bscan = comm.pool.map(get_batch, scan, batches, n=n)

        if position_options:
            bposition_options = comm.pool.map(PositionOptions.split,
                                              position_options,
                                              [b[n] for b in batches])
        else:
            bposition_options = None

        if isinstance(eigen_probe, list):
            beigen_weights = comm.pool.map(
                get_batch,
                eigen_weights,
                batches,
                n=n,
            )
            beigen_probe = eigen_probe
        else:
            beigen_probe = [None] * comm.pool.num_workers
            beigen_weights = [None] * comm.pool.num_workers

        unique_probe = comm.pool.map(
            get_varying_probe,
            probe,
            beigen_probe,
            beigen_weights,
        )

        nearplane, cost = zip(*comm.pool.map(
            _update_wavefront,
            bdata,
            unique_probe,
            bscan,
            psi,
            op=op,
        ))

        if comm.use_mpi:
            cost = comm.Allreduce_reduce(cost, 'cpu')
        else:
            cost = comm.reduce(cost, 'cpu')

        (
            psi,
            probe,
            beigen_probe,
            beigen_weights,
            bscan,
            bposition_options,
        ) = _update_nearplane(
            op,
            comm,
            nearplane,
            psi,
            bscan,
            probe,
            unique_probe,
            beigen_probe,
            beigen_weights,
            object_options is not None,
            probe_options is not None,
            bposition_options,
        )

        if position_options:
            comm.pool.map(PositionOptions.join, position_options,
                          bposition_options, [b[n] for b in batches])

        if isinstance(eigen_probe, list):
            comm.pool.map(
                put_batch,
                beigen_weights,
                eigen_weights,
                batches,
                n=n,
            )

        comm.pool.map(
            put_batch,
            bscan,
            scan,
            batches,
            n=n,
        )

    if probe_options and probe_options.orthogonality_constraint:
        probe[0] = orthogonalize_gs(probe[0], axis=(-2, -1))
        probe = comm.pool.bcast([probe[0]])

    if object_options:
        psi = comm.pool.map(positivity_constraint,
                            psi,
                            r=object_options.positivity_constraint)

        psi = comm.pool.map(smoothness_constraint,
                            psi,
                            a=object_options.smoothness_constraint)

    if isinstance(eigen_probe, list):
        eigen_probe, eigen_weights = (list(a) for a in zip(*comm.pool.map(
            constrain_variable_probe,
            eigen_probe,
            eigen_weights,
        )))

    algorithm_options.costs.append(cost)
    return {
        'probe': probe,
        'psi': psi,
        'scan': scan,
        'eigen_probe': eigen_probe,
        'eigen_weights': eigen_weights,
        'algorithm_options': algorithm_options,
        'probe_options': probe_options,
        'object_options': object_options,
        'position_options': position_options,
    }
Beispiel #4
0
def reconstruct(
        data,
        probe, scan,
        algorithm,
        psi=None, num_gpu=1, num_iter=1, rtol=-1,
        model='gaussian', use_mpi=False, cost=None, times=None,
        batch_size=None, subset_is_random=None,
        eigen_probe=None, eigen_weights=None,
        **kwargs
):  # yapf: disable
    """Solve the ptychography problem using the given `algorithm`.

    Parameters
    ----------
    algorithm : string
        The name of one algorithms from :py:mod:`.ptycho.solvers`.
    rtol : float
        Terminate early if the relative decrease of the cost function is
        less than this amount.
    split : 'grid' or 'stripe'
        The method to use for splitting the scan positions among GPUS.
    """
    (psi, scan) = get_padded_object(scan, probe) if psi is None else (psi,
                                                                      scan)
    check_allowed_positions(scan, psi, probe)
    if use_mpi is True:
        mpi = MPIComm
    else:
        mpi = None
    if algorithm in solvers.__all__:
        # Initialize an operator.
        with Ptycho(
                probe_shape=probe.shape[-1],
                detector_shape=data.shape[-1],
                nz=psi.shape[-2],
                n=psi.shape[-1],
                ntheta=scan.shape[0],
                model=model,
        ) as operator, Comm(num_gpu, mpi) as comm:
            logger.info("{} for {:,d} - {:,d} by {:,d} frames for {:,d} "
                        "iterations.".format(algorithm, *data.shape[-3:],
                                             num_iter))
            # Divide the inputs into regions and mini-batches
            num_batch = 1
            if batch_size is not None:
                num_batch = max(
                    1,
                    int(data.shape[1] / batch_size / comm.pool.num_workers),
                )
            odd_pool = comm.pool.num_workers % 2
            order = np.arange(data.shape[1])
            order, data, scan, eigen_weights = split_by_scan_grid(
                order,
                data,
                scan,
                (
                    comm.pool.num_workers
                    if odd_pool else comm.pool.num_workers // 2,
                    1 if odd_pool else 2,
                ),
                eigen_weights=eigen_weights,
            )
            order, data, scan, eigen_weights = zip(*comm.pool.map(
                _make_mini_batches,
                order,
                data,
                scan,
                eigen_weights,
                num_batch=num_batch,
                subset_is_random=subset_is_random,
            ))

            result = {
                'psi':
                comm.pool.bcast(psi.astype('complex64')),
                'probe':
                comm.pool.bcast(probe.astype('complex64')),
                'eigen_probe':
                comm.pool.bcast(eigen_probe.astype('complex64'))
                if eigen_probe is not None else None,
            }
            for key, value in kwargs.items():
                if np.ndim(value) > 0:
                    kwargs[key] = comm.pool.bcast(value)

            result['probe'] = comm.pool.bcast(
                _rescale_obj_probe(
                    operator,
                    comm,
                    data[0][0],
                    result['psi'][0],
                    scan[0][0],
                    result['probe'][0],
                ))

            costs = []
            times = []
            start = time.perf_counter()
            for i in range(num_iter):

                logger.info(f"{algorithm} epoch {i:,d}")

                for b in randomizer.permutation(num_batch):
                    kwargs.update(result)
                    kwargs['scan'] = [s[b] for s in scan]
                    kwargs['eigen_weights'] = [w[b] for w in eigen_weights]
                    result = getattr(solvers, algorithm)(
                        operator,
                        comm,
                        data=[d[b] for d in data],
                        **kwargs,
                    )
                    if result['cost'] is not None:
                        costs.append(result['cost'])
                    for g in range(comm.pool.num_workers):
                        scan[g][b] = result['scan'][g]
                        eigen_weights[g][b] = result['eigen_weights'][
                            g] if 'eigen_weights' in result else None

                times.append(time.perf_counter() - start)
                start = time.perf_counter()

                # Check for early termination
                if i > 0 and abs((costs[-1] - costs[-2]) / costs[-2]) < rtol:
                    logger.info(
                        "Cost function rtol < %g reached at %d "
                        "iterations.", rtol, i)
                    break

            reorder = np.argsort(
                np.concatenate(list(chain.from_iterable(order))))
            result['scan'] = comm.pool.gather(
                list(comm.pool.map(cp.concatenate, scan, axis=1)),
                axis=1,
            )[:, reorder]
            if 'eigen_weights' in result:
                result['eigen_weights'] = comm.pool.gather(
                    list(comm.pool.map(cp.concatenate, eigen_weights, axis=1)),
                    axis=1,
                )[:, reorder]
                result['eigen_probe'] = result['eigen_probe'][0]
            result['probe'] = result['probe'][0]
            result['cost'] = operator.asarray(costs)
            result['times'] = operator.asarray(times)
            for k, v in result.items():
                if isinstance(v, list):
                    result[k] = v[0]
        return {k: operator.asnumpy(v) for k, v in result.items()}
    else:
        raise ValueError(f"The '{algorithm}' algorithm is not an option.\n"
                         f"\tAvailable algorithms are : {solvers.__all__}")
Beispiel #5
0
def cgrad(
    op,
    comm,
    data,
    batches,
    *,
    probe,
    scan,
    psi,
    algorithm_options,
    probe_options=None,
    position_options=None,
    object_options=None,
):
    """Solve the ptychography problem using conjugate gradient.

    Parameters
    ----------
    op : :py:class:`tike.operators.Ptycho`
        A ptychography operator.
    comm : :py:class:`tike.communicators.Comm`
        An object which manages communications between GPUs and nodes.
    data : list((FRAME, WIDE, HIGH) float32, ...)
        A list of unique CuPy arrays for each device containing
        the intensity (square of the absolute value) of the propagated
        wavefront; i.e. what the detector records. FFT-shifted so the
        diffraction peak is at the corners.
    batches : list(list((BATCH_SIZE, ) int, ...), ...)
        A list of list of indices along the FRAME axis of `data` for
        each device which define the batches of `data` to process
        simultaneously.
    probe : list((1, 1, SHARED, WIDE, HIGH) complex64, ...)
        A list of duplicate CuPy arrays for each device containing
        the shared complex illumination function amongst all positions.
    scan : list((POSI, 2) float32, ...)
        A list of unique CuPy arrays for each device containing
        coordinates of the minimum corner of the probe grid for each
        measurement in the coordinate system of psi. Coordinate order
        consistent with WIDE, HIGH order.
    psi : list((WIDE, HIGH) complex64, ...)
        A list of duplicate CuPy arrays for each device containing
        the wavefront modulation coefficients of the object.
    algorithm_options : :py:class:`tike.ptycho.IterativeOptions`
        The options class for this algorithm.
    position_options : :py:class:`tike.ptycho.PositionOptions`
        A class containing settings related to position correction.
    probe_options : :py:class:`tike.ptycho.ProbeOptions`
        A class containing settings related to probe updates.
    object_options : :py:class:`tike.ptycho.ObjectOptions`
        A class containing settings related to object updates.

    Returns
    -------
    result : dict
        A dictionary containing the updated keyword-only arguments passed to
        this function.

    .. seealso:: :py:mod:`tike.ptycho`

    """
    for n in randomizer.permutation(len(batches[0])):

        bdata = comm.pool.map(get_batch, data, batches, n=n)
        bscan = comm.pool.map(get_batch, scan, batches, n=n)

        if position_options:
            bposition_options = comm.pool.map(PositionOptions.split,
                                              position_options,
                                              [b[n] for b in batches])
        else:
            bposition_options = None

        if object_options:
            psi, cost, object_options = _update_object(
                op,
                comm,
                bdata,
                psi,
                bscan,
                probe,
                num_iter=algorithm_options.cg_iter,
                step_length=algorithm_options.step_length,
                object_options=object_options,
            )
            psi = comm.pool.map(positivity_constraint,
                                psi,
                                r=object_options.positivity_constraint)
            psi = comm.pool.map(smoothness_constraint,
                                psi,
                                a=object_options.smoothness_constraint)

        if probe_options:
            probe, cost, probe_options = _update_probe(
                op,
                comm,
                bdata,
                psi,
                bscan,
                probe,
                num_iter=algorithm_options.cg_iter,
                step_length=algorithm_options.step_length,
                mode=list(range(probe[0].shape[-3])),
                probe_options=probe_options,
            )

        if position_options and comm.pool.num_workers == 1:
            bscan, cost = update_positions_pd(
                op,
                comm.pool.gather(bdata, axis=-3),
                psi[0],
                probe[0],
                comm.pool.gather(bscan, axis=-2),
            )
            bscan = comm.pool.bcast([bscan])
            # TODO: Assign bscan into scan when positions are updated

    algorithm_options.costs.append(cost)
    return {
        'probe': probe,
        'psi': psi,
        'scan': scan,
        'algorithm_options': algorithm_options,
        'probe_options': probe_options,
        'object_options': object_options,
        'position_options': position_options,
    }
Beispiel #6
0
def adam_grad(
    op,
    comm,
    data,
    batches,
    *,
    probe,
    scan,
    psi,
    algorithm_options,
    probe_options=None,
    position_options=None,
    object_options=None,
):
    """Solve the ptychography problem using ADAptive Moment gradient descent.

    Parameters
    ----------
    op : :py:class:`tike.operators.Ptycho`
        A ptychography operator.
    comm : :py:class:`tike.communicators.Comm`
        An object which manages communications between GPUs and nodes.
    data : list((FRAME, WIDE, HIGH) float32, ...)
        A list of unique CuPy arrays for each device containing
        the intensity (square of the absolute value) of the propagated
        wavefront; i.e. what the detector records. FFT-shifted so the
        diffraction peak is at the corners.
    batches : list(list((BATCH_SIZE, ) int, ...), ...)
        A list of list of indices along the FRAME axis of `data` for
        each device which define the batches of `data` to process
        simultaneously.
    probe : list((1, 1, SHARED, WIDE, HIGH) complex64, ...)
        A list of duplicate CuPy arrays for each device containing
        the shared complex illumination function amongst all positions.
    scan : list((POSI, 2) float32, ...)
        A list of unique CuPy arrays for each device containing
        coordinates of the minimum corner of the probe grid for each
        measurement in the coordinate system of psi. Coordinate order
        consistent with WIDE, HIGH order.
    psi : list((WIDE, HIGH) complex64, ...)
        A list of duplicate CuPy arrays for each device containing
        the wavefront modulation coefficients of the object.
    algorithm_options : :py:class:`tike.ptycho.IterativeOptions`
        The options class for this algorithm.
    position_options : :py:class:`tike.ptycho.PositionOptions`
        A class containing settings related to position correction.
    probe_options : :py:class:`tike.ptycho.ProbeOptions`
        A class containing settings related to probe updates.
    object_options : :py:class:`tike.ptycho.ObjectOptions`
        A class containing settings related to object updates.

    Returns
    -------
    result : dict
        A dictionary containing the updated keyword-only arguments passed to
        this function.

    .. seealso:: :py:mod:`tike.ptycho`

    """
    for n in randomizer.permutation(len(batches[0])):

        bdata = comm.pool.map(get_batch, data, batches, n=n)
        bscan = comm.pool.map(get_batch, scan, batches, n=n)

        cost, psi, probe = _update_all(
            op,
            comm,
            bdata,
            psi,
            bscan,
            probe,
            object_options,
            probe_options,
            algorithm_options,
        )

    if probe_options and probe_options.orthogonality_constraint:
        probe = comm.pool.map(orthogonalize_eig, probe)

    if object_options:
        psi = comm.pool.map(positivity_constraint,
                            psi,
                            r=object_options.positivity_constraint)

        psi = comm.pool.map(smoothness_constraint,
                            psi,
                            a=object_options.smoothness_constraint)

    algorithm_options.costs.append(cost)
    return {
        'probe': probe,
        'psi': psi,
        'scan': scan,
        'algorithm_options': algorithm_options,
        'probe_options': probe_options,
        'object_options': object_options,
        'position_options': position_options,
    }