コード例 #1
0
def _update_probe(
    op,
    comm,
    data,
    psi,
    scan,
    probe,
    num_iter,
    step_length,
    mode,
    probe_options,
):
    """Solve the probe recovery problem."""
    def cost_function(probe):
        cost_out = comm.pool.map(op.cost, data, psi, scan, probe)
        if comm.use_mpi:
            return comm.Allreduce_reduce(cost_out, 'cpu')
        else:
            return comm.reduce(cost_out, 'cpu')

    def grad(probe):
        grad_list = comm.pool.map(
            op.grad_probe,
            data,
            psi,
            scan,
            probe,
            mode=mode,
        )
        if comm.use_mpi:
            return comm.Allreduce_reduce(grad_list, 'gpu')
        else:
            return comm.reduce(grad_list, 'gpu')

    def dir_multi(dir):
        """Scatter dir to all GPUs"""
        return comm.pool.bcast(dir)

    def update_multi(x, gamma, d):
        def f(x, d):
            return x[..., mode, :, :] + gamma * d

        return comm.pool.map(f, x, d)

    probe, cost = conjugate_gradient(
        op.xp,
        x=probe,
        cost_function=cost_function,
        grad=grad,
        dir_multi=dir_multi,
        update_multi=update_multi,
        num_iter=num_iter,
        step_length=step_length,
    )

    if probe[0].shape[-3] > 1 and probe_options.orthogonality_constraint:
        probe = comm.pool.map(orthogonalize_gs, probe, axis=(-2, -1))

    logger.info('%10s cost is %+12.5e', 'probe', cost)
    return probe, cost, probe_options
コード例 #2
0
ファイル: divided.py プロジェクト: nikitinvv/tike
def update_probe(op, nearplane, probe, scan, psi, num_iter=1):
    """Solve the nearplane single probe recovery problem."""
    xp = op.xp
    obj_patches = op.diffraction.fwd(psi=psi,
                                     scan=scan,
                                     probe=xp.ones_like(probe))

    def cost_function(probe):
        return xp.linalg.norm(xp.ravel(probe * obj_patches - nearplane))**2

    def grad(probe):
        # Use the average gradient for all probe positions
        return xp.mean(
            xp.conj(obj_patches) * (probe * obj_patches - nearplane),
            axis=(1, 2),
            keepdims=True,
        )

    probe, cost = conjugate_gradient(
        op.xp,
        x=probe,
        cost_function=cost_function,
        grad=grad,
        num_iter=num_iter,
    )

    logger.info('%10s cost is %+12.5e', 'probe', cost)
    return probe, cost
コード例 #3
0
ファイル: divided.py プロジェクト: nikitinvv/tike
def update_object(op, nearplane, probe, scan, psi, num_iter=1):
    """Solve the nearplane object recovery problem."""
    xp = op.xp

    def cost_function(psi):
        return xp.linalg.norm(
            xp.ravel(
                op.diffraction.fwd(psi=psi, scan=scan, probe=probe) -
                nearplane))**2

    def grad(psi):
        return op.diffraction.adj(
            nearplane=(op.diffraction.fwd(psi=psi, scan=scan, probe=probe) -
                       nearplane),
            scan=scan,
            probe=probe,
        )

    psi, cost = conjugate_gradient(
        op.xp,
        x=psi,
        cost_function=cost_function,
        grad=grad,
        num_iter=num_iter,
    )

    logger.info('%10s cost is %+12.5e', 'object', cost)
    return psi, cost
