Example #1
0
    def init_state(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        seed: Optional[SeedType] = None,
    ) -> SamplerState:
        """
        Creates the structure holding the state of the sampler.

        If you want reproducible samples, you should specify `seed`, otherwise the state
        will be initialised randomly.

        If running across several MPI processes, all sampler_states are guaranteed to be
        in a different (but deterministic) state.
        This is achieved by first reducing (summing) the seed provided to every MPI rank,
        then generating n_rank seeds starting from the reduced one, and every rank is
        initialized with one of those seeds.

        The resulting state is guaranteed to be a frozen python dataclass (in particular,
        a flax's dataclass), and it can be serialized using Flax serialization methods.

        Args:
            machine: a Flax module or callable with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used.

        Returns:
            The structure holding the state of the sampler. In general you should not expect
            it to be in a valid state, and should reset it before use.
        """
        key = nkjax.PRNGKey(seed)

        return sampler._init_state(get_afun_if_module(machine), parameters,
                                   nkjax.mpi_split(key))
Example #2
0
def to_array(hilbert, apply_fun, variables, normalize=True, allgather=True):
    """
    Computes `apply_fun(variables, states)` on all states of `hilbert` and returns
      the results as a vector.

    Args:
        normalize: If True, the vector is normalized to have L2-norm 1.
        allgather: If True, the final wave function is stored in full at all MPI ranks.
    """
    if not hilbert.is_indexable:
        raise RuntimeError("The hilbert space is not indexable")

    apply_fun = get_afun_if_module(apply_fun)

    # mpi4jax does not have (yet) allgatherv so we need to be creative
    # could be made easier if we update mpi4jax
    n_states = hilbert.n_states
    n_states_padded = int(np.ceil(n_states / mpi.n_nodes)) * mpi.n_nodes
    states_n = np.arange(n_states)
    fake_states_n = np.arange(n_states_padded - n_states)

    # divide the hilbert space in chunks for each node
    states_per_rank = np.split(np.concatenate([states_n, fake_states_n]),
                               mpi.n_nodes)

    xs = hilbert.numbers_to_states(states_per_rank[mpi.rank])

    return _to_array_rank(apply_fun, variables, xs, n_states, normalize,
                          allgather)
Example #3
0
    def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable:
        """
        Returns a closure with the log_pdf function encoded by this sampler.

        Note: the result is returned as an HashablePartial so that the closure
        does not trigger recompilation.

        Args:
            model: The machine, or apply_fun

        Returns:
            the log probability density function
        """
        apply_fun = get_afun_if_module(model)
        log_pdf = HashablePartial(
            lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).real,
            apply_fun,
        )
        return log_pdf
Example #4
0
def to_array(hilbert, machine, params, normalize=True):
    import numpy as np
    from jax import numpy as jnp
    from netket.utils import get_afun_if_module

    machine = get_afun_if_module(machine)

    if hilbert.is_indexable:
        xs = hilbert.all_states()
        psi = machine(params, xs)
        logmax = psi.real.max()
        psi = jnp.exp(psi - logmax)

        if normalize:
            norm = jnp.linalg.norm(psi)
            psi /= norm

        return psi
    else:
        raise RuntimeError("The hilbert space is not indexable")
Example #5
0
    def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable:
        """
        Returns a closure with the log-pdf function encoded by this sampler.

        Args:
            model: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.

        Returns:
            The log-probability density function.

        Note:
            The result is returned as a `HashablePartial` so that the closure
            does not trigger recompilation.
        """
        apply_fun = get_afun_if_module(model)
        log_pdf = HashablePartial(
            lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).real,
            apply_fun,
        )
        return log_pdf
Example #6
0
def to_array(hilbert, apply_fun, variables, normalize=True):

    if not hilbert.is_indexable:
        raise RuntimeError("The hilbert space is not indexable")

    apply_fun = get_afun_if_module(apply_fun)

    # mpi4jax does not have (yet) allgatherv so we need to be creative
    # could be made easier if we update mpi4jax
    n_states = hilbert.n_states
    n_states_fake = int(np.ceil(n_states / mpi.n_nodes)) * mpi.n_nodes
    n_fake_states = n_states_fake - n_states
    states_n = np.arange(n_states)
    fake_states_n = np.arange(n_states_fake - n_states)

    # divide the hilbert space in chunks for each node
    states_per_rank = np.split(np.concatenate([states_n, fake_states_n]),
                               mpi.n_nodes)

    xs = hilbert.numbers_to_states(states_per_rank[mpi.rank])
    return _to_array_rank(apply_fun, variables, xs, n_states, normalize)
Example #7
0
    def reset(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> SamplerState:
        """
        Resets the state of the sampler. To be used every time the parameters are changed.

        Args:
            machine: a Flax module or callable with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If it's not provided, it will be constructed
                by calling :code:`sampler.init_state(machine, parameters)` with a random seed.

        Returns:
            A valid sampler state.
        """
        if state is None:
            state = sampler_state(sampler, machine, parameters)

        return sampler._reset(get_afun_if_module(machine), parameters, state)
Example #8
0
def to_matrix(hilbert, machine, params, normalize=True):
    import numpy as np
    from jax import numpy as jnp
    from netket.utils import get_afun_if_module

    machine = get_afun_if_module(machine)

    if hilbert.is_indexable:
        xs = hilbert.all_states()
        psi = machine(params, xs)
        logmax = psi.real.max()
        psi = jnp.exp(psi - logmax)

        L = hilbert.physical.n_states
        rho = psi.reshape((L, L))
        if normalize:
            trace = jnp.trace(rho)
            rho /= trace

        return rho
    else:
        raise RuntimeError("The hilbert space is not indexable")
Example #9
0
    def sample_next(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Samples the next state in the markov chain.

        Args:
            machine: a Flax module or callable apply function with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If it's not provided, it will be constructed
                by calling :code:`sampler.reset(machine, parameters)` with a random seed.

        Returns:
            state: The new state of the sampler
            σ: The next batch of samples.
        """
        if state is None:
            state = sampler_state(sampler, machine, parameters)

        return sampler._sample_next(get_afun_if_module(machine), parameters, state)