コード例 #1
0
 def network(inputs: jnp.ndarray) -> Tuple[Logits, Beta, Q, Logits_Omega]:
     flat_inputs = hk.Flatten()(inputs)  # Inputs flattened
     torso = hk.nets.MLP(
         [64, 64])  # Shared state processor, 2x64 with relu after each.
     # Option outputs
     policy_over_options_head = hk.Sequential(
         hk.Linear(n_options), Partial(jax.nn.softmax, axis=-1))
     beta_head = hk.Sequential(hk.Linear(n_options), jax.nn.sigmoid)
     interest_head = hk.Sequential(hk.Linear(n_options), jax.nn.sigmoid)
     q_head = hk.Linear(n_options)
     # q_ent_head = hk.Linear(n_options)
     policy_head = hk.Sequential(
         hk.Linear(action_spec.num_values * n_options),
         Partial(jnp.reshape,
                 newshape=(-1, n_options, action_spec.num_values)),
         Partial(jax.nn.softmax, axis=-1))
     embedding = torso(flat_inputs)
     logits = policy_head(embedding)
     beta = beta_head(embedding)
     interest = interest_head(embedding)
     pi_omega = policy_over_options_head(embedding)
     pi_omega = pi_omega * interest
     pi_omega = pi_omega / jnp.sum(pi_omega,
                                   axis=-1)  # Normalized interest policy
     # q_ent = q_ent_head(embedding)
     q = q_head(embedding)
     return logits, beta, q, pi_omega
コード例 #2
0
def losses_gmres_flax(preconditioner, model, n, new_matvec, x0, b):
    A = Partial(new_matvec)
    M = Partial(preconditioner, model)
    #loss = gmres.gmres_training(A, b, x0, n=n, M=M)
    x = gmres.gmres(A, b, x0, n=n, M=M)

    return np.linalg.norm(A(x) - b) * 10000000
コード例 #3
0
def mat_vec_factory(forward_fn, params, model_state, samples):
    # "forward function" that maps params to outputs
    def fun(W):
        return forward_fn({"params": W, **model_state}, samples)

    _, jvp_fn = jax.linearize(fun, params)
    return Partial(mat_vec, jvp_fn)
コード例 #4
0
def losses_lin_supervised(lin_op, mesh, params, n, x0, b, solution):
    multigrid = lambda x: MG._V_Cycle(x, b.reshape(mesh.shape), 3, 'R')
    iterator = lambda x: multigrid(x.reshape(mesh.shape)).ravel() + \
        lin_op(params, multigrid(x.reshape(mesh.shape)).ravel() - x)
    A = Partial(iterator)
    (_, x_opt), _ = jax.lax.scan(_lin_iter, (A, x0), np.arange(n))
    return np.linalg.norm(x_opt -
                          solution) * 10000000 / np.linalg.norm(solution)
コード例 #5
0
def losses_lin(lin_op, mesh, params, n, k, x0, b):
    multigrid = lambda x: MG._V_Cycle(x, b.reshape(mesh.shape), 3, 'R', k=k)
    iterator = lambda x: multigrid(x.reshape(mesh.shape)).ravel() + \
        lin_op(params, multigrid(x.reshape(mesh.shape)).ravel() - x)
    A = Partial(iterator)
    (_, x_opt), _ = jax.lax.scan(_lin_iter, (A, x0), np.arange(n))
    new_matvec = lambda x: mesh.matvec_helmholtz(k, 1.0, equations.make_mask,
                                                 equations.make_mask_dual, x)
    return np.linalg.norm(new_matvec(x_opt) - b) * 10000000 / np.linalg.norm(b)
コード例 #6
0
ファイル: ad.py プロジェクト: jbampton/jax
def vjp(traceable, primals, has_aux=False, reduce_axes=()):
  if not has_aux:
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  else:
    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

  def unbound_vjp(pvals, jaxpr, consts, *cts):
    cts = tuple(map(ignore_consts, cts, pvals))
    dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)
    return map(instantiate_zeros, arg_cts)

  # Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward
  # pass in a custom VJP.
  vjp_ =  Partial(partial(unbound_vjp, pvals, jaxpr), consts)
  if not has_aux:
    return out_primals, vjp_
  else:
    return out_primals, vjp_, aux
