def worker_init_fn(worker_id): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset # the dataset copy in this worker process # recreate forward_op # (to avoid astra.data2d.py ValueError: Data object not found) ray_trafo = dataset.dataset.dataset.forward_op.left dataset.dataset.dataset.forward_op = OperatorComp( odl.tomo.RayTransform( ray_trafo.domain, ray_trafo.geometry, impl=ray_trafo.impl), dataset.dataset.dataset.forward_op.right)
def gauss_newton(op, x, rhs, niter, zero_seq=exp_zero_seq(2.0), callback=None): """Optimized implementation of a Gauss-Newton method. This method solves the inverse problem (of the first kind):: A(x) = y for a (Frechet-) differentiable `Operator` ``A`` using a Gauss-Newton iteration. It uses a minimum amount of memory copies by applying re-usable temporaries and in-place evaluation. A variant of the method applied to a specific problem is described in a `Wikipedia article <https://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm>`_. Parameters ---------- op : `Operator` Operator in the inverse problem. If not linear, it must have an implementation of `Operator.derivative`, which in turn must implement `Operator.adjoint`, i.e. the call ``op.derivative(x).adjoint`` must be valid. x : ``op.domain`` element Element to which the result is written. Its initial value is used as starting point of the iteration, and its values are updated in each iteration step. rhs : ``op.range`` element Right-hand side of the equation defining the inverse problem niter : int Maximum number of iterations. zero_seq : iterable, optional Zero sequence whose values are used for the regularization of the linearized problem in each Newton step. callback : callable, optional Object executing code per iteration, e.g. plotting each iterate. """ if x not in op.domain: raise TypeError('`x` {!r} is not in the domain of `op` {!r}' ''.format(x, op.domain)) x0 = x.copy() id_op = IdentityOperator(op.domain) dx = op.domain.zero() tmp_dom = op.domain.element() u = op.domain.element() tmp_ran = op.range.element() v = op.range.element() for _ in range(niter): tm = next(zero_seq) deriv = op.derivative(x) deriv_adjoint = deriv.adjoint # v = rhs - op(x) - deriv(x0-x) # u = deriv.T(v) op(x, out=tmp_ran) # eval op(x) v.lincomb(1, rhs, -1, tmp_ran) # assign v = rhs - op(x) tmp_dom.lincomb(1, x0, -1, x) # assign temp tmp_dom = x0 - x deriv(tmp_dom, out=tmp_ran) # eval deriv(x0-x) v -= tmp_ran # assign v = rhs-op(x)-deriv(x0-x) deriv_adjoint(v, out=u) # eval/assign u = deriv.T(v) # Solve equation Tikhonov regularized system # (deriv.T o deriv + tm * id_op)^-1 u = dx tikh_op = OperatorSum(OperatorComp(deriv.adjoint, deriv), tm * id_op, tmp_dom) # TODO: allow user to select other method conjugate_gradient(tikh_op, dx, u, 3) # Update x x.lincomb(1, x0, 1, dx) # x = x0 + dx if callback is not None: callback(x)