def body_fun(state: LBFGSResults): # find search direction p_k = _two_loop_recursion(state) # line search ls_results = line_search( f=fun, xk=state.x_k, pk=p_k, old_fval=state.f_k, gfk=state.g_k, maxiter=maxls, ) # evaluate at next iterate s_k = ls_results.a_k * p_k x_kp1 = state.x_k + s_k f_kp1 = ls_results.f_k g_kp1 = ls_results.g_k y_k = g_kp1 - state.g_k rho_k_inv = jnp.real(_dot(y_k, s_k)) rho_k = jnp.reciprocal(rho_k_inv) gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k)) # replacements for next iteration status = 0 status = jnp.where(state.f_k - f_kp1 < ftol, 4, status) status = jnp.where(state.ngev >= maxgrad, 3, status) # type: ignore status = jnp.where(state.nfev >= maxfun, 2, status) # type: ignore status = jnp.where(state.k >= maxiter, 1, status) # type: ignore status = jnp.where(ls_results.failed, 5, status) converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol # TODO(jakevdp): use a fixed-point procedure rather than type-casting? state = state._replace( converged=converged, failed=(status > 0) & (~converged), k=state.k + 1, nfev=state.nfev + ls_results.nfev, ngev=state.ngev + ls_results.ngev, x_k=x_kp1.astype(state.x_k.dtype), f_k=f_kp1.astype(state.f_k.dtype), g_k=g_kp1.astype(state.g_k.dtype), s_history=_update_history_vectors(history=state.s_history, new=s_k), y_history=_update_history_vectors(history=state.y_history, new=y_k), rho_history=_update_history_scalars(history=state.rho_history, new=rho_k), gamma=gamma, status=jnp.where(converged, 0, status), ls_status=ls_results.status, ) return state
def test_line_search_wolfe2_bounds(self): # See gh-7475 # For this f and p, starting at a point on axis 0, the strong Wolfe # condition 2 is met if and only if the step length s satisfies # |x + s| <= c2 * |x| f = lambda x: jnp.dot(x, x) fp = lambda x: 2 * x p = jnp.array([1, 0]) # Smallest s satisfying strong Wolfe conditions for these arguments is 30 x = -60 * p c2 = 0.5 res = line_search(f, x, p, c2=c2) s = res.a_k # s, _, _, _, _, _ = ls.line_search_wolfe2(f, fp, x, p, amax=30, c2=c2) self.assert_line_wolfe(x, p, s, f, fp) self.assertTrue(s >= 30.) res = line_search(f, x, p, c2=c2, maxiter=5) self.assertTrue(res.failed)
def test_scalar_search_wolfe2(self, name): def bind_index(func, idx): # Remember Python's closure semantics! return lambda *a, **kw: func(*a, **kw)[idx] value = getattr(self, name) phi = bind_index(value, 0) derphi = bind_index(value, 1) for old_phi0 in self.rng().randn(3): res = line_search(phi, 0., 1.) s, phi1, derphi1 = res.a_k, res.f_k, res.g_k self.assertAllClose(phi1, phi(s), check_dtypes=False, atol=1e-6) if derphi1 is not None: self.assertAllClose(derphi1, derphi(s), check_dtypes=False, atol=1e-6) self.assert_wolfe(s, phi, derphi, err_msg=f"{name} {old_phi0:g}")
def test_line_search(self): def f(x): return jnp.cos(jnp.sum(jnp.exp(-x))**2) # assert not line_search(jax.value_and_grad(f), np.ones(2), np.array([-0.5, -0.25])).failed xk = jnp.ones(2) pk = jnp.array([-0.5, -0.25]) res = line_search(f, xk, pk, maxiter=100) scipy_res = line_search_wolfe2(f, grad(f), xk, pk) self.assertAllClose(scipy_res[0], res.a_k, atol=1e-5, check_dtypes=False) self.assertAllClose(scipy_res[3], res.f_k, atol=1e-5, check_dtypes=False)
def body_fun(state): p_k = -_dot(state.H_k, state.g_k) line_search_results = line_search( fun, state.x_k, p_k, old_fval=state.f_k, old_old_fval=state.old_old_fval, gfk=state.g_k, maxiter=line_search_maxiter, ) state = state._replace( nfev=state.nfev + line_search_results.nfev, ngev=state.ngev + line_search_results.ngev, failed=line_search_results.failed, line_search_status=line_search_results.status, ) s_k = line_search_results.a_k * p_k x_kp1 = state.x_k + s_k f_kp1 = line_search_results.f_k g_kp1 = line_search_results.g_k y_k = g_kp1 - state.g_k rho_k = jnp.reciprocal(_dot(y_k, s_k)) sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :] w = jnp.eye(d, dtype=rho_k.dtype) - rho_k * sy_k H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w) + rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :]) H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k) converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol state = state._replace( converged=converged, k=state.k + 1, x_k=x_kp1, f_k=f_kp1, g_k=g_kp1, H_k=H_kp1, old_old_fval=state.f_k, ) return state
def test_line_search_wolfe2(self, name): def bind_index(func, idx): # Remember Python's closure semantics! return lambda *a, **kw: func(*a, **kw)[idx] value = getattr(self, name) f = bind_index(value, 0) fprime = bind_index(value, 1) k = 0 N = 20 rng = self.rng() # sets A in one of the line funcs self.A = self.rng().randn(N, N) while k < 9: x = rng.randn(N) p = rng.randn(N) if jnp.dot(p, fprime(x)) >= 0: # always pick a descent pk continue k += 1 f0 = f(x) g0 = fprime(x) self.fcount = 0 res = line_search(f, x, p, old_fval=f0, gfk=g0) s = res.a_k fv = res.f_k gv = res.g_k self.assertAllClose(fv, f(x + s * p), check_dtypes=False, atol=1e-5) if gv is not None: self.assertAllClose(gv, fprime(x + s * p), check_dtypes=False, atol=1e-5)