コード例 #7
0
ファイル: _vjp_chunked.py プロジェクト: yannra/netket
def vjp_chunked(
        fun,
        *primals,
        has_aux=False,
        chunk_argnums=(),
        chunk_size=None,
        nondiff_argnums=(),
        return_forward=False,
        conjugate=False,
):
    """calculate the vjp in small chunks for a function where the leading dimension of the output only depends on the leading dimension of some of the arguments

    Args:
        fun: Function to be differentiated. It must accept chunks of size chunk_size of the primals in chunk_argnums.
        primals:  A sequence of primal values at which the Jacobian of ``fun`` should be evaluated.
        has_aux: Optional, bool. Only False is implemented. Indicates whether ``fun`` returns a pair where the
           first element is considered the output of the mathematical function to be
           differentiated and the second element is auxiliary data. Default False.
        chunk_argnums: an integer or tuple of integers indicating the primals which should be chunked.
            The leading dimension of each of the primals indicated must be the same as the output of fun.
        chunk_size: an integer indicating the size of the chunks over which the vjp is computed.
            It must be a integer divisor of the primals specified in chunk_argnums.
        nondiff_argnums: an integer or tuple of integers indicating the primals which should not be differentiated with.
            Specifying the arguments which are not needed should increase performance.
        return_forward: whether the returned function should also return the output of the forward pass

    Returns:
        a function corresponding to the vjp_fun returned by an equivalent ``jax.vjp(fun, *primals)[1]``` call
        which computes the vjp in chunks (recomputing the forward pass every time on subsequent calls).
        If return_forward=True the vjp_fun returned returns a tuple containing the output of the forward pass and the vjp.

    Example:
        >>> import jax
        >>> from netket.jax import vjp_chunked
        >>> from functools import partial
        >>>
        >>> @partial(jax.vmap, in_axes=(None, 0))
        ... def f(p, x):
        ...     return jax.lax.log(p.dot(jax.lax.sin(x)))
        >>>
        >>> k = jax.random.split(jax.random.PRNGKey(123), 4)
        >>> p = jax.random.uniform(k[0], shape=(8,))
        >>> v = jax.random.uniform(k[1], shape=(8,))
        >>> X = jax.random.uniform(k[2], shape=(1024,8))
        >>> w = jax.random.uniform(k[3], shape=(1024,))
        >>>
        >>> vjp_fun_chunked = vjp_chunked(f, p, X, chunk_argnums=(1,), chunk_size=32, nondiff_argnums=1)
        >>> vjp_fun = jax.vjp(f, p, X)[1]
        >>>
        >>> vjp_fun_chunked(w)
        (DeviceArray([106.76358917, 113.3123931 , 101.95475061, 104.11138622,
                      111.95590131, 109.17531467, 108.97138052, 106.89249739],            dtype=float64),)
        >>> vjp_fun(w)[:1]
        (DeviceArray([106.76358917, 113.3123931 , 101.95475061, 104.11138622,
                      111.95590131, 109.17531467, 108.97138052, 106.89249739],            dtype=float64),)
    """

    if not isinstance(primals, (tuple, list)):
        raise TypeError(
            "primal arguments to vjp_chunked must be a tuple or list; "
            f"found {type(primals).__name__}.")

    if isinstance(chunk_argnums, int):
        chunk_argnums = (chunk_argnums, )

    if not all(map(lambda x: (0 <= x) and (x < len(primals)), chunk_argnums)):
        raise ValueError(
            "chunk_argnums must index primals. Got chunk_argnums={} but len(primals)={}"
            .format(chunk_argnums, len(primals)))
        # TODO also check they are unique?

    if isinstance(nondiff_argnums, int):
        nondiff_argnums = (nondiff_argnums, )

    if chunk_argnums == ():
        chunk_size = None

    if chunk_size is not None:
        n_elements = jax.tree_leaves(primals[chunk_argnums[0]])[0].shape[0]

        # check that they are all the same size
        chunk_leaves = jax.tree_leaves([primals[i] for i in chunk_argnums])
        if not all(map(lambda x: x.shape[0] == n_elements, chunk_leaves)):
            raise ValueError(
                "The chunked arguments have inconsistent leading array dimensions"
            )

        if chunk_size >= n_elements:
            chunk_size = None

    if chunk_size is None:

        y, vjp_fun = nkvjp(fun, *primals, conjugate=conjugate, has_aux=has_aux)

        if return_forward:

            def __vjp_fun(y, vjp_fun, cotangents):
                res = vjp_fun(cotangents)
                res = _trash_tuple_elements(res, nondiff_argnums)
                return y, res

            return Partial(__vjp_fun, y, vjp_fun)
        else:

            def __vjp_fun(vjp_fun, cotangents):
                res = vjp_fun(cotangents)
                res = _trash_tuple_elements(res, nondiff_argnums)
                return res

            return Partial(__vjp_fun, vjp_fun)
    if has_aux:
        raise NotImplementedError
        # fun = compose(lambda x_aux: x_aux[0], fun)
        # TODO in principle we could also return the aux of the fwd pass for every chunk...

    _vjp_fun = _value_and_vjp_fun_chunked if return_forward else _vjp_fun_chunked

    return Partial(
        partial(
            _vjp_fun,
            fun,
            chunk_argnums=chunk_argnums,
            nondiff_argnums=nondiff_argnums,
            chunk_size=chunk_size,
            conjugate=conjugate,
        ),
        primals,
    )
