def network(inputs: jnp.ndarray) -> Tuple[Logits, Beta, Q, Logits_Omega]: flat_inputs = hk.Flatten()(inputs) # Inputs flattened torso = hk.nets.MLP( [64, 64]) # Shared state processor, 2x64 with relu after each. # Option outputs policy_over_options_head = hk.Sequential( hk.Linear(n_options), Partial(jax.nn.softmax, axis=-1)) beta_head = hk.Sequential(hk.Linear(n_options), jax.nn.sigmoid) interest_head = hk.Sequential(hk.Linear(n_options), jax.nn.sigmoid) q_head = hk.Linear(n_options) # q_ent_head = hk.Linear(n_options) policy_head = hk.Sequential( hk.Linear(action_spec.num_values * n_options), Partial(jnp.reshape, newshape=(-1, n_options, action_spec.num_values)), Partial(jax.nn.softmax, axis=-1)) embedding = torso(flat_inputs) logits = policy_head(embedding) beta = beta_head(embedding) interest = interest_head(embedding) pi_omega = policy_over_options_head(embedding) pi_omega = pi_omega * interest pi_omega = pi_omega / jnp.sum(pi_omega, axis=-1) # Normalized interest policy # q_ent = q_ent_head(embedding) q = q_head(embedding) return logits, beta, q, pi_omega
def losses_gmres_flax(preconditioner, model, n, new_matvec, x0, b): A = Partial(new_matvec) M = Partial(preconditioner, model) #loss = gmres.gmres_training(A, b, x0, n=n, M=M) x = gmres.gmres(A, b, x0, n=n, M=M) return np.linalg.norm(A(x) - b) * 10000000
def mat_vec_factory(forward_fn, params, model_state, samples): # "forward function" that maps params to outputs def fun(W): return forward_fn({"params": W, **model_state}, samples) _, jvp_fn = jax.linearize(fun, params) return Partial(mat_vec, jvp_fn)
def losses_lin_supervised(lin_op, mesh, params, n, x0, b, solution): multigrid = lambda x: MG._V_Cycle(x, b.reshape(mesh.shape), 3, 'R') iterator = lambda x: multigrid(x.reshape(mesh.shape)).ravel() + \ lin_op(params, multigrid(x.reshape(mesh.shape)).ravel() - x) A = Partial(iterator) (_, x_opt), _ = jax.lax.scan(_lin_iter, (A, x0), np.arange(n)) return np.linalg.norm(x_opt - solution) * 10000000 / np.linalg.norm(solution)
def losses_lin(lin_op, mesh, params, n, k, x0, b): multigrid = lambda x: MG._V_Cycle(x, b.reshape(mesh.shape), 3, 'R', k=k) iterator = lambda x: multigrid(x.reshape(mesh.shape)).ravel() + \ lin_op(params, multigrid(x.reshape(mesh.shape)).ravel() - x) A = Partial(iterator) (_, x_opt), _ = jax.lax.scan(_lin_iter, (A, x0), np.arange(n)) new_matvec = lambda x: mesh.matvec_helmholtz(k, 1.0, equations.make_mask, equations.make_mask_dual, x) return np.linalg.norm(new_matvec(x_opt) - b) * 10000000 / np.linalg.norm(b)
def vjp(traceable, primals, has_aux=False, reduce_axes=()): if not has_aux: out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) else: out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True) def unbound_vjp(pvals, jaxpr, consts, *cts): cts = tuple(map(ignore_consts, cts, pvals)) dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars] arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts) return map(instantiate_zeros, arg_cts) # Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward # pass in a custom VJP. vjp_ = Partial(partial(unbound_vjp, pvals, jaxpr), consts) if not has_aux: return out_primals, vjp_ else: return out_primals, vjp_, aux
def vjp_chunked( fun, *primals, has_aux=False, chunk_argnums=(), chunk_size=None, nondiff_argnums=(), return_forward=False, conjugate=False, ): """calculate the vjp in small chunks for a function where the leading dimension of the output only depends on the leading dimension of some of the arguments Args: fun: Function to be differentiated. It must accept chunks of size chunk_size of the primals in chunk_argnums. primals: A sequence of primal values at which the Jacobian of ``fun`` should be evaluated. has_aux: Optional, bool. Only False is implemented. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. chunk_argnums: an integer or tuple of integers indicating the primals which should be chunked. The leading dimension of each of the primals indicated must be the same as the output of fun. chunk_size: an integer indicating the size of the chunks over which the vjp is computed. It must be a integer divisor of the primals specified in chunk_argnums. nondiff_argnums: an integer or tuple of integers indicating the primals which should not be differentiated with. Specifying the arguments which are not needed should increase performance. return_forward: whether the returned function should also return the output of the forward pass Returns: a function corresponding to the vjp_fun returned by an equivalent ``jax.vjp(fun, *primals)[1]``` call which computes the vjp in chunks (recomputing the forward pass every time on subsequent calls). If return_forward=True the vjp_fun returned returns a tuple containing the output of the forward pass and the vjp. Example: >>> import jax >>> from netket.jax import vjp_chunked >>> from functools import partial >>> >>> @partial(jax.vmap, in_axes=(None, 0)) ... def f(p, x): ... return jax.lax.log(p.dot(jax.lax.sin(x))) >>> >>> k = jax.random.split(jax.random.PRNGKey(123), 4) >>> p = jax.random.uniform(k[0], shape=(8,)) >>> v = jax.random.uniform(k[1], shape=(8,)) >>> X = jax.random.uniform(k[2], shape=(1024,8)) >>> w = jax.random.uniform(k[3], shape=(1024,)) >>> >>> vjp_fun_chunked = vjp_chunked(f, p, X, chunk_argnums=(1,), chunk_size=32, nondiff_argnums=1) >>> vjp_fun = jax.vjp(f, p, X)[1] >>> >>> vjp_fun_chunked(w) (DeviceArray([106.76358917, 113.3123931 , 101.95475061, 104.11138622, 111.95590131, 109.17531467, 108.97138052, 106.89249739], dtype=float64),) >>> vjp_fun(w)[:1] (DeviceArray([106.76358917, 113.3123931 , 101.95475061, 104.11138622, 111.95590131, 109.17531467, 108.97138052, 106.89249739], dtype=float64),) """ if not isinstance(primals, (tuple, list)): raise TypeError( "primal arguments to vjp_chunked must be a tuple or list; " f"found {type(primals).__name__}.") if isinstance(chunk_argnums, int): chunk_argnums = (chunk_argnums, ) if not all(map(lambda x: (0 <= x) and (x < len(primals)), chunk_argnums)): raise ValueError( "chunk_argnums must index primals. Got chunk_argnums={} but len(primals)={}" .format(chunk_argnums, len(primals))) # TODO also check they are unique? if isinstance(nondiff_argnums, int): nondiff_argnums = (nondiff_argnums, ) if chunk_argnums == (): chunk_size = None if chunk_size is not None: n_elements = jax.tree_leaves(primals[chunk_argnums[0]])[0].shape[0] # check that they are all the same size chunk_leaves = jax.tree_leaves([primals[i] for i in chunk_argnums]) if not all(map(lambda x: x.shape[0] == n_elements, chunk_leaves)): raise ValueError( "The chunked arguments have inconsistent leading array dimensions" ) if chunk_size >= n_elements: chunk_size = None if chunk_size is None: y, vjp_fun = nkvjp(fun, *primals, conjugate=conjugate, has_aux=has_aux) if return_forward: def __vjp_fun(y, vjp_fun, cotangents): res = vjp_fun(cotangents) res = _trash_tuple_elements(res, nondiff_argnums) return y, res return Partial(__vjp_fun, y, vjp_fun) else: def __vjp_fun(vjp_fun, cotangents): res = vjp_fun(cotangents) res = _trash_tuple_elements(res, nondiff_argnums) return res return Partial(__vjp_fun, vjp_fun) if has_aux: raise NotImplementedError # fun = compose(lambda x_aux: x_aux[0], fun) # TODO in principle we could also return the aux of the fwd pass for every chunk... _vjp_fun = _value_and_vjp_fun_chunked if return_forward else _vjp_fun_chunked return Partial( partial( _vjp_fun, fun, chunk_argnums=chunk_argnums, nondiff_argnums=nondiff_argnums, chunk_size=chunk_size, conjugate=conjugate, ), primals, )
def losses_gmres(preconditioner, params, n, new_matvec, x0, b): A = Partial(new_matvec) M = Partial(preconditioner, params) loss = gmres.gmres_training(A, b, x0, n=n, M=M) return loss * 10000000
def losses_gmres_inf(preconditioner, params, n, new_matvec, x0, b): A = Partial(new_matvec) M = Partial(preconditioner, params) x_opt = gmres.gmres(A, b, x0, n=n, M=M) return np.linalg.norm(A(x_opt) - b, np.inf) * 1000 * x_opt.shape[0]
def modelSelection( X, y, models=None, prior="normal", family="logistic", groups=None, method="post" ): """Bayesian model selection for generalized linear models. ``modelSelection`` enumerates all the models or iterates over ``models`` to perform Bayesian model selection for linear, logistic or poisson regression using local or non-local priors. Parameters ---------- X : array_like Design matrix y : array_like Observations models : array_like, optional If ``None``, enumeration will consider all possible models `ignoring groups`. prior : {mom, normal}, optional family : {logistic, poisson}, optional groups : array_like, optional method : {post, like}, optional Returns ------- models : array_like Models considered in enumeration. modelprobs : array_like Posterior probabilities assigned to each model. """ if prior not in ("normal", "mom"): raise ValueError("prior not recognized") _, p = X.shape if models is None: n_models = 2 ** p model_i = 0 else: n_models = models.shape[0] if groups is None: groups = jnp.arange(p) W, Winv = get_group_zellner(groups, X, prior == "mom") p_j = get_p_j(groups) modelprobs = jnp.empty(n_models) fact_y = jnp.sum(gammaln(y + 1)) ytX = jnp.dot(y, X) if prior == "mom": XtX = jnp.dot(X.T, X) ## define vmapped functions ## # likelihood helpers if family == "poisson": _loglik = lambda b, ytx, x: poisson_log_lik(b, ytx, x, fact_y=fact_y) elif family == "logistic": _loglik = logistic_log_lik else: raise ValueError("family not recognized") if method == "post": _logpost = lambda b, ytx, x, w: (_loglik(b, ytx, x) + normalprior(b, 1, 1, w)) vmaparg = (0, 0, 0, 0) logpost = jit(vmap(_logpost, vmaparg, 0)) glogpost = jit(vmap(grad(_logpost, argnums=0), vmaparg, 0)) hlogpost = jit(vmap(hessian(_logpost, argnums=0), vmaparg, 0)) jitted_ala = jit( Partial( marghood_ala_post, logpost=logpost, glogpost=glogpost, hlogpost=hlogpost ) ) elif method == "lik": logpr = jit(vmap(lambda b, w: normalprior(b, 1, 1, w), (0, 0), 0)) loglik = jit(vmap(_loglik, (0, 0, 0), 0)) gloglik = jit(vmap(grad(_loglik, argnums=0), (0, 0, 0), 0)) hloglik = jit(vmap(hessian(_loglik, argnums=0), (0, 0, 0), 0)) jitted_ala = jit( Partial( marghood_ala_lik, logl=loglik, glogl=gloglik, hlogl=hloglik, logpr=logpr, ) ) else: raise ValueError("method not recognized") ## loop, we need equal shapes for vmap, hence calculations are vectorized ## ## on a number of selected variables basis ## for n_vars in range(1, p + 1): if models is None: models_iter = jnp.array(list(combinations(jnp.arange(p), n_vars))) model_mask = jnp.full((n_models,), False) model_mask = model_mask.at[model_i : models_iter.shape[0]].set(True) model_i += models_iter.shape[0] else: model_mask = models.sum(axis=1) == n_vars if model_mask.sum() == 0: continue models_iter = ( jnp.arange(p) .reshape((1, -1))[models[model_mask, :]] .reshape((-1, n_vars)) ) b0 = jnp.zeros(models_iter.shape) X_iter = apply_mask_2d(X, models_iter) ytX_iter = apply_mask_1d(ytX, models_iter) W_iter = apply_mask_matrix(W, models_iter) if method == "post": args = [ytX_iter, X_iter, W_iter] margs = jitted_ala(b0=b0, args=args) elif method == "lik": argspr = (W_iter,) argsl = (ytX_iter, X_iter) margs = jitted_ala(b0=b0, argsl=argsl, argspr=argspr) if prior == "mom": p_j_iter = apply_mask_1d(p_j, models_iter) Winv_iter = apply_mask_matrix(Winv, models_iter) XtX_iter = apply_mask_matrix(XtX, models_iter) margs += vmap(gmomprior_correction, (None, 0, 0, 0, 0), 0)( 1, Winv_iter, p_j_iter, XtX_iter, ytX_iter ) modelprobs = modelprobs.at[model_mask].set(margs) return models, modelprobs
def _V_Cycle_GMRES(x, f, num_cycle, shapebc='R', k=0, aspect_ratio=1.0): # https://en.wikipedia.org/wiki/Multigrid_method # Pre-Smoothing # bc are not included h = 1.0 / (x.shape[0] + 1) if shapebc == 'R': mask_f = equations.make_mask mask_f_dual = equations.make_mask_dual elif shapebc == 'L': mask_f = equations.make_mask_L mask_f_dual = equations.make_mask_L_dual r = f - equations.helmholtz(x, k, step=h, aspect_ratio=aspect_ratio, mask_f=mask_f, mask_f_dual=mask_f_dual) new_matvec = lambda z: equations.helmholtz(z.reshape(x.shape), k, step=h, aspect_ratio=aspect_ratio, mask_f=mask_f, mask_f_dual=mask_f_dual).ravel( ) A = Partial(new_matvec) x = x + gmres.gmres(A, r.ravel(), n=5).reshape(x.shape) # Compute Residual Errors # no bc here because we assume they are 0 r = f - equations.helmholtz(x, k, step=h, aspect_ratio=aspect_ratio, mask_f=mask_f, mask_f_dual=mask_f_dual) # Restriction from h to 2h rhs = restriction(r) eps = np.zeros(rhs.shape) mask = mask_f(eps.shape[0], aspect_ratio) eps = np.multiply(eps, mask) # stop recursion after 3 cycles if num_cycle == 3: r = rhs - equations.helmholtz(eps, k, step=2 * h, aspect_ratio=aspect_ratio, mask_f=mask_f, mask_f_dual=mask_f_dual) new_matvec1 = lambda z: equations.helmholtz(z.reshape(eps.shape), k, step=2 * h, aspect_ratio=aspect_ratio, mask_f=mask_f, mask_f_dual=mask_f_dual ).ravel() A1 = Partial(new_matvec1) eps = eps + gmres.gmres(A1, r.ravel(), n=5).reshape(eps.shape) else: eps = _V_Cycle(eps, rhs, num_cycle + 1, shapebc, k=k, aspect_ratio=aspect_ratio) # Prolongation and Correction x = x + prolongation(eps) mask = mask_f(x.shape[0], aspect_ratio) x = np.multiply(x, mask) # Post-Smoothing r = f - equations.helmholtz(x, k, step=h, aspect_ratio=aspect_ratio, mask_f=mask_f, mask_f_dual=mask_f_dual) x = x + gmres.gmres(A, r.ravel(), n=5).reshape(x.shape) return x
def __get__(self, instance: Any, owner: Any = None) -> Callable[..., R_co]: if instance is None: return self # Create a partial function application corresponding to a bound method. return Partial(self, instance) # type: ignore[no-untyped-call]
def mat_vec_chunked_factory(forward_fn, params, model_state, samples): def fun(W, samples): return forward_fn({"params": W, **model_state}, samples) return Partial(partial(matvec_chunked_transposable, fun), params, samples)