def init_state( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, seed: Optional[SeedType] = None, ) -> SamplerState: """ Creates the structure holding the state of the sampler. If you want reproducible samples, you should specify `seed`, otherwise the state will be initialised randomly. If running across several MPI processes, all sampler_states are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating n_rank seeds starting from the reduced one, and every rank is initialized with one of those seeds. The resulting state is guaranteed to be a frozen python dataclass (in particular, a flax's dataclass), and it can be serialized using Flax serialization methods. Args: machine: a Flax module or callable with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used. Returns: The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use. """ key = nkjax.PRNGKey(seed) return sampler._init_state(get_afun_if_module(machine), parameters, nkjax.mpi_split(key))
def to_array(hilbert, apply_fun, variables, normalize=True, allgather=True): """ Computes `apply_fun(variables, states)` on all states of `hilbert` and returns the results as a vector. Args: normalize: If True, the vector is normalized to have L2-norm 1. allgather: If True, the final wave function is stored in full at all MPI ranks. """ if not hilbert.is_indexable: raise RuntimeError("The hilbert space is not indexable") apply_fun = get_afun_if_module(apply_fun) # mpi4jax does not have (yet) allgatherv so we need to be creative # could be made easier if we update mpi4jax n_states = hilbert.n_states n_states_padded = int(np.ceil(n_states / mpi.n_nodes)) * mpi.n_nodes states_n = np.arange(n_states) fake_states_n = np.arange(n_states_padded - n_states) # divide the hilbert space in chunks for each node states_per_rank = np.split(np.concatenate([states_n, fake_states_n]), mpi.n_nodes) xs = hilbert.numbers_to_states(states_per_rank[mpi.rank]) return _to_array_rank(apply_fun, variables, xs, n_states, normalize, allgather)
def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable: """ Returns a closure with the log_pdf function encoded by this sampler. Note: the result is returned as an HashablePartial so that the closure does not trigger recompilation. Args: model: The machine, or apply_fun Returns: the log probability density function """ apply_fun = get_afun_if_module(model) log_pdf = HashablePartial( lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).real, apply_fun, ) return log_pdf
def to_array(hilbert, machine, params, normalize=True): import numpy as np from jax import numpy as jnp from netket.utils import get_afun_if_module machine = get_afun_if_module(machine) if hilbert.is_indexable: xs = hilbert.all_states() psi = machine(params, xs) logmax = psi.real.max() psi = jnp.exp(psi - logmax) if normalize: norm = jnp.linalg.norm(psi) psi /= norm return psi else: raise RuntimeError("The hilbert space is not indexable")
def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable: """ Returns a closure with the log-pdf function encoded by this sampler. Args: model: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. Returns: The log-probability density function. Note: The result is returned as a `HashablePartial` so that the closure does not trigger recompilation. """ apply_fun = get_afun_if_module(model) log_pdf = HashablePartial( lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).real, apply_fun, ) return log_pdf
def to_array(hilbert, apply_fun, variables, normalize=True): if not hilbert.is_indexable: raise RuntimeError("The hilbert space is not indexable") apply_fun = get_afun_if_module(apply_fun) # mpi4jax does not have (yet) allgatherv so we need to be creative # could be made easier if we update mpi4jax n_states = hilbert.n_states n_states_fake = int(np.ceil(n_states / mpi.n_nodes)) * mpi.n_nodes n_fake_states = n_states_fake - n_states states_n = np.arange(n_states) fake_states_n = np.arange(n_states_fake - n_states) # divide the hilbert space in chunks for each node states_per_rank = np.split(np.concatenate([states_n, fake_states_n]), mpi.n_nodes) xs = hilbert.numbers_to_states(states_per_rank[mpi.rank]) return _to_array_rank(apply_fun, variables, xs, n_states, normalize)
def reset( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> SamplerState: """ Resets the state of the sampler. To be used every time the parameters are changed. Args: machine: a Flax module or callable with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If it's not provided, it will be constructed by calling :code:`sampler.init_state(machine, parameters)` with a random seed. Returns: A valid sampler state. """ if state is None: state = sampler_state(sampler, machine, parameters) return sampler._reset(get_afun_if_module(machine), parameters, state)
def to_matrix(hilbert, machine, params, normalize=True): import numpy as np from jax import numpy as jnp from netket.utils import get_afun_if_module machine = get_afun_if_module(machine) if hilbert.is_indexable: xs = hilbert.all_states() psi = machine(params, xs) logmax = psi.real.max() psi = jnp.exp(psi - logmax) L = hilbert.physical.n_states rho = psi.reshape((L, L)) if normalize: trace = jnp.trace(rho) rho /= trace return rho else: raise RuntimeError("The hilbert space is not indexable")
def sample_next( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> Tuple[jnp.ndarray, SamplerState]: """ Samples the next state in the markov chain. Args: machine: a Flax module or callable apply function with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If it's not provided, it will be constructed by calling :code:`sampler.reset(machine, parameters)` with a random seed. Returns: state: The new state of the sampler σ: The next batch of samples. """ if state is None: state = sampler_state(sampler, machine, parameters) return sampler._sample_next(get_afun_if_module(machine), parameters, state)