示例#1
0
    def body_fun(state: LBFGSResults):
        # find search direction
        p_k = _two_loop_recursion(state)

        # line search
        ls_results = line_search(
            f=fun,
            xk=state.x_k,
            pk=p_k,
            old_fval=state.f_k,
            gfk=state.g_k,
            maxiter=maxls,
        )

        # evaluate at next iterate
        s_k = ls_results.a_k * p_k
        x_kp1 = state.x_k + s_k
        f_kp1 = ls_results.f_k
        g_kp1 = ls_results.g_k
        y_k = g_kp1 - state.g_k
        rho_k_inv = jnp.real(_dot(y_k, s_k))
        rho_k = jnp.reciprocal(rho_k_inv)
        gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))

        # replacements for next iteration
        status = 0
        status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
        status = jnp.where(state.ngev >= maxgrad, 3, status)  # type: ignore
        status = jnp.where(state.nfev >= maxfun, 2, status)  # type: ignore
        status = jnp.where(state.k >= maxiter, 1, status)  # type: ignore
        status = jnp.where(ls_results.failed, 5, status)

        converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

        # TODO(jakevdp): use a fixed-point procedure rather than type-casting?
        state = state._replace(
            converged=converged,
            failed=(status > 0) & (~converged),
            k=state.k + 1,
            nfev=state.nfev + ls_results.nfev,
            ngev=state.ngev + ls_results.ngev,
            x_k=x_kp1.astype(state.x_k.dtype),
            f_k=f_kp1.astype(state.f_k.dtype),
            g_k=g_kp1.astype(state.g_k.dtype),
            s_history=_update_history_vectors(history=state.s_history,
                                              new=s_k),
            y_history=_update_history_vectors(history=state.y_history,
                                              new=y_k),
            rho_history=_update_history_scalars(history=state.rho_history,
                                                new=rho_k),
            gamma=gamma,
            status=jnp.where(converged, 0, status),
            ls_status=ls_results.status,
        )

        return state
示例#2
0
    def test_line_search_wolfe2_bounds(self):
        # See gh-7475

        # For this f and p, starting at a point on axis 0, the strong Wolfe
        # condition 2 is met if and only if the step length s satisfies
        # |x + s| <= c2 * |x|
        f = lambda x: jnp.dot(x, x)
        fp = lambda x: 2 * x
        p = jnp.array([1, 0])

        # Smallest s satisfying strong Wolfe conditions for these arguments is 30
        x = -60 * p
        c2 = 0.5

        res = line_search(f, x, p, c2=c2)
        s = res.a_k
        # s, _, _, _, _, _ = ls.line_search_wolfe2(f, fp, x, p, amax=30, c2=c2)
        self.assert_line_wolfe(x, p, s, f, fp)
        self.assertTrue(s >= 30.)

        res = line_search(f, x, p, c2=c2, maxiter=5)
        self.assertTrue(res.failed)
示例#3
0
    def test_scalar_search_wolfe2(self, name):
        def bind_index(func, idx):
            # Remember Python's closure semantics!
            return lambda *a, **kw: func(*a, **kw)[idx]

        value = getattr(self, name)
        phi = bind_index(value, 0)
        derphi = bind_index(value, 1)
        for old_phi0 in self.rng().randn(3):
            res = line_search(phi, 0., 1.)
            s, phi1, derphi1 = res.a_k, res.f_k, res.g_k
            self.assertAllClose(phi1, phi(s), check_dtypes=False, atol=1e-6)
            if derphi1 is not None:
                self.assertAllClose(derphi1,
                                    derphi(s),
                                    check_dtypes=False,
                                    atol=1e-6)
            self.assert_wolfe(s, phi, derphi, err_msg=f"{name} {old_phi0:g}")
示例#4
0
    def test_line_search(self):
        def f(x):
            return jnp.cos(jnp.sum(jnp.exp(-x))**2)

        # assert not line_search(jax.value_and_grad(f), np.ones(2), np.array([-0.5, -0.25])).failed
        xk = jnp.ones(2)
        pk = jnp.array([-0.5, -0.25])
        res = line_search(f, xk, pk, maxiter=100)

        scipy_res = line_search_wolfe2(f, grad(f), xk, pk)

        self.assertAllClose(scipy_res[0],
                            res.a_k,
                            atol=1e-5,
                            check_dtypes=False)
        self.assertAllClose(scipy_res[3],
                            res.f_k,
                            atol=1e-5,
                            check_dtypes=False)
示例#5
0
  def body_fun(state):
    p_k = -_dot(state.H_k, state.g_k)
    line_search_results = line_search(
        fun,
        state.x_k,
        p_k,
        old_fval=state.f_k,
        old_old_fval=state.old_old_fval,
        gfk=state.g_k,
        maxiter=line_search_maxiter,
    )
    state = state._replace(
        nfev=state.nfev + line_search_results.nfev,
        ngev=state.ngev + line_search_results.ngev,
        failed=line_search_results.failed,
        line_search_status=line_search_results.status,
    )
    s_k = line_search_results.a_k * p_k
    x_kp1 = state.x_k + s_k
    f_kp1 = line_search_results.f_k
    g_kp1 = line_search_results.g_k
    y_k = g_kp1 - state.g_k
    rho_k = jnp.reciprocal(_dot(y_k, s_k))

    sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :]
    w = jnp.eye(d, dtype=rho_k.dtype) - rho_k * sy_k
    H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w)
             + rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :])
    H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k)
    converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol

    state = state._replace(
        converged=converged,
        k=state.k + 1,
        x_k=x_kp1,
        f_k=f_kp1,
        g_k=g_kp1,
        H_k=H_kp1,
        old_old_fval=state.f_k,
    )
    return state
示例#6
0
    def test_line_search_wolfe2(self, name):
        def bind_index(func, idx):
            # Remember Python's closure semantics!
            return lambda *a, **kw: func(*a, **kw)[idx]

        value = getattr(self, name)
        f = bind_index(value, 0)
        fprime = bind_index(value, 1)

        k = 0
        N = 20
        rng = self.rng()
        # sets A in one of the line funcs
        self.A = self.rng().randn(N, N)
        while k < 9:
            x = rng.randn(N)
            p = rng.randn(N)
            if jnp.dot(p, fprime(x)) >= 0:
                # always pick a descent pk
                continue
            k += 1

            f0 = f(x)
            g0 = fprime(x)
            self.fcount = 0
            res = line_search(f, x, p, old_fval=f0, gfk=g0)
            s = res.a_k
            fv = res.f_k
            gv = res.g_k
            self.assertAllClose(fv,
                                f(x + s * p),
                                check_dtypes=False,
                                atol=1e-5)
            if gv is not None:
                self.assertAllClose(gv,
                                    fprime(x + s * p),
                                    check_dtypes=False,
                                    atol=1e-5)