コード例 #4
0
def update_object(op, num_gpu, data, psi, scan, probe, num_iter=1):
    """Solve the object recovery problem."""
    def cost_function(psi):
        return op.cost(data, psi, scan, probe)

    def grad(psi):
        return op.grad(data, psi, scan, probe)

    def cost_function_multi(psi, **kwargs):
        return op.cost_multi(num_gpu, data, psi, scan, probe, **kwargs)

    def grad_multi(psi):
        return op.grad_multi(num_gpu, data, psi, scan, probe)

    def dir_multi(*args):
        return op.dir_multi(num_gpu, *args)

    def update_multi(psi, *args):
        return op.update_multi(num_gpu, psi, *args)

    if (num_gpu <= 1):
        psi, cost = conjugate_gradient(
            op.xp,
            x=psi,
            cost_function=cost_function,
            grad=grad,
            num_gpu=num_gpu,
            num_iter=num_iter,
        )
    else:
        psi, cost = conjugate_gradient(
            op.xp,
            x=psi,
            cost_function=cost_function_multi,
            grad=grad_multi,
            dir_multi=dir_multi,
            update_multi=update_multi,
            num_gpu=num_gpu,
            num_iter=num_iter,
        )

    logger.info('%10s cost is %+12.5e', 'object', cost)
    return psi, cost
コード例 #5
0
    def run(self, tomo, obj, theta, num_iter,
            rho=1.0, tau=0.0, reg=0j, K=1 + 0j, **kwargs
    ):  # yapf: disable
        """Use conjugate gradient to estimate `obj`.

        Parameters
        ----------
        tomo: array-like float32
            Line integrals through the object.
        obj : array-like float32
            The object to be recovered.
        num_iter : int
            Number of steps to take.
        rho, tau : float32
            Weights for data and variation components of the cost function
        reg : complex64
            The regularizer for total variation

        """
        xp = self.array_module
        reg = xp.asarray(reg, dtype='complex64')
        K = xp.asarray(K, dtype='complex64')
        K_conj = xp.conj(K, dtype='complex64')

        def cost_function(obj):
            model = K * self.fwd(obj=obj, theta=theta)
            return (
                + rho * xp.square(xp.linalg.norm(model - tomo))
                + tau * xp.square(xp.linalg.norm(tv.fwd(xp, obj) - reg))
            )

        def grad(obj):
            model = K * self.fwd(obj, theta=theta)
            return (
                + rho * self.adj(K_conj * (model - tomo), theta=theta)
                + tau * tv.adj(xp, tv.fwd(xp, obj) - reg)
            )

        obj = conjugate_gradient(
            self.array_module,
            x=obj,
            cost_function=cost_function,
            grad=grad,
            num_iter=num_iter,
        )

        return {
            'obj': obj
        }
コード例 #6
0
def _update_object(
    op,
    comm,
    data,
    psi,
    scan,
    probe,
    num_iter,
    step_length,
    object_options,
):
    """Solve the object recovery problem."""
    def cost_function_multi(psi, **kwargs):
        cost_out = comm.pool.map(op.cost, data, psi, scan, probe)
        if comm.use_mpi:
            return comm.Allreduce_reduce(cost_out, 'cpu')
        else:
            return comm.reduce(cost_out, 'cpu')

    def grad_multi(psi):
        grad_list = comm.pool.map(op.grad_psi, data, psi, scan, probe)
        if comm.use_mpi:
            return comm.Allreduce_reduce(grad_list, 'gpu')
        else:
            return comm.reduce(grad_list, 'gpu')

    def dir_multi(dir):
        """Scatter dir to all GPUs"""
        return comm.pool.bcast(dir)

    def update_multi(psi, gamma, dir):
        def f(psi, dir):
            return psi + gamma * dir

        return list(comm.pool.map(f, psi, dir))

    psi, cost = conjugate_gradient(
        op.xp,
        x=psi,
        cost_function=cost_function_multi,
        grad=grad_multi,
        dir_multi=dir_multi,
        update_multi=update_multi,
        num_iter=num_iter,
        step_length=step_length,
    )

    logger.info('%10s cost is %+12.5e', 'object', cost)
    return psi, cost, object_options
