def find_fps_with_opt_solver(self, candidates, opt_method=None): """Optimize fixed points with nonlinear optimization solvers. Parameters ---------- candidates opt_method: function, callable """ assert bm.ndim(candidates) == 2 and isinstance( candidates, (bm.JaxArray, jax.numpy.ndarray)) if opt_method is None: opt_method = lambda f, x0: minimize(f, x0, method='BFGS') if self.verbose: print(f"Optimizing to find fixed points:") f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0))) res = f_opt(bm.as_device_array(candidates)) valid_ids = jax.numpy.where(res.success)[0] self._fixed_points = np.asarray(res.x[valid_ids]) self._losses = np.asarray(res.fun[valid_ids]) self._selected_ids = np.asarray(valid_ids) if self.verbose: print( f' ' f'Found {len(valid_ids)} fixed points from {len(candidates)} initial points.' )
def eval_and_update(self, fn: Callable, state: _IterOptState) -> _IterOptState: i, (flat_params, unravel_fn) = state results = minimize(lambda x: fn(unravel_fn(x)), flat_params, (), method=self._method, **self._kwargs) flat_params, out = results.x, results.fun state = (i + 1, _MinimizeState(flat_params, unravel_fn)) return out, state
def do_minimisation(): results = minimize(loss, jnp.zeros(prior_transform.U_ndims), method='BFGS', options=dict(gtol=1e-10, line_search_maxiter=200)) print(results.message) return prior_transform(constrain(results.x)), constrain( results.x), results.status
def do_minimize(): results = minimize(loss, Q0, method='BFGS', options=dict(gtol=1e-8, line_search_maxiter=100)) print(results.message) return results.x.reshape( (K, 7) ), results.status, results.fun, results.nfev, results.nit, results.jac
def run_scf(x0, coords, mf, mo_coeff, mo_occ): options = {"gtol": 1e-6} res = minimize(energy_tot, x0, args=(coords, mf, mo_coeff, mo_occ), method="BFGS", options=options) e = energy_tot(res.x, coords, mf, mo_coeff, mo_occ) print("SCF energy: ", e)
def debug_vmap_bfgs(): import jax.numpy as jnp from jax import jit, config from jax.scipy.optimize import minimize import os config.enable_omnistaging() ncpu=2 os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}" def cost_fn(x): return -jnp.sum(x**2) x = random.uniform(random.PRNGKey(0), (3,), minval=-1, maxval=1) result = jit(lambda x: minimize(cost_fn, x, method='BFGS'))(x) print(result)
def eval_and_update(self, fn: Callable[[Any], Tuple], state: _IterOptState): i, (flat_params, unravel_fn) = state def loss_fn(x): x = unravel_fn(x) out, aux = fn(x) if aux is not None: raise ValueError( "Minimize does not support models with mutable states." ) return out results = minimize( loss_fn, flat_params, (), method=self._method, **self._kwargs ) flat_params, out = results.x, results.fun state = (i + 1, _MinimizeState(flat_params, unravel_fn)) return (out, None), state
def train_bfgs(self, n_batches, batch_fn, options, loss_names, log_file=None, scale=1.0): param_shapes = apply_to_nested_list(self.params, lambda x: x.shape) flatten = flatten_list(self.params) flatten_params = jnp.hstack([x.reshape(-1, ) for x in flatten]) @jax.jit def loss_fn_bfgs(params, batch): params_ = unflatten_to_shape(params, param_shapes) return self.loss_fn(params_, batch) * scale for i in range(n_batches): batch = batch_fn(i) loss_fn_batch = jax.jit(partial(loss_fn_bfgs, batch=batch)) opt_results = minimize(loss_fn_batch, flatten_params, method="bfgs", tol=1e-7, options=options) print( "Success: {},\n Status: {},\n Message: {},\n nfev: {},\n njev: {},\n nit: {}" .format(opt_results.success, opt_results.status, opt_results.message, opt_results.nfev, opt_results.njev, opt_results.nit)) flatten_params = opt_results.x losses = self.evaluate_fn( unflatten_to_shape(flatten_params, param_shapes), batch) print("{}, Batch: {}, BFGS".format(get_time(), i) + \ ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(loss_names, losses)]), file = sys.stdout if log_file is None else log_file) return unflatten_to_shape(flatten_params, param_shapes)
def optimize_subspace(key, d, D): """ Optimize the subspace loss function for a given dimension d. Parameters ---------- key : jax.random.PRNGKey Random number generator key d : int Dimension of the subspace Returns ------- jax._src.scipy.optimize.minimize.OptimizeResults: Optimization results """ key_weight, key_map, key_sign = random.split(key, 3) theta_0 = random.normal(key_weight, (D, )) / 10 theta_sub_0 = jnp.zeros(d) choice_map = random.bernoulli(key_map, 1 / jnp.sqrt(D), shape=(D, d)) P = random.choice(key_sign, jnp.array([-1, 1]), shape=(D, d)) * choice_map f_part = partial(subspace_loss, P=P, theta_0=theta_0, y=y) res = minimize(f_part, theta_sub_0, method="bfgs", tol=1e-3) return res
E = partial(E_base, Phi=Phi, y=y, alpha=alpha) initial_state = mh.new_state(w0, E) mcmc_kernel = mh.kernel(E, jnp.ones(M) * sigma_mcmc) mcmc_kernel = jax.jit(mcmc_kernel) n_samples = 5_000 burnin = 300 key_init = jax.random.PRNGKey(0) states = inference_loop(key_init, mcmc_kernel, initial_state, n_samples) chains = states.position[burnin:, :] nsamp, _ = chains.shape # ** Laplace approximation ** res = minimize(lambda x: E(x) / len(y), w0, method="BFGS") w_map = res.x SN = jax.hessian(E)(w_map) # ** ADF inference ** q = 0.14 lbound, ubound = -10, 10 mu_t = jnp.zeros(M) tau_t = jnp.ones(M) * q init_state = (mu_t, tau_t) xs = (Phi, y) adf_loop = partial(adf_step, q=q, lbound=lbound, ubound=ubound) (mu_t, tau_t), (mu_t_hist, tau_t_hist) = jax.lax.scan(adf_loop, init_state, xs)
res = minimize(f_part, theta_sub_0, method="bfgs", tol=1e-3) return res if __name__ == "__main__": plt.rcParams["axes.spines.top"] = False plt.rcParams["axes.spines.right"] = False D = 1000 R = 10 y = jnp.arange(R) + 1 # 1. Obtain optimal loss for the full-dimensional function theta_0 = jnp.zeros(D) f_part = partial(full_dimension_loss, y=y) res = minimize(f_part, theta_0, method="bfgs") optimal_loss = res.fun # 2. Obtain optimal loss for the subspace function at # different dimensions dimensions = jnp.array(list(range(1, 16)) + [20, 30, 30]) key = random.PRNGKey(314) keys = random.split(key, len(dimensions)) ans = {"dim": [], "loss": [], "w": []} for key, dim in zip(keys, dimensions): print(f"@dim={dim}", end="\r") res = optimize_subspace(key, dim, D) ans["dim"].append(dim) ans["loss"].append(res.fun)