Example #1
0
def grad_expect_operator_kernel(
    machine_pow: int,
    model_apply_fun: Callable,
    local_kernel: Callable,
    mutable: bool,
    parameters: PyTree,
    model_state: PyTree,
    σ: jnp.ndarray,
    σp: jnp.ndarray,
    mels: jnp.ndarray,
) -> Tuple[PyTree, PyTree, Stats]:

    if not config.FLAGS["NETKET_EXPERIMENTAL"]:
        raise RuntimeError(
            """
                           Computing the gradient of a squared or non hermitian 
                           operator is an experimental feature under development 
                           and is known not to return wrong values sometimes.

                           If you want to debug it, set the environment variable
                           NETKET_EXPERIMENTAL=1
                           """
        )

    σ_shape = σ.shape
    if jnp.ndim(σ) != 2:
        σ = σ.reshape((-1, σ_shape[-1]))

    has_aux = mutable is not False
    if not has_aux:
        out_axes = (0, 0)
    else:
        out_axes = (0, 0, 0)

    if not has_aux:
        logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ)
    else:
        # TODO: output the mutable state
        logpsi = lambda w, σ: model_apply_fun(
            {"params": w, **model_state}, σ, mutable=mutable
        )[0]

    log_pdf = (
        lambda w, σ: machine_pow * model_apply_fun({"params": w, **model_state}, σ).real
    )

    def expect_closure(*args):
        local_kernel_vmap = jax.vmap(
            partial(local_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0
        )

        return nkjax.expect(log_pdf, local_kernel_vmap, *args, n_chains=σ_shape[0])

    def expect_closure_pars(pars):
        return expect_closure(pars, σ, σp, mels)

    Ō, Ō_pb, Ō_stats = nkjax.vjp(expect_closure_pars, parameters, has_aux=True)
    Ō_pars_grad = Ō_pb(jnp.ones_like(Ō))

    return (
        Ō_stats,
        tree_map(lambda x: sum_inplace(x) / utils.n_nodes, Ō_pars_grad),
        model_state,
    )
Example #2
0
 def n_accepted(self) -> int:
     """Total number of moves accepted across all processes since the last reset."""
     return sum_inplace(self.n_accepted_proc)