コード例 #8
0
def losses_gmres(preconditioner, params, n, new_matvec, x0, b):
    A = Partial(new_matvec)
    M = Partial(preconditioner, params)
    loss = gmres.gmres_training(A, b, x0, n=n, M=M)
    return loss * 10000000
コード例 #9
0
def losses_gmres_inf(preconditioner, params, n, new_matvec, x0, b):
    A = Partial(new_matvec)
    M = Partial(preconditioner, params)
    x_opt = gmres.gmres(A, b, x0, n=n, M=M)

    return np.linalg.norm(A(x_opt) - b, np.inf) * 1000 * x_opt.shape[0]
コード例 #10
0
ファイル: modSel.py プロジェクト: OriolAbril/mombf.py
def modelSelection(
    X, y, models=None, prior="normal", family="logistic", groups=None, method="post"
):
    """Bayesian model selection for generalized linear models.

    ``modelSelection`` enumerates all the models or iterates over ``models`` to
    perform Bayesian model selection for linear, logistic or poisson regression
    using local or non-local priors.

    Parameters
    ----------
    X : array_like
        Design matrix
    y : array_like
        Observations
    models : array_like, optional
        If ``None``, enumeration will consider all possible models `ignoring groups`.
    prior : {mom, normal}, optional
    family : {logistic, poisson}, optional
    groups : array_like, optional
    method : {post, like}, optional

    Returns
    -------
    models : array_like
        Models considered in enumeration.
    modelprobs : array_like
        Posterior probabilities assigned to each model.
    """
    if prior not in ("normal", "mom"):
        raise ValueError("prior not recognized")
    _, p = X.shape
    if models is None:
        n_models = 2 ** p
        model_i = 0
    else:
        n_models = models.shape[0]
    if groups is None:
        groups = jnp.arange(p)
    W, Winv = get_group_zellner(groups, X, prior == "mom")
    p_j = get_p_j(groups)
    modelprobs = jnp.empty(n_models)
    fact_y = jnp.sum(gammaln(y + 1))
    ytX = jnp.dot(y, X)
    if prior == "mom":
        XtX = jnp.dot(X.T, X)
    ## define vmapped functions ##
    # likelihood helpers
    if family == "poisson":
        _loglik = lambda b, ytx, x: poisson_log_lik(b, ytx, x, fact_y=fact_y)
    elif family == "logistic":
        _loglik = logistic_log_lik
    else:
        raise ValueError("family not recognized")
    if method == "post":
        _logpost = lambda b, ytx, x, w: (_loglik(b, ytx, x) + normalprior(b, 1, 1, w))
        vmaparg = (0, 0, 0, 0)
        logpost = jit(vmap(_logpost, vmaparg, 0))
        glogpost = jit(vmap(grad(_logpost, argnums=0), vmaparg, 0))
        hlogpost = jit(vmap(hessian(_logpost, argnums=0), vmaparg, 0))
        jitted_ala = jit(
            Partial(
                marghood_ala_post, logpost=logpost, glogpost=glogpost, hlogpost=hlogpost
            )
        )
    elif method == "lik":
        logpr = jit(vmap(lambda b, w: normalprior(b, 1, 1, w), (0, 0), 0))
        loglik = jit(vmap(_loglik, (0, 0, 0), 0))
        gloglik = jit(vmap(grad(_loglik, argnums=0), (0, 0, 0), 0))
        hloglik = jit(vmap(hessian(_loglik, argnums=0), (0, 0, 0), 0))
        jitted_ala = jit(
            Partial(
                marghood_ala_lik,
                logl=loglik,
                glogl=gloglik,
                hlogl=hloglik,
                logpr=logpr,
            )
        )
    else:
        raise ValueError("method not recognized")
    ## loop, we need equal shapes for vmap, hence calculations are vectorized ##
    ## on a number of selected variables basis ##
    for n_vars in range(1, p + 1):
        if models is None:
            models_iter = jnp.array(list(combinations(jnp.arange(p), n_vars)))
            model_mask = jnp.full((n_models,), False)
            model_mask = model_mask.at[model_i : models_iter.shape[0]].set(True)
            model_i += models_iter.shape[0]
        else:
            model_mask = models.sum(axis=1) == n_vars
            if model_mask.sum() == 0:
                continue
            models_iter = (
                jnp.arange(p)
                .reshape((1, -1))[models[model_mask, :]]
                .reshape((-1, n_vars))
            )
        b0 = jnp.zeros(models_iter.shape)
        X_iter = apply_mask_2d(X, models_iter)
        ytX_iter = apply_mask_1d(ytX, models_iter)
        W_iter = apply_mask_matrix(W, models_iter)
        if method == "post":
            args = [ytX_iter, X_iter, W_iter]
            margs = jitted_ala(b0=b0, args=args)
        elif method == "lik":
            argspr = (W_iter,)
            argsl = (ytX_iter, X_iter)
            margs = jitted_ala(b0=b0, argsl=argsl, argspr=argspr)
        if prior == "mom":
            p_j_iter = apply_mask_1d(p_j, models_iter)
            Winv_iter = apply_mask_matrix(Winv, models_iter)
            XtX_iter = apply_mask_matrix(XtX, models_iter)
            margs += vmap(gmomprior_correction, (None, 0, 0, 0, 0), 0)(
                1, Winv_iter, p_j_iter, XtX_iter, ytX_iter
            )
        modelprobs = modelprobs.at[model_mask].set(margs)
    return models, modelprobs