コード例 #7
0
ファイル: admm.py プロジェクト: xiaodong-yu/tike
def update_nearplane(
    op, nearplane, farplane, probe, psi, scan,
    ρ, λ, τ, μ, num_iter=1,
):  # yapf: disable
    """Solve the nearplane problem."""
    xp = op.xp
    nearplane0 = op.diffraction.fwd(probe=probe, psi=psi, scan=scan)

    def cost_function(nearplane):
        return (
            + ρ * xp.linalg.norm(xp.ravel(
                + op.propagation.fwd(nearplane)
                - farplane
                + λ / ρ
            ))**2
            + τ * xp.linalg.norm(xp.ravel(
                + nearplane0
                - nearplane
                + μ / τ
            ))**2
        )  # yapf: disable

    def grad(nearplane):
        return (
            + ρ * op.propagation.adj(
                + op.propagation.fwd(nearplane)
                - farplane
                + λ / ρ
            )
            - τ * (
                + nearplane0
                - nearplane
                + μ / τ
            )
        )  # yapf: disable

    nearplane, cost = conjugate_gradient(
        op.xp,
        x=nearplane,
        cost_function=cost_function,
        grad=grad,
        num_iter=num_iter,
    )

    # print cost function for sanity check
    logger.info('%10s cost is %+12.5e', 'nearplane', cost)
    return nearplane, cost
コード例 #8
0
ファイル: combined.py プロジェクト: OrkoHunter/tike
def _update_probe(op, comm, data, psi, scan, probe, num_iter=1):
    """Solve the probe recovery problem."""

    # TODO: Cache object patches between mode updates
    intensity = [
        _compute_intensity(op, psi, scan, probe[..., m:m + 1, :, :])[0]
        for m in range(probe.shape[-3])
    ]
    intensity = op.xp.array(intensity)

    for m in range(probe.shape[-3]):

        def cost_function(mode):
            intensity[m], _ = _compute_intensity(op, psi, scan, mode)
            return op.propagation.cost(data, op.xp.sum(intensity, axis=0))

        def grad(mode):
            intensity[m], farplane = _compute_intensity(op, psi, scan, mode)
            # Use the average gradient for all probe positions
            return op.xp.mean(
                op.adj_probe(
                    farplane=op.propagation.grad(
                        data,
                        farplane,
                        op.xp.sum(intensity, axis=0),
                    ),
                    psi=psi,
                    scan=scan,
                    overwrite=True,
                ),
                axis=1,
                keepdims=True,
            )

        probe[..., m:m + 1, :, :], cost = conjugate_gradient(
            op.xp,
            x=probe[..., m:m + 1, :, :],
            cost_function=cost_function,
            grad=grad,
            num_iter=num_iter,
            step_length=4,
        )

    logger.info('%10s cost is %+12.5e', 'probe', cost)
    return probe, cost
コード例 #9
0
ファイル: cgrad.py プロジェクト: nikitinvv/tike
def update_obj(op, data, obj, num_iter=1):
    """Solver the object recovery problem."""
    def cost_function(obj):
        return op.cost(data, obj)

    def grad(obj):
        return op.grad(data, obj)

    obj, cost = conjugate_gradient(
        op.xp,
        x=obj,
        cost_function=cost_function,
        grad=grad,
        num_iter=num_iter,
    )

    logger.info('%10s cost is %+12.5e', 'object', cost)
    return obj, cost
コード例 #10
0
ファイル: divided.py プロジェクト: nikitinvv/tike
def update_phase(op, data, farplane, num_iter=1):
    """Solve the farplane phase problem."""
    def grad(farplane):
        return op.propagation.grad(data, farplane)

    def cost_function(farplane):
        return op.propagation.cost(data, farplane)

    farplane, cost = conjugate_gradient(
        op.xp,
        x=farplane,
        cost_function=cost_function,
        grad=grad,
        num_iter=num_iter,
    )

    # print cost function for sanity check
    logger.info('%10s cost is %+12.5e', 'farplane', cost)
    return farplane, cost
