示例#1
0
    def find_fps_with_opt_solver(self, candidates, opt_method=None):
        """Optimize fixed points with nonlinear optimization solvers.

    Parameters
    ----------
    candidates
    opt_method: function, callable
    """

        assert bm.ndim(candidates) == 2 and isinstance(
            candidates, (bm.JaxArray, jax.numpy.ndarray))
        if opt_method is None:
            opt_method = lambda f, x0: minimize(f, x0, method='BFGS')
        if self.verbose:
            print(f"Optimizing to find fixed points:")
        f_opt = bm.jit(bm.vmap(lambda x0: opt_method(self.f_loss, x0)))
        res = f_opt(bm.as_device_array(candidates))
        valid_ids = jax.numpy.where(res.success)[0]
        self._fixed_points = np.asarray(res.x[valid_ids])
        self._losses = np.asarray(res.fun[valid_ids])
        self._selected_ids = np.asarray(valid_ids)
        if self.verbose:
            print(
                f'    '
                f'Found {len(valid_ids)} fixed points from {len(candidates)} initial points.'
            )
示例#2
0
文件: optim.py 项目: xidulu/numpyro
 def eval_and_update(self, fn: Callable, state: _IterOptState) -> _IterOptState:
     i, (flat_params, unravel_fn) = state
     results = minimize(lambda x: fn(unravel_fn(x)), flat_params, (),
                        method=self._method, **self._kwargs)
     flat_params, out = results.x, results.fun
     state = (i + 1, _MinimizeState(flat_params, unravel_fn))
     return out, state
示例#3
0
文件: main.py 项目: fehiepsi/jaxns
 def do_minimisation():
     results = minimize(loss,
                        jnp.zeros(prior_transform.U_ndims),
                        method='BFGS',
                        options=dict(gtol=1e-10, line_search_maxiter=200))
     print(results.message)
     return prior_transform(constrain(results.x)), constrain(
         results.x), results.status
 def do_minimize():
     results = minimize(loss,
                        Q0,
                        method='BFGS',
                        options=dict(gtol=1e-8, line_search_maxiter=100))
     print(results.message)
     return results.x.reshape(
         (K, 7)
     ), results.status, results.fun, results.nfev, results.nit, results.jac
示例#5
0
def run_scf(x0, coords, mf, mo_coeff, mo_occ):
    options = {"gtol": 1e-6}
    res = minimize(energy_tot,
                   x0,
                   args=(coords, mf, mo_coeff, mo_occ),
                   method="BFGS",
                   options=options)
    e = energy_tot(res.x, coords, mf, mo_coeff, mo_occ)
    print("SCF energy: ", e)
示例#6
0
文件: debug.py 项目: fehiepsi/jaxns
def debug_vmap_bfgs():
    import jax.numpy as jnp
    from jax import jit, config
    from jax.scipy.optimize import minimize
    import os
    config.enable_omnistaging()
    ncpu=2
    os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={ncpu}"

    def cost_fn(x):
        return -jnp.sum(x**2)

    x = random.uniform(random.PRNGKey(0), (3,), minval=-1, maxval=1)
    result = jit(lambda x: minimize(cost_fn, x, method='BFGS'))(x)
    print(result)
示例#7
0
    def eval_and_update(self, fn: Callable[[Any], Tuple], state: _IterOptState):
        i, (flat_params, unravel_fn) = state

        def loss_fn(x):
            x = unravel_fn(x)
            out, aux = fn(x)
            if aux is not None:
                raise ValueError(
                    "Minimize does not support models with mutable states."
                )
            return out

        results = minimize(
            loss_fn, flat_params, (), method=self._method, **self._kwargs
        )
        flat_params, out = results.x, results.fun
        state = (i + 1, _MinimizeState(flat_params, unravel_fn))
        return (out, None), state