コード例 #11
0
def _V_Cycle_GMRES(x, f, num_cycle, shapebc='R', k=0, aspect_ratio=1.0):
    # https://en.wikipedia.org/wiki/Multigrid_method
    # Pre-Smoothing
    # bc are not included

    h = 1.0 / (x.shape[0] + 1)
    if shapebc == 'R':
        mask_f = equations.make_mask
        mask_f_dual = equations.make_mask_dual
    elif shapebc == 'L':
        mask_f = equations.make_mask_L
        mask_f_dual = equations.make_mask_L_dual

    r = f - equations.helmholtz(x,
                                k,
                                step=h,
                                aspect_ratio=aspect_ratio,
                                mask_f=mask_f,
                                mask_f_dual=mask_f_dual)

    new_matvec = lambda z: equations.helmholtz(z.reshape(x.shape),
                                               k,
                                               step=h,
                                               aspect_ratio=aspect_ratio,
                                               mask_f=mask_f,
                                               mask_f_dual=mask_f_dual).ravel(
                                               )

    A = Partial(new_matvec)
    x = x + gmres.gmres(A, r.ravel(), n=5).reshape(x.shape)

    # Compute Residual Errors

    # no bc here because we assume they are 0
    r = f - equations.helmholtz(x,
                                k,
                                step=h,
                                aspect_ratio=aspect_ratio,
                                mask_f=mask_f,
                                mask_f_dual=mask_f_dual)
    # Restriction from h to 2h
    rhs = restriction(r)
    eps = np.zeros(rhs.shape)
    mask = mask_f(eps.shape[0], aspect_ratio)
    eps = np.multiply(eps, mask)
    # stop recursion after 3 cycles
    if num_cycle == 3:
        r = rhs - equations.helmholtz(eps,
                                      k,
                                      step=2 * h,
                                      aspect_ratio=aspect_ratio,
                                      mask_f=mask_f,
                                      mask_f_dual=mask_f_dual)

        new_matvec1 = lambda z: equations.helmholtz(z.reshape(eps.shape),
                                                    k,
                                                    step=2 * h,
                                                    aspect_ratio=aspect_ratio,
                                                    mask_f=mask_f,
                                                    mask_f_dual=mask_f_dual
                                                    ).ravel()

        A1 = Partial(new_matvec1)
        eps = eps + gmres.gmres(A1, r.ravel(), n=5).reshape(eps.shape)
    else:
        eps = _V_Cycle(eps,
                       rhs,
                       num_cycle + 1,
                       shapebc,
                       k=k,
                       aspect_ratio=aspect_ratio)

    # Prolongation and Correction
    x = x + prolongation(eps)
    mask = mask_f(x.shape[0], aspect_ratio)
    x = np.multiply(x, mask)
    # Post-Smoothing
    r = f - equations.helmholtz(x,
                                k,
                                step=h,
                                aspect_ratio=aspect_ratio,
                                mask_f=mask_f,
                                mask_f_dual=mask_f_dual)
    x = x + gmres.gmres(A, r.ravel(), n=5).reshape(x.shape)
    return x
コード例 #12
0
ファイル: shims.py プロジェクト: NeilGirdhar/tjax
 def __get__(self, instance: Any, owner: Any = None) -> Callable[..., R_co]:
     if instance is None:
         return self
     # Create a partial function application corresponding to a bound method.
     return Partial(self, instance)  # type: ignore[no-untyped-call]
コード例 #13
0
def mat_vec_chunked_factory(forward_fn, params, model_state, samples):
    def fun(W, samples):
        return forward_fn({"params": W, **model_state}, samples)

    return Partial(partial(matvec_chunked_transposable, fun), params, samples)