def _update_probe( op, comm, data, psi, scan, probe, num_iter, step_length, mode, probe_options, ): """Solve the probe recovery problem.""" def cost_function(probe): cost_out = comm.pool.map(op.cost, data, psi, scan, probe) if comm.use_mpi: return comm.Allreduce_reduce(cost_out, 'cpu') else: return comm.reduce(cost_out, 'cpu') def grad(probe): grad_list = comm.pool.map( op.grad_probe, data, psi, scan, probe, mode=mode, ) if comm.use_mpi: return comm.Allreduce_reduce(grad_list, 'gpu') else: return comm.reduce(grad_list, 'gpu') def dir_multi(dir): """Scatter dir to all GPUs""" return comm.pool.bcast(dir) def update_multi(x, gamma, d): def f(x, d): return x[..., mode, :, :] + gamma * d return comm.pool.map(f, x, d) probe, cost = conjugate_gradient( op.xp, x=probe, cost_function=cost_function, grad=grad, dir_multi=dir_multi, update_multi=update_multi, num_iter=num_iter, step_length=step_length, ) if probe[0].shape[-3] > 1 and probe_options.orthogonality_constraint: probe = comm.pool.map(orthogonalize_gs, probe, axis=(-2, -1)) logger.info('%10s cost is %+12.5e', 'probe', cost) return probe, cost, probe_options
def update_probe(op, nearplane, probe, scan, psi, num_iter=1): """Solve the nearplane single probe recovery problem.""" xp = op.xp obj_patches = op.diffraction.fwd(psi=psi, scan=scan, probe=xp.ones_like(probe)) def cost_function(probe): return xp.linalg.norm(xp.ravel(probe * obj_patches - nearplane))**2 def grad(probe): # Use the average gradient for all probe positions return xp.mean( xp.conj(obj_patches) * (probe * obj_patches - nearplane), axis=(1, 2), keepdims=True, ) probe, cost = conjugate_gradient( op.xp, x=probe, cost_function=cost_function, grad=grad, num_iter=num_iter, ) logger.info('%10s cost is %+12.5e', 'probe', cost) return probe, cost
def update_object(op, nearplane, probe, scan, psi, num_iter=1): """Solve the nearplane object recovery problem.""" xp = op.xp def cost_function(psi): return xp.linalg.norm( xp.ravel( op.diffraction.fwd(psi=psi, scan=scan, probe=probe) - nearplane))**2 def grad(psi): return op.diffraction.adj( nearplane=(op.diffraction.fwd(psi=psi, scan=scan, probe=probe) - nearplane), scan=scan, probe=probe, ) psi, cost = conjugate_gradient( op.xp, x=psi, cost_function=cost_function, grad=grad, num_iter=num_iter, ) logger.info('%10s cost is %+12.5e', 'object', cost) return psi, cost
def update_object(op, num_gpu, data, psi, scan, probe, num_iter=1): """Solve the object recovery problem.""" def cost_function(psi): return op.cost(data, psi, scan, probe) def grad(psi): return op.grad(data, psi, scan, probe) def cost_function_multi(psi, **kwargs): return op.cost_multi(num_gpu, data, psi, scan, probe, **kwargs) def grad_multi(psi): return op.grad_multi(num_gpu, data, psi, scan, probe) def dir_multi(*args): return op.dir_multi(num_gpu, *args) def update_multi(psi, *args): return op.update_multi(num_gpu, psi, *args) if (num_gpu <= 1): psi, cost = conjugate_gradient( op.xp, x=psi, cost_function=cost_function, grad=grad, num_gpu=num_gpu, num_iter=num_iter, ) else: psi, cost = conjugate_gradient( op.xp, x=psi, cost_function=cost_function_multi, grad=grad_multi, dir_multi=dir_multi, update_multi=update_multi, num_gpu=num_gpu, num_iter=num_iter, ) logger.info('%10s cost is %+12.5e', 'object', cost) return psi, cost
def run(self, tomo, obj, theta, num_iter, rho=1.0, tau=0.0, reg=0j, K=1 + 0j, **kwargs ): # yapf: disable """Use conjugate gradient to estimate `obj`. Parameters ---------- tomo: array-like float32 Line integrals through the object. obj : array-like float32 The object to be recovered. num_iter : int Number of steps to take. rho, tau : float32 Weights for data and variation components of the cost function reg : complex64 The regularizer for total variation """ xp = self.array_module reg = xp.asarray(reg, dtype='complex64') K = xp.asarray(K, dtype='complex64') K_conj = xp.conj(K, dtype='complex64') def cost_function(obj): model = K * self.fwd(obj=obj, theta=theta) return ( + rho * xp.square(xp.linalg.norm(model - tomo)) + tau * xp.square(xp.linalg.norm(tv.fwd(xp, obj) - reg)) ) def grad(obj): model = K * self.fwd(obj, theta=theta) return ( + rho * self.adj(K_conj * (model - tomo), theta=theta) + tau * tv.adj(xp, tv.fwd(xp, obj) - reg) ) obj = conjugate_gradient( self.array_module, x=obj, cost_function=cost_function, grad=grad, num_iter=num_iter, ) return { 'obj': obj }
def _update_object( op, comm, data, psi, scan, probe, num_iter, step_length, object_options, ): """Solve the object recovery problem.""" def cost_function_multi(psi, **kwargs): cost_out = comm.pool.map(op.cost, data, psi, scan, probe) if comm.use_mpi: return comm.Allreduce_reduce(cost_out, 'cpu') else: return comm.reduce(cost_out, 'cpu') def grad_multi(psi): grad_list = comm.pool.map(op.grad_psi, data, psi, scan, probe) if comm.use_mpi: return comm.Allreduce_reduce(grad_list, 'gpu') else: return comm.reduce(grad_list, 'gpu') def dir_multi(dir): """Scatter dir to all GPUs""" return comm.pool.bcast(dir) def update_multi(psi, gamma, dir): def f(psi, dir): return psi + gamma * dir return list(comm.pool.map(f, psi, dir)) psi, cost = conjugate_gradient( op.xp, x=psi, cost_function=cost_function_multi, grad=grad_multi, dir_multi=dir_multi, update_multi=update_multi, num_iter=num_iter, step_length=step_length, ) logger.info('%10s cost is %+12.5e', 'object', cost) return psi, cost, object_options
def update_nearplane( op, nearplane, farplane, probe, psi, scan, ρ, λ, τ, μ, num_iter=1, ): # yapf: disable """Solve the nearplane problem.""" xp = op.xp nearplane0 = op.diffraction.fwd(probe=probe, psi=psi, scan=scan) def cost_function(nearplane): return ( + ρ * xp.linalg.norm(xp.ravel( + op.propagation.fwd(nearplane) - farplane + λ / ρ ))**2 + τ * xp.linalg.norm(xp.ravel( + nearplane0 - nearplane + μ / τ ))**2 ) # yapf: disable def grad(nearplane): return ( + ρ * op.propagation.adj( + op.propagation.fwd(nearplane) - farplane + λ / ρ ) - τ * ( + nearplane0 - nearplane + μ / τ ) ) # yapf: disable nearplane, cost = conjugate_gradient( op.xp, x=nearplane, cost_function=cost_function, grad=grad, num_iter=num_iter, ) # print cost function for sanity check logger.info('%10s cost is %+12.5e', 'nearplane', cost) return nearplane, cost
def _update_probe(op, comm, data, psi, scan, probe, num_iter=1): """Solve the probe recovery problem.""" # TODO: Cache object patches between mode updates intensity = [ _compute_intensity(op, psi, scan, probe[..., m:m + 1, :, :])[0] for m in range(probe.shape[-3]) ] intensity = op.xp.array(intensity) for m in range(probe.shape[-3]): def cost_function(mode): intensity[m], _ = _compute_intensity(op, psi, scan, mode) return op.propagation.cost(data, op.xp.sum(intensity, axis=0)) def grad(mode): intensity[m], farplane = _compute_intensity(op, psi, scan, mode) # Use the average gradient for all probe positions return op.xp.mean( op.adj_probe( farplane=op.propagation.grad( data, farplane, op.xp.sum(intensity, axis=0), ), psi=psi, scan=scan, overwrite=True, ), axis=1, keepdims=True, ) probe[..., m:m + 1, :, :], cost = conjugate_gradient( op.xp, x=probe[..., m:m + 1, :, :], cost_function=cost_function, grad=grad, num_iter=num_iter, step_length=4, ) logger.info('%10s cost is %+12.5e', 'probe', cost) return probe, cost
def update_obj(op, data, obj, num_iter=1): """Solver the object recovery problem.""" def cost_function(obj): return op.cost(data, obj) def grad(obj): return op.grad(data, obj) obj, cost = conjugate_gradient( op.xp, x=obj, cost_function=cost_function, grad=grad, num_iter=num_iter, ) logger.info('%10s cost is %+12.5e', 'object', cost) return obj, cost
def update_phase(op, data, farplane, num_iter=1): """Solve the farplane phase problem.""" def grad(farplane): return op.propagation.grad(data, farplane) def cost_function(farplane): return op.propagation.cost(data, farplane) farplane, cost = conjugate_gradient( op.xp, x=farplane, cost_function=cost_function, grad=grad, num_iter=num_iter, ) # print cost function for sanity check logger.info('%10s cost is %+12.5e', 'farplane', cost) return farplane, cost
def update_obj(op, comm, data, theta, obj, num_iter=1, step_length=1): """Solver the object recovery problem.""" def cost_function(obj): cost_out = comm.pool.map(op.cost, data, theta, obj) if comm.use_mpi: return comm.Allreduce_reduce(cost_out, 'cpu') else: return comm.reduce(cost_out, 'cpu') def grad(obj): grad_list = comm.pool.map(op.grad, data, theta, obj) if comm.use_mpi: return comm.Allreduce_reduce(grad_list, 'gpu') else: return comm.reduce(grad_list, 'gpu') def dir_multi(dir): """Scatter dir to all GPUs""" return comm.pool.bcast(dir) def update_multi(x, gamma, dir): def f(x, dir): return x + gamma * dir return comm.pool.map(f, x, dir) obj, cost = conjugate_gradient( op.xp, x=obj, cost_function=cost_function, grad=grad, dir_multi=dir_multi, update_multi=update_multi, num_iter=num_iter, step_length=step_length, ) logger.info('%10s cost is %+12.5e', 'object', cost) return obj, cost
def update_probe(op, num_gpu, data, psi, scan, probe, num_iter=1): """Solve the probe recovery problem.""" # TODO: add multi-GPU support if (num_gpu > 1): scan = op.asarray_multi_fuse(num_gpu, scan) data = op.asarray_multi_fuse(num_gpu, data) psi = psi[0] probe = probe[0] # TODO: Cache object patche between mode updates for m in range(probe.shape[-3]): def cost_function(mode): return op.cost(data, psi, scan, probe, m, mode) def grad(mode): # Use the average gradient for all probe positions return op.xp.mean( op.grad_probe(data, psi, scan, probe, m, mode), axis=(1, 2), keepdims=True, ) probe[..., m:m + 1, :, :], cost = conjugate_gradient( op.xp, x=probe[..., m:m + 1, :, :], cost_function=cost_function, grad=grad, num_iter=num_iter, ) if (num_gpu > 1): probe = op.asarray_multi(num_gpu, probe) del scan del data logger.info('%10s cost is %+12.5e', 'probe', cost) return probe, cost
def update_phase(op, data, farplane, nearplane, ρ, λ, num_iter=1): """Solve the farplane phase problem.""" xp = op.xp farplane0 = op.propagation.fwd(nearplane) def cost_function(farplane): return (op.propagation.cost(data, farplane) + ρ * xp.linalg.norm(xp.ravel(farplane0 - farplane + λ / ρ))**2) def grad(farplane): return (op.propagation.grad(data, farplane) - ρ * (farplane0 - farplane + λ / ρ)) farplane, cost = conjugate_gradient( op.xp, x=farplane, cost_function=cost_function, grad=grad, num_iter=num_iter, ) # print cost function for sanity check logger.info('%10s cost is %+12.5e', 'farplane', cost) return farplane, cost
def update_object(op, pool, num_gpu, data, psi, scan, probe, num_iter=1): """Solve the object recovery problem.""" def cost_function(psi): return op.cost(data, psi, scan, probe) def grad(psi): return op.grad(data, psi, scan, probe) def cost_function_multi(psi, **kwargs): cost_out = pool.map(op.cost, data, psi, scan, probe) # TODO: Implement reduce function for ThreadPool cost_cpu = 0 for c in cost_out: cost_cpu += op.asnumpy(c) return cost_cpu def grad_multi(psi): grad_out = pool.map(op.grad, data, psi, scan, probe) grad_list = list(grad_out) # TODO: Implement reduce function for ThreadPool for i in range(1, num_gpu): # TODO: Implement optimal reduce in ThreadPool # if cp.cuda.runtime.deviceCanAccessPeer(0, i): # cp.cuda.runtime.deviceEnablePeerAccess(i) # grad_tmp.data.copy_from_device( # grad_list[i].data, # grad_list[0].size * grad_list[0].itemsize, # ) # else: grad_cpu_tmp = op.asnumpy(grad_list[i]) grad_tmp = op.asarray(grad_cpu_tmp) grad_list[0] += grad_tmp return grad_list[0] def dir_multi(dir): """Scatter dir to all GPUs""" return pool.bcast(dir) def update_multi(psi, gamma, dir): def f(psi, dir): return psi + gamma * dir return list(pool.map(f, psi, dir)) if (num_gpu <= 1): psi, cost = conjugate_gradient( op.xp, x=psi, cost_function=cost_function, grad=grad, num_gpu=num_gpu, num_iter=num_iter, ) else: psi, cost = conjugate_gradient( op.xp, x=psi, cost_function=cost_function_multi, grad=grad_multi, dir_multi=dir_multi, update_multi=update_multi, num_gpu=num_gpu, num_iter=num_iter, ) logger.info('%10s cost is %+12.5e', 'object', cost) return psi, cost
def update_obj( op, comm, data, theta, obj, grid, obj_split, fwd_op, num_iter=1, step_length=1, ): """Solver the object recovery problem.""" def cost_function(obj): fwd_data = fwd_op(obj) workers = comm.pool.workers[::obj_split] cost_out = comm.pool.map( op.cost, data[::obj_split], fwd_data[::obj_split], workers=workers, ) if comm.use_mpi: return comm.Allreduce_reduce(cost_out, 'cpu') else: return comm.reduce(cost_out, 'cpu') def grad(obj): fwd_data = fwd_op(obj) grad_list = comm.pool.map(op.grad, data, theta, fwd_data, grid) return comm.reduce(grad_list, 'gpu', s=obj_split) def direction_dy(xp, grad1, grad0=None, dir_=None): """Return the Dai-Yuan search direction.""" def init(grad1): return -grad1 def f(grad1): return xp.linalg.norm(grad1.ravel())**2 def d(grad0, grad1, dir_, norm_): return ( - grad1 + dir_ * norm_ / (xp.sum(dir_.conj() * (grad1 - grad0)) + 1e-32) ) # yapf: disable workers = comm.pool.workers[:obj_split] if dir_ is None: return comm.pool.map(init, grad1, workers=workers) n = comm.pool.map(f, grad1, workers=workers) if comm.use_mpi: norm_ = comm.Allreduce_reduce(n, 'cpu') else: norm_ = comm.reduce(n, 'cpu') return comm.pool.map( d, grad0, grad1, dir_, norm_=norm_, workers=workers, ) def dir_multi(dir): """Scatter dir to all GPUs""" return comm.pool.bcast(dir, obj_split) def update_multi(x, gamma, dir): def f(x, dir): return x + gamma * dir return comm.pool.map(f, x, dir) obj, cost = conjugate_gradient( op.xp, x=obj, cost_function=cost_function, grad=grad, direction_dy=direction_dy, dir_multi=dir_multi, update_multi=update_multi, num_iter=num_iter, step_length=step_length, ) logger.info('%10s cost is %+12.5e', 'object', cost) return obj, cost