Exemple #1
0
def _expect(
    machine_pow: int,
    model_apply_fun: Callable,
    local_value_kernel: Callable,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Stats:
    σ_shape = σ.shape

    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)

    log_pdf = (
        lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real
    )

    local_value_vmap = jax.vmap(
        partial(local_value_kernel, logpsi),
        in_axes=(None, 0, 0, 0),
        out_axes=0,
    )

    _, Ō_stats = nkjax.expect(
        log_pdf, local_value_vmap, parameters, σ, σp, mels, n_chains=σ_shape[0]
    )

    return Ō_stats
Exemple #2
0
def _expect(
    local_value_kernel: Callable,
    model_apply_fun: Callable,
    machine_pow: int,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    local_value_args: PyTree,
) -> Stats:
    σ_shape = σ.shape

    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    def logpsi(w, σ):
        return model_apply_fun({"params": w, **model_state}, σ)

    def log_pdf(w, σ):
        return machine_pow * model_apply_fun({"params": w, **model_state}, σ).real

    _, Ō_stats = nkjax.expect(
        log_pdf,
        partial(local_value_kernel, logpsi),
        parameters,
        σ,
        local_value_args,
        n_chains=σ_shape[0],
    )

    return Ō_stats
Exemple #3
0
 def expect_closure_pars(pars):
     return nkjax.expect(
         log_pdf,
         partial(local_value_kernel, logpsi),
         pars,
         σ,
         local_value_args,
         n_chains=σ_shape[0],
     )
Exemple #4
0
    def expect_closure(*args):
        local_kernel_vmap = jax.vmap(
            partial(local_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
        )

        return nkjax.expect(log_pdf, local_kernel_vmap, *args, n_chains=σ_shape[0])