コード例 #11
0
def update_obj(op, comm, data, theta, obj, num_iter=1, step_length=1):
    """Solver the object recovery problem."""

    def cost_function(obj):
        cost_out = comm.pool.map(op.cost, data, theta, obj)
        if comm.use_mpi:
            return comm.Allreduce_reduce(cost_out, 'cpu')
        else:
            return comm.reduce(cost_out, 'cpu')

    def grad(obj):
        grad_list = comm.pool.map(op.grad, data, theta, obj)
        if comm.use_mpi:
            return comm.Allreduce_reduce(grad_list, 'gpu')
        else:
            return comm.reduce(grad_list, 'gpu')

    def dir_multi(dir):
        """Scatter dir to all GPUs"""
        return comm.pool.bcast(dir)

    def update_multi(x, gamma, dir):

        def f(x, dir):
            return x + gamma * dir

        return comm.pool.map(f, x, dir)

    obj, cost = conjugate_gradient(
        op.xp,
        x=obj,
        cost_function=cost_function,
        grad=grad,
        dir_multi=dir_multi,
        update_multi=update_multi,
        num_iter=num_iter,
        step_length=step_length,
    )

    logger.info('%10s cost is %+12.5e', 'object', cost)
    return obj, cost
コード例 #12
0
def update_probe(op, num_gpu, data, psi, scan, probe, num_iter=1):
    """Solve the probe recovery problem."""
    # TODO: add multi-GPU support
    if (num_gpu > 1):
        scan = op.asarray_multi_fuse(num_gpu, scan)
        data = op.asarray_multi_fuse(num_gpu, data)
        psi = psi[0]
        probe = probe[0]

    # TODO: Cache object patche between mode updates
    for m in range(probe.shape[-3]):

        def cost_function(mode):
            return op.cost(data, psi, scan, probe, m, mode)

        def grad(mode):
            # Use the average gradient for all probe positions
            return op.xp.mean(
                op.grad_probe(data, psi, scan, probe, m, mode),
                axis=(1, 2),
                keepdims=True,
            )

        probe[..., m:m + 1, :, :], cost = conjugate_gradient(
            op.xp,
            x=probe[..., m:m + 1, :, :],
            cost_function=cost_function,
            grad=grad,
            num_iter=num_iter,
        )

    if (num_gpu > 1):
        probe = op.asarray_multi(num_gpu, probe)
        del scan
        del data

    logger.info('%10s cost is %+12.5e', 'probe', cost)
    return probe, cost
コード例 #13
0
ファイル: admm.py プロジェクト: xiaodong-yu/tike
def update_phase(op, data, farplane, nearplane, ρ, λ, num_iter=1):
    """Solve the farplane phase problem."""
    xp = op.xp
    farplane0 = op.propagation.fwd(nearplane)

    def cost_function(farplane):
        return (op.propagation.cost(data, farplane) +
                ρ * xp.linalg.norm(xp.ravel(farplane0 - farplane + λ / ρ))**2)

    def grad(farplane):
        return (op.propagation.grad(data, farplane) - ρ *
                (farplane0 - farplane + λ / ρ))

    farplane, cost = conjugate_gradient(
        op.xp,
        x=farplane,
        cost_function=cost_function,
        grad=grad,
        num_iter=num_iter,
    )

    # print cost function for sanity check
    logger.info('%10s cost is %+12.5e', 'farplane', cost)
    return farplane, cost
