def _expect( machine_pow: int, model_apply_fun: Callable, local_value_kernel: Callable, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Stats: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) log_pdf = ( lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real ) local_value_vmap = jax.vmap( partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0, ) _, Ō_stats = nkjax.expect( log_pdf, local_value_vmap, parameters, σ, σp, mels, n_chains=σ_shape[0] ) return Ō_stats
def _expect( local_value_kernel: Callable, model_apply_fun: Callable, machine_pow: int, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, local_value_args: PyTree, ) -> Stats: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) def logpsi(w, σ): return model_apply_fun({"params": w, **model_state}, σ) def log_pdf(w, σ): return machine_pow * model_apply_fun({"params": w, **model_state}, σ).real _, Ō_stats = nkjax.expect( log_pdf, partial(local_value_kernel, logpsi), parameters, σ, local_value_args, n_chains=σ_shape[0], ) return Ō_stats
def expect_closure_pars(pars): return nkjax.expect( log_pdf, partial(local_value_kernel, logpsi), pars, σ, local_value_args, n_chains=σ_shape[0], )
def expect_closure(*args): local_kernel_vmap = jax.vmap( partial(local_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 ) return nkjax.expect(log_pdf, local_kernel_vmap, *args, n_chains=σ_shape[0])