示例#8
0
    def train_bfgs(self,
                   n_batches,
                   batch_fn,
                   options,
                   loss_names,
                   log_file=None,
                   scale=1.0):
        param_shapes = apply_to_nested_list(self.params, lambda x: x.shape)
        flatten = flatten_list(self.params)
        flatten_params = jnp.hstack([x.reshape(-1, ) for x in flatten])

        @jax.jit
        def loss_fn_bfgs(params, batch):
            params_ = unflatten_to_shape(params, param_shapes)
            return self.loss_fn(params_, batch) * scale

        for i in range(n_batches):
            batch = batch_fn(i)
            loss_fn_batch = jax.jit(partial(loss_fn_bfgs, batch=batch))
            opt_results = minimize(loss_fn_batch,
                                   flatten_params,
                                   method="bfgs",
                                   tol=1e-7,
                                   options=options)
            print(
                "Success: {},\n Status: {},\n Message: {},\n nfev: {},\n njev: {},\n nit: {}"
                .format(opt_results.success, opt_results.status,
                        opt_results.message, opt_results.nfev,
                        opt_results.njev, opt_results.nit))

            flatten_params = opt_results.x
            losses = self.evaluate_fn(
                unflatten_to_shape(flatten_params, param_shapes), batch)
            print("{}, Batch: {}, BFGS".format(get_time(), i) + \
              ','.join([" {}: {:.4e}".format(name, loss) for name, loss in zip(loss_names, losses)]), file = sys.stdout if log_file is None else log_file)

        return unflatten_to_shape(flatten_params, param_shapes)
示例#9
0
def optimize_subspace(key, d, D):
    """
    Optimize the subspace loss function for a given dimension d.

    Parameters
    ----------
    key : jax.random.PRNGKey
        Random number generator key
    d : int
        Dimension of the subspace
    
    Returns
    -------
    jax._src.scipy.optimize.minimize.OptimizeResults: Optimization results
    """
    key_weight, key_map, key_sign = random.split(key, 3)
    theta_0 = random.normal(key_weight, (D, )) / 10
    theta_sub_0 = jnp.zeros(d)

    choice_map = random.bernoulli(key_map, 1 / jnp.sqrt(D), shape=(D, d))
    P = random.choice(key_sign, jnp.array([-1, 1]), shape=(D, d)) * choice_map
    f_part = partial(subspace_loss, P=P, theta_0=theta_0, y=y)
    res = minimize(f_part, theta_sub_0, method="bfgs", tol=1e-3)
    return res
示例#10
0
E = partial(E_base, Phi=Phi, y=y, alpha=alpha)
initial_state = mh.new_state(w0, E)

mcmc_kernel = mh.kernel(E, jnp.ones(M) * sigma_mcmc)
mcmc_kernel = jax.jit(mcmc_kernel)

n_samples = 5_000
burnin = 300
key_init = jax.random.PRNGKey(0)
states = inference_loop(key_init, mcmc_kernel, initial_state, n_samples)

chains = states.position[burnin:, :]
nsamp, _ = chains.shape

# ** Laplace approximation **
res = minimize(lambda x: E(x) / len(y), w0, method="BFGS")
w_map = res.x
SN = jax.hessian(E)(w_map)

# ** ADF inference **
q = 0.14
lbound, ubound = -10, 10
mu_t = jnp.zeros(M)
tau_t = jnp.ones(M) * q

init_state = (mu_t, tau_t)
xs = (Phi, y)

adf_loop = partial(adf_step, q=q, lbound=lbound, ubound=ubound)
(mu_t, tau_t), (mu_t_hist, tau_t_hist) = jax.lax.scan(adf_loop, init_state, xs)
示例#11
0
    res = minimize(f_part, theta_sub_0, method="bfgs", tol=1e-3)
    return res


if __name__ == "__main__":
    plt.rcParams["axes.spines.top"] = False
    plt.rcParams["axes.spines.right"] = False

    D = 1000
    R = 10
    y = jnp.arange(R) + 1

    # 1. Obtain optimal loss for the full-dimensional function
    theta_0 = jnp.zeros(D)
    f_part = partial(full_dimension_loss, y=y)
    res = minimize(f_part, theta_0, method="bfgs")
    optimal_loss = res.fun

    # 2. Obtain optimal loss for the subspace function at
    #    different dimensions
    dimensions = jnp.array(list(range(1, 16)) + [20, 30, 30])
    key = random.PRNGKey(314)
    keys = random.split(key, len(dimensions))

    ans = {"dim": [], "loss": [], "w": []}

    for key, dim in zip(keys, dimensions):
        print(f"@dim={dim}", end="\r")
        res = optimize_subspace(key, dim, D)
        ans["dim"].append(dim)
        ans["loss"].append(res.fun)