コード例 #14
0
ファイル: combined.py プロジェクト: nikitinvv/tike
def update_object(op, pool, num_gpu, data, psi, scan, probe, num_iter=1):
    """Solve the object recovery problem."""

    def cost_function(psi):
        return op.cost(data, psi, scan, probe)

    def grad(psi):
        return op.grad(data, psi, scan, probe)

    def cost_function_multi(psi, **kwargs):
        cost_out = pool.map(op.cost, data, psi, scan, probe)
        # TODO: Implement reduce function for ThreadPool
        cost_cpu = 0
        for c in cost_out:
            cost_cpu += op.asnumpy(c)
        return cost_cpu

    def grad_multi(psi):
        grad_out = pool.map(op.grad, data, psi, scan, probe)
        grad_list = list(grad_out)
        # TODO: Implement reduce function for ThreadPool
        for i in range(1, num_gpu):
            # TODO: Implement optimal reduce in ThreadPool
            # if cp.cuda.runtime.deviceCanAccessPeer(0, i):
            #     cp.cuda.runtime.deviceEnablePeerAccess(i)
            #     grad_tmp.data.copy_from_device(
            #         grad_list[i].data,
            #         grad_list[0].size * grad_list[0].itemsize,
            #     )
            # else:
            grad_cpu_tmp = op.asnumpy(grad_list[i])
            grad_tmp = op.asarray(grad_cpu_tmp)
            grad_list[0] += grad_tmp

        return grad_list[0]

    def dir_multi(dir):
        """Scatter dir to all GPUs"""
        return pool.bcast(dir)

    def update_multi(psi, gamma, dir):

        def f(psi, dir):
            return psi + gamma * dir

        return list(pool.map(f, psi, dir))

    if (num_gpu <= 1):
        psi, cost = conjugate_gradient(
            op.xp,
            x=psi,
            cost_function=cost_function,
            grad=grad,
            num_gpu=num_gpu,
            num_iter=num_iter,
        )
    else:
        psi, cost = conjugate_gradient(
            op.xp,
            x=psi,
            cost_function=cost_function_multi,
            grad=grad_multi,
            dir_multi=dir_multi,
            update_multi=update_multi,
            num_gpu=num_gpu,
            num_iter=num_iter,
        )

    logger.info('%10s cost is %+12.5e', 'object', cost)
    return psi, cost
コード例 #15
0
def update_obj(
    op,
    comm,
    data, theta, obj, grid,
    obj_split,
    fwd_op,
    num_iter=1,
    step_length=1,
):
    """Solver the object recovery problem."""

    def cost_function(obj):
        fwd_data = fwd_op(obj)
        workers = comm.pool.workers[::obj_split]
        cost_out = comm.pool.map(
            op.cost,
            data[::obj_split],
            fwd_data[::obj_split],
            workers=workers,
        )
        if comm.use_mpi:
            return comm.Allreduce_reduce(cost_out, 'cpu')
        else:
            return comm.reduce(cost_out, 'cpu')

    def grad(obj):
        fwd_data = fwd_op(obj)
        grad_list = comm.pool.map(op.grad, data, theta, fwd_data, grid)
        return comm.reduce(grad_list, 'gpu', s=obj_split)

    def direction_dy(xp, grad1, grad0=None, dir_=None):
        """Return the Dai-Yuan search direction."""

        def init(grad1):
            return -grad1

        def f(grad1):
            return xp.linalg.norm(grad1.ravel())**2

        def d(grad0, grad1, dir_, norm_):
            return (
                - grad1
                + dir_ * norm_
                / (xp.sum(dir_.conj() * (grad1 - grad0)) + 1e-32)
            )  # yapf: disable

        workers = comm.pool.workers[:obj_split]

        if dir_ is None:
            return comm.pool.map(init, grad1, workers=workers)

        n = comm.pool.map(f, grad1, workers=workers)
        if comm.use_mpi:
            norm_ = comm.Allreduce_reduce(n, 'cpu')
        else:
            norm_ = comm.reduce(n, 'cpu')
        return comm.pool.map(
            d,
            grad0,
            grad1,
            dir_,
            norm_=norm_,
            workers=workers,
        )

    def dir_multi(dir):
        """Scatter dir to all GPUs"""
        return comm.pool.bcast(dir, obj_split)

    def update_multi(x, gamma, dir):

        def f(x, dir):
            return x + gamma * dir

        return comm.pool.map(f, x, dir)

    obj, cost = conjugate_gradient(
        op.xp,
        x=obj,
        cost_function=cost_function,
        grad=grad,
        direction_dy=direction_dy,
        dir_multi=dir_multi,
        update_multi=update_multi,
        num_iter=num_iter,
        step_length=step_length,
    )

    logger.info('%10s cost is %+12.5e', 'object', cost)
    return obj, cost