예제 #1
0
def gauss_newton(op, x, rhs, niter, zero_seq=exp_zero_seq(2.0), callback=None):
    """Optimized implementation of a Gauss-Newton method.

    This method solves the inverse problem (of the first kind)::

        A(x) = y

    for a (Frechet-) differentiable `Operator` ``A`` using a
    Gauss-Newton iteration.

    It uses a minimum amount of memory copies by applying re-usable
    temporaries and in-place evaluation.

    A variant of the method applied to a specific problem is described
    in a
    `Wikipedia article
    <https://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm>`_.

    Parameters
    ----------
    op : `Operator`
        Operator in the inverse problem. If not linear, it must have
        an implementation of `Operator.derivative`, which
        in turn must implement `Operator.adjoint`, i.e.
        the call ``op.derivative(x).adjoint`` must be valid.
    x : ``op.domain`` element
        Element to which the result is written. Its initial value is
        used as starting point of the iteration, and its values are
        updated in each iteration step.
    rhs : ``op.range`` element
        Right-hand side of the equation defining the inverse problem
    niter : int
        Maximum number of iterations.
    zero_seq : iterable, optional
        Zero sequence whose values are used for the regularization of
        the linearized problem in each Newton step.
    callback : callable, optional
        Object executing code per iteration, e.g. plotting each iterate.
    """
    if x not in op.domain:
        raise TypeError('`x` {!r} is not in the domain of `op` {!r}'
                        ''.format(x, op.domain))

    x0 = x.copy()
    id_op = IdentityOperator(op.domain)
    dx = op.domain.zero()

    tmp_dom = op.domain.element()
    u = op.domain.element()
    tmp_ran = op.range.element()
    v = op.range.element()

    for _ in range(niter):
        tm = next(zero_seq)
        deriv = op.derivative(x)
        deriv_adjoint = deriv.adjoint

        # v = rhs - op(x) - deriv(x0-x)
        # u = deriv.T(v)
        op(x, out=tmp_ran)  # eval  op(x)
        v.lincomb(1, rhs, -1, tmp_ran)  # assign  v = rhs - op(x)
        tmp_dom.lincomb(1, x0, -1, x)  # assign temp  tmp_dom = x0 - x
        deriv(tmp_dom, out=tmp_ran)  # eval  deriv(x0-x)
        v -= tmp_ran  # assign  v = rhs-op(x)-deriv(x0-x)
        deriv_adjoint(v, out=u)  # eval/assign  u = deriv.T(v)

        # Solve equation Tikhonov regularized system
        # (deriv.T o deriv + tm * id_op)^-1 u = dx
        tikh_op = OperatorSum(OperatorComp(deriv.adjoint, deriv), tm * id_op,
                              tmp_dom)

        # TODO: allow user to select other method
        conjugate_gradient(tikh_op, dx, u, 3)

        # Update x
        x.lincomb(1, x0, 1, dx)  # x = x0 + dx

        if callback is not None:
            callback(x)