예제 #1
0
class Point1:
    x: float
    y: float
    meta: Any = struct.field(pytree_node=False)

    @struct.property_cached
    def cached_node(self) -> int:
        return 3
예제 #2
0
class Point0:
    x: float
    y: float
    meta: Any = struct.field(pytree_node=False)

    def __pre_init__(self, *args, **kwargs):
        if "z" in kwargs:
            kwargs["x"] = kwargs.pop("z") * 10

        return args, kwargs

    @struct.property_cached
    def cached_node(self) -> int:
        return 3
예제 #3
0
class MetropolisSampler(Sampler):
    r"""
    Metropolis-Hastings sampler for an Hilbert space according to a specific transition rule.

    The transition rule is used to generate a proposed state :math:`s^\prime`, starting from the
    current state :math:`s`. The move is accepted with probability

    .. math::

        A(s \rightarrow s^\prime) = \mathrm{min} \left( 1,\frac{P(s^\prime)}{P(s)} F(e^{L(s,s^\prime)}) \right) ,

    where the probability being sampled from is :math:`P(s)=|M(s)|^p. Here ::math::`M(s)` is a
    user-provided function (the machine), :math:`p` is also user-provided with default value :math:`p=2`,
    and :math:`L(s,s^\prime)` is a suitable correcting factor computed by the transition kernel.

    The dtype of the sampled states can be chosen.
    """

    rule: MetropolisRule = None
    """The metropolis transition rule."""
    n_sweeps: int = struct.field(pytree_node=False, default=None)
    """Number of sweeps for each step along the chain. Defaults to number of sites in hilbert space."""
    reset_chains: bool = struct.field(pytree_node=False, default=False)
    """If True resets the chain state when reset is called (every new sampling)."""
    def __pre_init__(self, hilbert, rule, **kwargs):
        r"""
        Constructs a Metropolis Sampler.

        Args:
            hilbert: The hilbert space to sample
            rule: A `MetropolisRule` to generate random transitions from a given state as
                    well as uniform random states.
            n_sweeps: The number of exchanges that compose a single sweep.
                    If None, sweep_size is equal to the number of degrees of freedom being sampled
                    (the size of the input vector s to the machine).
            reset_chains: If False the state configuration is not resetted when reset() is called.
            n_chains: The number of Markov Chain to be run in parallel on a single process.
            machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
            dtype: The dtype of the statees sampled (default = np.float32).
        """
        # process arguments in the base
        args, kwargs = super().__pre_init__(hilbert=hilbert, **kwargs)

        kwargs["rule"] = rule

        # deprecation warnings
        if "reset_chain" in kwargs:
            warn_deprecation(
                "The keyword argument `reset_chain` is deprecated in favour of `reset_chains`"
            )
            kwargs["reset_chains"] = kwargs.pop("reset_chain")

        return args, kwargs

    def __post_init__(self):
        super().__post_init__()
        # Validate the inputs
        if not isinstance(self.rule, MetropolisRule):
            raise TypeError("rule must be a MetropolisRule.")

        if not isinstance(self.reset_chains, bool):
            raise TypeError("reset_chains must be a boolean.")

        #  Default value of n_sweeps
        if self.n_sweeps is None:
            object.__setattr__(self, "n_sweeps", self.hilbert.size)

    def _init_state(sampler, machine, params, key):
        key_state, key_rule = jax.random.split(key, 2)
        rule_state = sampler.rule.init_state(sampler, machine, params,
                                             key_rule)
        σ = jnp.zeros((sampler.n_chains_per_rank, sampler.hilbert.size),
                      dtype=sampler.dtype)

        state = MetropolisSamplerState(σ=σ,
                                       rng=key_state,
                                       rule_state=rule_state)

        # If we don't reset the chain at every sampling iteration, then reset it
        # now.
        if not sampler.reset_chains:
            key_state, rng = jax.random.split(key_state)
            σ = sampler.rule.random_state(sampler, machine, params, state, rng)
            state = state.replace(σ=σ, rng=key_state)

        return state

    def _reset(sampler, machine, parameters, state):
        new_rng, rng = jax.random.split(state.rng)

        if sampler.reset_chains:
            σ = sampler.rule.random_state(sampler, machine, parameters, state,
                                          rng)
        else:
            σ = state.σ

        rule_state = sampler.rule.reset(sampler, machine, parameters, state)

        return state.replace(σ=σ,
                             rng=new_rng,
                             rule_state=rule_state,
                             n_steps_proc=0,
                             n_accepted_proc=0)

    def _sample_next(sampler, machine, parameters, state):
        new_rng, rng = jax.random.split(state.rng)

        with loops.Scope() as s:
            s.key = rng
            s.σ = state.σ
            s.log_prob = sampler.machine_pow * machine.apply(
                parameters, state.σ).real

            # for logging
            s.accepted = state.n_accepted_proc

            for i in s.range(sampler.n_sweeps):
                # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
                s.key, key1, key2 = jax.random.split(s.key, 3)

                σp, log_prob_correction = sampler.rule.transition(
                    sampler, machine, parameters, state, key1, s.σ)
                proposal_log_prob = (sampler.machine_pow *
                                     machine.apply(parameters, σp).real)

                uniform = jax.random.uniform(
                    key2, shape=(sampler.n_chains_per_rank, ))
                if log_prob_correction is not None:
                    do_accept = uniform < jnp.exp(proposal_log_prob -
                                                  s.log_prob +
                                                  log_prob_correction)
                else:
                    do_accept = uniform < jnp.exp(proposal_log_prob -
                                                  s.log_prob)

                # do_accept must match ndim of proposal and state (which is 2)
                s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ)
                s.accepted += do_accept.sum()

                s.log_prob = jax.numpy.where(do_accept.reshape(-1),
                                             proposal_log_prob, s.log_prob)

            new_state = state.replace(
                rng=new_rng,
                σ=s.σ,
                n_accepted_proc=s.accepted,
                n_steps_proc=state.n_steps_proc +
                sampler.n_sweeps * sampler.n_chains_per_rank,
            )

        return new_state, new_state.σ

    def __repr__(sampler):
        return (f"{type(sampler).__name__}(" +
                "\n  hilbert = {},".format(sampler.hilbert) +
                "\n  rule = {},".format(sampler.rule) +
                "\n  n_chains = {},".format(sampler.n_chains) +
                "\n  machine_power = {},".format(sampler.machine_pow) +
                "\n  reset_chains = {},".format(sampler.reset_chains) +
                "\n  n_sweeps = {},".format(sampler.n_sweeps) +
                "\n  dtype = {}".format(sampler.dtype) + ")")

    def __str__(sampler):
        return (f"{type(sampler).__name__}(" +
                "rule = {}, ".format(sampler.rule) +
                "n_chains = {}, ".format(sampler.n_chains) +
                "machine_power = {}, ".format(sampler.machine_pow) +
                "n_sweeps = {}, ".format(sampler.n_sweeps) +
                "dtype = {})".format(sampler.dtype))
예제 #4
0
파일: base.py 프로젝트: tobiaswiener/netket
class Sampler(abc.ABC):
    """
    Abstract base class for all samplers.

    It contains the fields that all of them should possess, defining the common
    API.
    Note that fields marked with `pytree_node=False` are treated as static arguments
    when jitting.
    """

    hilbert: AbstractHilbert = struct.field(pytree_node=False)
    """The Hilbert space to sample."""

    n_chains_per_rank: int = struct.field(pytree_node=False, default=None)
    """Number of independent chains on every MPI rank."""

    machine_pow: int = struct.field(default=2)
    """The power to which the machine should be exponentiated to generate the pdf."""

    dtype: DType = struct.field(pytree_node=False, default=np.float64)
    """The dtype of the states sampled."""
    def __pre_init__(self,
                     hilbert: AbstractHilbert,
                     n_chains: Optional[int] = None,
                     **kwargs):
        """
        Construct a Monte Carlo sampler.

        Args:
            hilbert: The Hilbert space to sample.
            n_chains: The total number of independent chains across all MPI ranks. Either specify this or `n_chains_per_rank`.
            n_chains_per_rank: Number of independent chains on every MPI rank (default = 1).
            machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
            dtype: The dtype of the states sampled (default = np.float64).
        """

        if "n_chains_per_rank" in kwargs:
            if n_chains is not None:
                raise ValueError(
                    "Cannot specify both `n_chains` and `n_chains_per_rank`")
        else:
            if n_chains is None:
                # Default value
                n_chains_per_rank = 1
            else:
                n_chains_per_rank = max(int(np.ceil(n_chains / mpi.n_nodes)),
                                        1)
                if mpi.n_nodes > 1 and mpi.rank == 0:
                    if n_chains_per_rank * mpi.n_nodes != n_chains:
                        import warnings

                        warnings.warn(
                            f"Using {n_chains_per_rank} chains per rank among {mpi.n_nodes} ranks "
                            f"(total={n_chains_per_rank * mpi.n_nodes} instead of n_chains={n_chains}). "
                            f"To directly control the number of chains on every rank, specify "
                            f"`n_chains_per_rank` when constructing the sampler. "
                            f"To silence this warning, either use `n_chains_per_rank` or use `n_chains` "
                            f"that is a multiple of the number of MPI ranks.",
                            category=UserWarning,
                        )

            kwargs["n_chains_per_rank"] = n_chains_per_rank

        return (hilbert, ), kwargs

    def __post_init__(self):
        # Raise errors if hilbert is not an Hilbert
        if not isinstance(self.hilbert, AbstractHilbert):
            raise ValueError(
                "hilbert must be a subtype of netket.hilbert.AbstractHilbert, "
                + "instead, type {} is not.".format(type(self.hilbert)))

        # workaround Jax bug under pmap
        # might be removed in the future
        if type(self.machine_pow) != object:
            if not np.issubdtype(numbers.dtype(self.machine_pow), np.integer):
                raise ValueError(
                    f"machine_pow ({self.machine_pow}) must be a positive integer"
                )

    @property
    def n_chains(self) -> int:
        """
        The total number of independent chains across all MPI ranks.

        If you are not using MPI, this is equal to `n_chains_per_rank`.
        """
        return self.n_chains_per_rank * mpi.n_nodes

    @property
    def n_batches(self) -> int:
        r"""
        The batch size of the configuration $\sigma$ used by this sampler.

        In general, it is equivalent to :attr:`~Sampler.n_chains_per_rank`.
        """
        return self.n_chains_per_rank

    @property
    def is_exact(self) -> bool:
        """
        Returns `True` if the sampler is exact.

        The sampler is exact if all the samples are exactly distributed according to the
        chosen power of the variational state, and there is no correlation among them.
        """
        return False

    def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable:
        """
        Returns a closure with the log-pdf function encoded by this sampler.

        Args:
            model: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.

        Returns:
            The log-probability density function.

        Note:
            The result is returned as a `HashablePartial` so that the closure
            does not trigger recompilation.
        """
        apply_fun = get_afun_if_module(model)
        log_pdf = HashablePartial(
            lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).
            real,
            apply_fun,
        )
        return log_pdf

    def init_state(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        seed: Optional[SeedT] = None,
    ) -> SamplerState:
        """
        Creates the structure holding the state of the sampler.

        If you want reproducible samples, you should specify `seed`, otherwise the state
        will be initialised randomly.

        If running across several MPI processes, all `sampler_state`s are guaranteed to be
        in a different (but deterministic) state.
        This is achieved by first reducing (summing) the seed provided to every MPI rank,
        then generating `n_rank` seeds starting from the reduced one, and every rank is
        initialized with one of those seeds.

        The resulting state is guaranteed to be a frozen Python dataclass (in particular,
        a Flax dataclass), and it can be serialized using Flax serialization methods.

        Args:
            machine: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.
            parameters: The PyTree of parameters of the model.
            seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used.

        Returns:
            The structure holding the state of the sampler. In general you should not expect
            it to be in a valid state, and should reset it before use.
        """
        key = nkjax.PRNGKey(seed)
        key = nkjax.mpi_split(key)

        return sampler._init_state(wrap_afun(machine), parameters, key)

    def reset(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> SamplerState:
        """
        Resets the state of the sampler. To be used every time the parameters are changed.

        Args:
            machine: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If not specified, it will be constructed
                by calling :code:`sampler.init_state(machine, parameters)` with a random seed.

        Returns:
            A valid sampler state.
        """
        if state is None:
            state = sampler.init_state(machine, parameters)

        return sampler._reset(wrap_afun(machine), parameters, state)

    def sample(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        *,
        state: Optional[SamplerState] = None,
        chain_length: int = 1,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Samples `chain_length` batches of samples along the chains.

        Arguments:
            machine: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If not specified, then initialize and reset it.
            chain_length: The length of the chains (default = 1).

        Returns:
            σ: The generated batches of samples.
            state: The new state of the sampler.
        """
        if state is None:
            state = sampler.reset(machine, parameters)

        return sampler._sample_chain(wrap_afun(machine), parameters, state,
                                     chain_length)

    def samples(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        *,
        state: Optional[SamplerState] = None,
        chain_length: int = 1,
    ) -> Iterator[jnp.ndarray]:
        """
        Returns a generator sampling `chain_length` batches of samples along the chains.

        Arguments:
            machine: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If not specified, then initialize and reset it.
            chain_length: The length of the chains (default = 1).
        """
        if state is None:
            state = sampler.reset(machine, parameters)

        machine = wrap_afun(machine)

        for i in range(chain_length):
            samples, state = sampler._sample_chain(machine, parameters, state,
                                                   1)
            yield samples[0, :, :]

    @abc.abstractmethod
    def _sample_chain(
        sampler,
        machine: nn.Module,
        parameters: PyTree,
        state: SamplerState,
        chain_length: int,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Implementation of `sample` for subclasses of `Sampler`.

        If you subclass `Sampler`, you should override this and not `sample`
        itself, because `sample` contains some common logic.

        If using Jax, this function should be jitted.

        Arguments:
            machine: A Flax module with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler.
            chain_length: The length of the chains.

        Returns:
            σ: The generated batches of samples.
            state: The new state of the sampler.
        """

    @abc.abstractmethod
    def _init_state(sampler, machine, params, seed) -> SamplerState:
        """
        Implementation of `init_state` for subclasses of `Sampler`.

        If you subclass `Sampler`, you should override this and not `init_state`
        itself, because `init_state` contains some common logic.
        """

    @abc.abstractmethod
    def _reset(sampler, machine, parameters, state):
        """
예제 #5
0
class MetropolisPtSampler(MetropolisSampler):
    """
    Metropolis-Hastings with Parallel Tempering sampler.

    This sampler samples an Hilbert space, producing samples off a specific dtype.
    The samples are generated according to a transition rule that must be
    specified.
    """

    n_replicas: int = struct.field(pytree_node=False, default=32)
    """The number of replicas"""
    def __post_init__(self):
        if not config.FLAGS["NETKET_EXPERIMENTAL"]:
            raise RuntimeError("""
                               Parallel Tempering samplers are under development and
                               are known not to work.

                               If you want to debug it, set the environment variable
                               NETKET_EXPERIMENTAL=1
                               """)

        super().__post_init__()
        if (not isinstance(self.n_replicas, int) and self.n_replicas > 0
                and np.mod(self.n_replicas, 2) == 0):
            raise ValueError("n_replicas must be an even integer > 0.")

    @property
    def n_batches(self):
        return self.n_chains * self.n_replicas

    def _init_state(sampler, machine, params: PyTree,
                    key: PRNGKeyT) -> MetropolisPtSamplerState:
        key_state, key_rule = jax.random.split(key, 2)
        σ = jnp.zeros(
            (sampler.n_batches, sampler.hilbert.size),
            dtype=sampler.dtype,
        )
        rule_state = sampler.rule.init_state(sampler, machine, params,
                                             key_rule)

        beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas
        beta = jnp.tile(beta, (sampler.n_chains, 1))

        return MetropolisPtSamplerState(
            σ=σ,
            rng=key_state,
            rule_state=rule_state,
            n_steps_proc=0,
            n_accepted_proc=0,
            beta=beta,
            beta_0_index=jnp.zeros((sampler.n_chains, ), dtype=int),
            n_accepted_per_beta=jnp.zeros(
                (sampler.n_chains, sampler.n_replicas), dtype=int),
            beta_position=jnp.zeros((sampler.n_chains, )),
            beta_diffusion=jnp.zeros((sampler.n_chains, )),
            exchange_steps=0,
        )

    def _reset(sampler, machine, parameters: PyTree,
               state: MetropolisPtSamplerState):
        new_rng, rng = jax.random.split(state.rng)

        σ = sampler.rule.random_state(sampler, machine, parameters, state, rng)

        rule_state = sampler.rule.reset(sampler, machine, parameters, state)

        beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas
        beta = jnp.tile(beta, (sampler.n_chains, 1))

        return state.replace(
            σ=σ,
            rng=new_rng,
            rule_state=rule_state,
            n_steps_proc=0,
            n_accepted_proc=0,
            n_accepted_per_beta=jnp.zeros(
                (sampler.n_chains, sampler.n_replicas)),
            beta_position=jnp.zeros((sampler.n_chains, )),
            beta_diffusion=jnp.zeros((sampler.n_chains)),
            exchange_steps=0,
            # beta=beta,
            # beta_0_index=jnp.zeros((sampler.n_chains,), dtype=jnp.int32),
        )

    def _sample_next(sampler, machine, parameters: PyTree,
                     state: MetropolisPtSamplerState):
        new_rng, rng = jax.random.split(state.rng)
        # def cbr(data):
        #    new_rng, rng = data
        #    print("sample_next newrng:\n", new_rng,  "\nand rng:\n", rng)
        #    return new_rng
        # new_rng = hcb.call(
        #   cbr,
        #   (new_rng, rng),
        #   result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype),
        # )

        with loops.Scope() as s:
            s.key = rng
            s.σ = state.σ
            s.log_prob = sampler.machine_pow * machine.apply(
                parameters, state.σ).real
            s.beta = state.beta

            # for logging
            s.beta_0_index = state.beta_0_index
            s.n_accepted_per_beta = state.n_accepted_per_beta
            s.beta_position = state.beta_position
            s.beta_diffusion = state.beta_diffusion

            for i in s.range(sampler.n_sweeps):
                # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
                s.key, key1, key2, key3, key4 = jax.random.split(s.key, 5)

                # def cbi(data):
                #    i, beta = data
                #    print("sweep #", i, " for beta=\n", beta)
                #    return beta
                # beta = hcb.call(
                #   cbi,
                #   (i, s.beta),
                #   result_shape=jax.ShapeDtypeStruct(s.beta.shape, s.beta.dtype),
                # )
                beta = s.beta

                σp, log_prob_correction = sampler.rule.transition(
                    sampler, machine, parameters, state, key1, s.σ)
                proposal_log_prob = (sampler.machine_pow *
                                     machine.apply(parameters, σp).real)

                uniform = jax.random.uniform(key2, shape=(sampler.n_batches, ))
                if log_prob_correction is not None:
                    do_accept = uniform < jnp.exp(
                        beta.reshape((-1, )) *
                        (proposal_log_prob - s.log_prob + log_prob_correction))
                else:
                    do_accept = uniform < jnp.exp(
                        beta.reshape(
                            (-1, )) * (proposal_log_prob - s.log_prob))

                # do_accept must match ndim of proposal and state (which is 2)
                s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ)
                n_accepted_per_beta = s.n_accepted_per_beta + do_accept.reshape(
                    (sampler.n_chains, sampler.n_replicas))

                s.log_prob = jax.numpy.where(do_accept.reshape(-1),
                                             proposal_log_prob, s.log_prob)

                # exchange betas

                # randomly decide if every set of replicas should be swapped in even or odd order
                swap_order = jax.random.randint(
                    key3,
                    minval=0,
                    maxval=2,
                    shape=(sampler.n_chains, ),
                )  # 0 or 1
                iswap_order = jnp.mod(swap_order + 1, 2)  #  1 or 0

                # indices of even swapped elements (per-row)
                idxs = jnp.arange(0, sampler.n_replicas, 2).reshape(
                    (1, -1)) + swap_order.reshape((-1, 1))
                # indices off odd swapped elements (per-row)
                inn = (idxs + 1) % sampler.n_replicas

                # for every rows of the input, swap elements at idxs with elements at inn
                @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
                def swap_rows(beta_row, idxs, inn):
                    proposed_beta = jax.ops.index_update(
                        beta_row,
                        idxs,
                        beta_row[inn],
                        unique_indices=True,
                        indices_are_sorted=True,
                    )
                    proposed_beta = jax.ops.index_update(
                        proposed_beta,
                        inn,
                        beta_row[idxs],
                        unique_indices=True,
                        indices_are_sorted=False,
                    )
                    return proposed_beta

                proposed_beta = swap_rows(beta, idxs, inn)

                @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
                def compute_proposed_prob(prob, idxs, inn):
                    prob_rescaled = prob[idxs] + prob[inn]
                    return prob_rescaled

                # compute the probability of the swaps
                log_prob = (proposed_beta - state.beta) * s.log_prob.reshape(
                    (sampler.n_chains, sampler.n_replicas))

                prob_rescaled = jnp.exp(
                    compute_proposed_prob(log_prob, idxs, inn))

                prob_rescaled = jnp.exp(
                    compute_proposed_prob(log_prob, idxs, inn))

                uniform = jax.random.uniform(key4,
                                             shape=(sampler.n_chains,
                                                    sampler.n_replicas // 2))

                do_swap = uniform < prob_rescaled

                do_swap = jnp.dstack((do_swap, do_swap)).reshape(
                    (-1, sampler.n_replicas))  #  concat along last dimension
                # roll if swap_ordeer is odd
                @partial(jax.vmap, in_axes=(0, 0), out_axes=0)
                def fix_swap(do_swap, swap_order):
                    return jax.lax.cond(swap_order == 0, lambda x: x,
                                        lambda x: jnp.roll(x, 1), do_swap)

                do_swap = fix_swap(do_swap, swap_order)
                # jax.experimental.host_callback.id_print(state.beta)
                # jax.experimental.host_callback.id_print(proposed_beta)

                new_beta = jax.numpy.where(do_swap, proposed_beta, beta)

                def cb(data):
                    _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data
                    print("--------.---------.---------.--------")
                    print("     cur beta:\n", _bt)
                    print("proposed beta:\n", _pbt)
                    print("     new beta:\n", new_beta)
                    print("swaporder :", so)
                    print("do_swap :\n", do_swap)
                    print("log_prob;\n", log_prob)
                    print("prob_rescaled;\n", prob)
                    return new_beta

                # new_beta = hcb.call(
                #    cb,
                #    (
                #        beta,
                #        proposed_beta,
                #        new_beta,
                #        swap_order,
                #        do_swap,
                #        log_prob,
                #        prob_rescaled,
                #    ),
                #    result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype),
                # )
                # s.beta = new_beta

                swap_order = swap_order.reshape(-1)

                beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i],
                                        in_axes=(0, 0),
                                        out_axes=0)(do_swap,
                                                    state.beta_0_index)
                proposed_beta_0_index = jnp.mod(
                    state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) *
                    (-jnp.mod(state.beta_0_index, 2) * 2 + 1),
                    sampler.n_replicas,
                )

                s.beta_0_index = jnp.where(beta_0_moved, proposed_beta_0_index,
                                           s.beta_0_index)

                # swap acceptances
                swapped_n_accepted_per_beta = swap_rows(
                    n_accepted_per_beta, idxs, inn)
                s.n_accepted_per_beta = jax.numpy.where(
                    do_swap,
                    swapped_n_accepted_per_beta,
                    n_accepted_per_beta,
                )

                # Update statistics to compute diffusion coefficient of replicas
                # Total exchange steps performed
                delta = s.beta_0_index - s.beta_position
                s.beta_position = s.beta_position + delta / (
                    state.exchange_steps + i)
                delta2 = s.beta_0_index - s.beta_position
                s.beta_diffusion = s.beta_diffusion + delta * delta2

            new_state = state.replace(
                rng=new_rng,
                σ=s.σ,
                # n_accepted=s.accepted,
                n_steps_proc=state.n_steps_proc +
                sampler.n_sweeps * sampler.n_chains,
                beta=s.beta,
                beta_0_index=s.beta_0_index,
                beta_position=s.beta_position,
                beta_diffusion=s.beta_diffusion,
                exchange_steps=state.exchange_steps + sampler.n_sweeps,
                n_accepted_per_beta=s.n_accepted_per_beta,
            )

        offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas,
                             sampler.n_replicas)

        return new_state, new_state.σ[new_state.beta_0_index + offsets, :]

    def __repr__(sampler):
        return ("MetropolisPTSampler(" +
                "\n  hilbert = {},".format(sampler.hilbert) +
                "\n  rule = {},".format(sampler.rule) +
                "\n  n_chains = {},".format(sampler.n_chains) +
                "\n  machine_power = {},".format(sampler.machine_pow) +
                "\n  reset_chain = {},".format(sampler.reset_chain) +
                "\n  n_sweeps = {},".format(sampler.n_sweeps) +
                "\n  dtype = {},".format(sampler.dtype) + ")")

    def __str__(sampler):
        return ("MetropolisPTSampler(" + "rule = {}, ".format(sampler.rule) +
                "n_chains = {}, ".format(sampler.n_chains) +
                "machine_power = {}, ".format(sampler.machine_pow) +
                "reset_chain = {}, ".format(sampler.reset_chain) +
                "n_sweeps = {}, ".format(sampler.n_sweeps) +
                "dtype = {})".format(sampler.dtype))
예제 #6
0
class MetropolisSampler(Sampler):
    r"""
    Metropolis-Hastings sampler for a Hilbert space according to a specific transition rule.

    The transition rule is used to generate a proposed state :math:`s^\prime`, starting from the
    current state :math:`s`. The move is accepted with probability

    .. math::

        A(s \rightarrow s^\prime) = \mathrm{min} \left( 1,\frac{P(s^\prime)}{P(s)} e^{L(s,s^\prime)} \right) ,

    where the probability being sampled from is :math:`P(s)=|M(s)|^p`. Here :math:`M(s)` is a
    user-provided function (the machine), :math:`p` is also user-provided with default value :math:`p=2`,
    and :math:`L(s,s^\prime)` is a suitable correcting factor computed by the transition kernel.

    The dtype of the sampled states can be chosen.
    """

    rule: MetropolisRule = None
    """The Metropolis transition rule."""
    n_sweeps: int = struct.field(pytree_node=False, default=None)
    """Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space."""
    reset_chains: bool = struct.field(pytree_node=False, default=False)
    """If True, resets the chain state when `reset` is called on every new sampling."""
    def __pre_init__(self, hilbert, rule, **kwargs):
        """
        Constructs a Metropolis Sampler.

        Args:
            hilbert: The Hilbert space to sample.
            rule: A `MetropolisRule` to generate random transitions from a given state as
                    well as uniform random states.
            n_chains: The total number of independent Markov chains across all MPI ranks. Either specify this or `n_chains_per_rank`.
            n_chains_per_rank: Number of independent chains on every MPI rank (default = 16).
            n_sweeps: Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space.
                    This is equivalent to subsampling the Markov chain.
            reset_chains: If True, resets the chain state when `reset` is called on every new sampling (default = False).
            machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2).
            dtype: The dtype of the states sampled (default = np.float64).
        """
        if "n_chains" not in kwargs and "n_chains_per_rank" not in kwargs:
            kwargs["n_chains_per_rank"] = 16

        # process arguments in the base
        args, kwargs = super().__pre_init__(hilbert=hilbert, **kwargs)

        kwargs["rule"] = rule

        # deprecation warnings
        if "reset_chain" in kwargs:
            warn_deprecation(
                "The keyword argument `reset_chain` is deprecated in favour of `reset_chains`"
            )
            kwargs["reset_chains"] = kwargs.pop("reset_chain")

        return args, kwargs

    def __post_init__(self):
        super().__post_init__()
        # Validate the inputs
        if not isinstance(self.rule, MetropolisRule):
            raise TypeError("rule must be a MetropolisRule.")

        if not isinstance(self.reset_chains, bool):
            raise TypeError("reset_chains must be a boolean.")

        # Default value of n_sweeps
        if self.n_sweeps is None:
            object.__setattr__(self, "n_sweeps", self.hilbert.size)

    def sample_next(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> Tuple[SamplerState, jnp.ndarray]:
        """
        Samples the next state in the Markov chain.

        Args:
            machine: A Flax module or callable with the forward pass of the log-pdf.
                If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If not specified, then initialize and reset it.

        Returns:
            state: The new state of the sampler.
            σ: The next batch of samples.

        Note:
            The return order is inverted wrt `sample` because when called inside of
            a scan function the first returned argument should be the state.
        """
        if state is None:
            state = sampler.reset(machine, parameters)

        return sampler._sample_next(wrap_afun(machine), parameters, state)

    def _init_state(sampler, machine, params, key):
        key_state, key_rule = jax.random.split(key, 2)
        rule_state = sampler.rule.init_state(sampler, machine, params,
                                             key_rule)
        σ = jnp.zeros((sampler.n_chains_per_rank, sampler.hilbert.size),
                      dtype=sampler.dtype)

        state = MetropolisSamplerState(σ=σ,
                                       rng=key_state,
                                       rule_state=rule_state)

        # If we don't reset the chain at every sampling iteration, then reset it
        # now.
        if not sampler.reset_chains:
            key_state, rng = jax.random.split(key_state)
            σ = sampler.rule.random_state(sampler, machine, params, state, rng)
            state = state.replace(σ=σ, rng=key_state)

        return state

    def _reset(sampler, machine, parameters, state):
        new_rng, rng = jax.random.split(state.rng)

        if sampler.reset_chains:
            σ = sampler.rule.random_state(sampler, machine, parameters, state,
                                          rng)
        else:
            σ = state.σ

        rule_state = sampler.rule.reset(sampler, machine, parameters, state)

        return state.replace(σ=σ,
                             rng=new_rng,
                             rule_state=rule_state,
                             n_steps_proc=0,
                             n_accepted_proc=0)

    def _sample_next(sampler, machine, parameters, state):
        """
        Implementation of `sample_next` for subclasses of `MetropolisSampler`.

        If you subclass `MetropolisSampler`, you should override this and not `sample_next`
        itself, because `sample_next` contains some common logic.
        """
        def loop_body(i, s):
            # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
            s["key"], key1, key2 = jax.random.split(s["key"], 3)

            σp, log_prob_correction = sampler.rule.transition(
                sampler, machine, parameters, state, key1, s["σ"])
            proposal_log_prob = sampler.machine_pow * machine.apply(
                parameters, σp).real

            uniform = jax.random.uniform(key2,
                                         shape=(sampler.n_chains_per_rank, ))
            if log_prob_correction is not None:
                do_accept = uniform < jnp.exp(proposal_log_prob -
                                              s["log_prob"] +
                                              log_prob_correction)
            else:
                do_accept = uniform < jnp.exp(proposal_log_prob -
                                              s["log_prob"])

            # do_accept must match ndim of proposal and state (which is 2)
            s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"])
            s["accepted"] += do_accept.sum()

            s["log_prob"] = jax.numpy.where(do_accept.reshape(-1),
                                            proposal_log_prob, s["log_prob"])

            return s

        new_rng, rng = jax.random.split(state.rng)

        s = {
            "key": rng,
            "σ": state.σ,
            "log_prob":
            sampler.machine_pow * machine.apply(parameters, state.σ).real,
            # for logging
            "accepted": state.n_accepted_proc,
        }
        s = jax.lax.fori_loop(0, sampler.n_sweeps, loop_body, s)

        new_state = state.replace(
            rng=new_rng,
            σ=s["σ"],
            n_accepted_proc=s["accepted"],
            n_steps_proc=state.n_steps_proc +
            sampler.n_sweeps * sampler.n_chains_per_rank,
        )

        return new_state, new_state.σ

    def _sample_chain(sampler, machine, parameters, state, chain_length):
        return _sample_chain(sampler, machine, parameters, state, chain_length)

    def __repr__(sampler):
        return (f"{type(sampler).__name__}(" +
                "\n  hilbert = {},".format(sampler.hilbert) +
                "\n  rule = {},".format(sampler.rule) +
                "\n  n_chains = {},".format(sampler.n_chains) +
                "\n  n_sweeps = {},".format(sampler.n_sweeps) +
                "\n  reset_chains = {},".format(sampler.reset_chains) +
                "\n  machine_power = {},".format(sampler.machine_pow) +
                "\n  dtype = {}".format(sampler.dtype) + ")")

    def __str__(sampler):
        return (f"{type(sampler).__name__}(" +
                "rule = {}, ".format(sampler.rule) +
                "n_chains = {}, ".format(sampler.n_chains) +
                "n_sweeps = {}, ".format(sampler.n_sweeps) +
                "reset_chains = {}, ".format(sampler.reset_chains) +
                "machine_power = {}, ".format(sampler.machine_pow) +
                "dtype = {})".format(sampler.dtype))
예제 #7
0
class MetropolisSamplerState(SamplerState):
    """
    State for a Metropolis sampler.

    Contains the current configuration, the RNG state and the (optional)
    state of the transition rule.
    """

    σ: jnp.ndarray
    """Current batch of configurations in the Markov chain."""
    rng: jnp.ndarray
    """State of the random number generator (key, in jax terms)."""
    rule_state: Optional[Any]
    """Optional state of the transition rule."""

    # those are initialised to 0. We want to initialise them to zero arrays because they can
    # be passed to jax jitted functions that require type invariance to avoid recompilation
    n_steps_proc: int = struct.field(
        default_factory=lambda: jnp.zeros((), dtype=jnp.int64))
    """Number of moves performed along the chains in this process since the last reset."""
    n_accepted_proc: int = struct.field(
        default_factory=lambda: jnp.zeros((), dtype=jnp.int64))
    """Number of accepted transitions among the chains in this process since the last reset."""
    @property
    def acceptance(self) -> float:
        """The fraction of accepted moves across all chains and MPI processes.

        The rate is computed since the last reset of the sampler.
        Will return None if no sampling has been performed since then.
        """
        if self.n_steps == 0:
            return None

        return self.n_accepted / self.n_steps

    @property
    @deprecated("""Please use the attribute `.acceptance` instead of
        `.acceptance_ratio`. The new attribute `.acceptance` returns the
        acceptance ratio ∈ [0,1], instead of the current `acceptance_ratio`
        returning a percentage, which is a bug.""")
    def acceptance_ratio(self):
        """DEPRECATED: Please use the attribute `.acceptance` instead of
        `.acceptance_ratio`. The new attribute `.acceptance` returns the
        acceptance ratio ∈ [0,1], instead of the current `acceptance_ratio`
        returning a percentage, which is a bug.

        The percentage of accepted moves across all chains and MPI processes.

        The rate is computed since the last reset of the sampler.
        Will return None if no sampling has been performed since then.
        """
        return self.acceptance * 100

    @property
    def n_steps(self) -> int:
        """Total number of moves performed across all processes since the last reset."""
        return self.n_steps_proc * mpi.n_nodes

    @property
    def n_accepted(self) -> int:
        """Total number of moves accepted across all processes since the last reset."""
        res, _ = mpi.mpi_sum_jax(self.n_accepted_proc)
        return res

    def __repr__(self):
        if self.n_steps > 0:
            acc_string = "# accepted = {}/{} ({}%), ".format(
                self.n_accepted, self.n_steps, self.acceptance * 100)
        else:
            acc_string = ""

        return f"{type(self).__name__}({acc_string}rng state={self.rng})"
예제 #8
0
class RungeKuttaIntegrator:
    tableau: rkt.NamedTableau

    f: Callable = field(repr=False)
    t0: float
    y0: Array = field(repr=False)

    initial_dt: float

    use_adaptive: bool
    norm: Callable

    atol: float = 0.0
    rtol: float = 1e-7
    dt_limits: Optional[LimitsType] = None

    def __post_init__(self):
        if self.use_adaptive and not self.tableau.data.is_adaptive:
            raise RuntimeError(
                f"Solver {self.tableau} does not support adaptive step size")
        if self.use_adaptive:
            self._do_step = self._do_step_adaptive
        else:
            self._do_step = self._do_step_fixed

        if self.norm is None:
            self.norm = euclidean_norm

        if self.dt_limits is None:
            self.dt_limits = (None, 10 * self.initial_dt)

        self._rkstate = RungeKuttaState(
            step_no=0,
            step_no_total=0,
            t=nk.utils.KahanSum(self.t0),
            y=self.y0,
            dt=self.initial_dt,
            last_norm=0.0 if self.use_adaptive else None,
            flags=SolverFlags(0),
        )

    def step(self, max_dt=None):
        """
        Perform one full Runge-Kutta step by min(self.dt, max_dt).


        Returns:
            A boolean indicating whether the step was successful or
            was rejected by the step controller and should be retried.

            Note that the step size can be adjusted by the step controller
            in both cases, so the integrator state will have changed
            even after a rejected step.
        """
        self._rkstate = self._do_step(self._rkstate, max_dt)
        return self._rkstate.accepted

    def _do_step_fixed(self, rk_state, max_dt=None):
        return general_time_step_fixed(
            self.tableau.data,
            self.f,
            rk_state,
            max_dt=max_dt,
        )

    def _do_step_adaptive(self, rk_state, max_dt=None):
        return general_time_step_adaptive(
            self.tableau.data,
            self.f,
            rk_state,
            atol=self.atol,
            rtol=self.rtol,
            norm_fn=self.norm,
            max_dt=max_dt,
            dt_limits=self.dt_limits,
        )

    @property
    def t(self):
        return self._rkstate.t.value

    @property
    def y(self):
        return self._rkstate.y

    @property
    def dt(self):
        return self._rkstate.dt

    def _get_solver_flags(self, intersect=SolverFlags.NONE) -> SolverFlags:
        """Returns the currently set flags of the solver, intersected with `intersect`."""
        # _rkstate.flags is turned into an int-valued DeviceArray by JAX,
        # so we convert it back.
        return SolverFlags(int(self._rkstate.flags) & intersect)

    @property
    def errors(self) -> SolverFlags:
        """Returns the currently set error flags of the solver."""
        return self._get_solver_flags(SolverFlags.ERROR_FLAGS)

    @property
    def warnings(self) -> SolverFlags:
        """Returns the currently set warning flags of the solver."""
        return self._get_solver_flags(SolverFlags.WARNINGS_FLAGS)
예제 #9
0
class MetropolisPtSampler(MetropolisSampler):
    """
    Metropolis-Hastings with Parallel Tempering sampler.

    This sampler samples an Hilbert space, producing samples off a specific dtype.
    The samples are generated according to a transition rule that must be
    specified.
    """

    n_replicas: int = struct.field(pytree_node=False, default=32)
    """The number of replicas"""
    def __post_init__(self):
        super().__post_init__()
        if (not isinstance(self.n_replicas, int) and self.n_replicas > 0
                and np.mod(self.n_replicas, 2) == 0):
            raise ValueError("n_replicas must be an even integer > 0.")

    @property
    def n_batches(self):
        return self.n_chains * self.n_replicas

    def _init_state(sampler, machine, params: PyTree,
                    key: PRNGKeyT) -> MetropolisPtSamplerState:
        key_state, key_rule = jax.random.split(key, 2)
        σ = jnp.zeros(
            (sampler.n_batches, sampler.hilbert.size),
            dtype=sampler.dtype,
        )
        rule_state = sampler.rule.init_state(sampler, machine, params,
                                             key_rule)

        beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas
        beta = jnp.tile(beta, (sampler.n_chains, 1))

        return MetropolisPtSamplerState(
            σ=σ,
            rng=key_state,
            rule_state=rule_state,
            n_steps_proc=0,
            n_accepted_proc=0,
            beta=beta,
            beta_0_index=jnp.zeros((sampler.n_chains, ), dtype=jnp.int64),
            n_accepted_per_beta=jnp.zeros(
                (sampler.n_chains, sampler.n_replicas), dtype=jnp.int64),
            beta_position=jnp.zeros((sampler.n_chains, )),
            beta_diffusion=jnp.zeros((sampler.n_chains, )),
            exchange_steps=0,
        )

    def _reset(sampler, machine, parameters: PyTree,
               state: MetropolisPtSamplerState):
        new_rng, rng = jax.random.split(state.rng)

        σ = sampler.rule.random_state(sampler, machine, parameters, state, rng)

        rule_state = sampler.rule.reset(sampler, machine, parameters, state)

        beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas
        beta = jnp.tile(beta, (sampler.n_chains, 1))

        return state.replace(
            σ=σ,
            rng=new_rng,
            rule_state=rule_state,
            n_steps_proc=0,
            n_accepted_proc=0,
            n_accepted_per_beta=jnp.zeros(
                (sampler.n_chains, sampler.n_replicas), dtype=jnp.int64),
            beta_position=jnp.zeros((sampler.n_chains, )),
            beta_diffusion=jnp.zeros((sampler.n_chains)),
            exchange_steps=0,
            # beta=beta,
            # beta_0_index=jnp.zeros((sampler.n_chains,), dtype=jnp.int64),
        )

    def _sample_next(sampler, machine, parameters: PyTree,
                     state: MetropolisPtSamplerState):
        def loop_body(i, s):
            # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel
            s["key"], key1, key2, key3, key4 = jax.random.split(s["key"], 5)

            # def cbi(data):
            #    i, beta = data
            #    print("sweep #", i, " for beta=\n", beta)
            #    return beta
            #
            # beta = hcb.call(
            #   cbi,
            #   (i, s["beta"]),
            #   result_shape=jax.ShapeDtypeStruct(s["beta"].shape, s["beta"].dtype),
            # )

            beta = s["beta"]

            σp, log_prob_correction = sampler.rule.transition(
                sampler, machine, parameters, state, key1, s["σ"])
            proposal_log_prob = sampler.machine_pow * machine.apply(
                parameters, σp).real

            uniform = jax.random.uniform(key2, shape=(sampler.n_batches, ))
            if log_prob_correction is not None:
                do_accept = uniform < jnp.exp(
                    beta.reshape((-1, )) *
                    (proposal_log_prob - s["log_prob"] + log_prob_correction))
            else:
                do_accept = uniform < jnp.exp(
                    beta.reshape((-1, )) * (proposal_log_prob - s["log_prob"]))

            # do_accept must match ndim of proposal and state (which is 2)
            s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"])
            n_accepted_per_beta = s["n_accepted_per_beta"] + do_accept.reshape(
                (sampler.n_chains, sampler.n_replicas))

            s["log_prob"] = jax.numpy.where(do_accept.reshape(-1),
                                            proposal_log_prob, s["log_prob"])

            # exchange betas

            # randomly decide if every set of replicas should be swapped in even or odd order
            swap_order = jax.random.randint(
                key3,
                minval=0,
                maxval=2,
                shape=(sampler.n_chains, ),
            )  # 0 or 1
            # iswap_order = jnp.mod(swap_order + 1, 2)  #  1 or 0

            # indices of even swapped elements (per-row)
            idxs = jnp.arange(0, sampler.n_replicas, 2).reshape(
                (1, -1)) + swap_order.reshape((-1, 1))
            # indices off odd swapped elements (per-row)
            inn = (idxs + 1) % sampler.n_replicas

            # for every rows of the input, swap elements at idxs with elements at inn
            @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
            def swap_rows(beta_row, idxs, inn):
                proposed_beta = beta_row.at[idxs].set(beta_row[inn],
                                                      unique_indices=True,
                                                      indices_are_sorted=True)
                proposed_beta = proposed_beta.at[inn].set(
                    beta_row[idxs],
                    unique_indices=True,
                    indices_are_sorted=False)
                return proposed_beta

            proposed_beta = swap_rows(beta, idxs, inn)

            @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0)
            def compute_proposed_prob(prob, idxs, inn):
                prob_rescaled = prob[idxs] + prob[inn]
                return prob_rescaled

            # compute the probability of the swaps
            log_prob = (proposed_beta - state.beta) * s["log_prob"].reshape(
                (sampler.n_chains, sampler.n_replicas))

            prob_rescaled = jnp.exp(compute_proposed_prob(log_prob, idxs, inn))

            uniform = jax.random.uniform(key4,
                                         shape=(sampler.n_chains,
                                                sampler.n_replicas // 2))

            do_swap = uniform < prob_rescaled

            do_swap = jnp.dstack((do_swap, do_swap)).reshape(
                (-1, sampler.n_replicas))  # concat along last dimension

            # roll if swap_ordeer is odd
            @partial(jax.vmap, in_axes=(0, 0), out_axes=0)
            def fix_swap(do_swap, swap_order):
                return jax.lax.cond(swap_order == 0, lambda x: x,
                                    lambda x: jnp.roll(x, 1), do_swap)

            do_swap = fix_swap(do_swap, swap_order)
            # jax.experimental.host_callback.id_print(state.beta)
            # jax.experimental.host_callback.id_print(proposed_beta)

            # new_beta = jax.numpy.where(do_swap, proposed_beta, beta)

            # def cb(data):
            #     _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data
            #     print("--------.---------.---------.--------")
            #     print("     cur beta:\n", _bt)
            #     print("proposed beta:\n", _pbt)
            #     print("     new beta:\n", new_beta)
            #     print("swaporder :", so)
            #     print("do_swap :\n", do_swap)
            #     print("log_prob;\n", log_prob)
            #     print("prob_rescaled;\n", prob)
            #     return new_beta
            #
            # new_beta = hcb.call(
            #    cb,
            #    (
            #        beta,
            #        proposed_beta,
            #        new_beta,
            #        swap_order,
            #        do_swap,
            #        log_prob,
            #        prob_rescaled,
            #    ),
            #    result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype),
            # )
            # s["beta"] = new_beta

            swap_order = swap_order.reshape(-1)

            beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i],
                                    in_axes=(0, 0),
                                    out_axes=0)(do_swap, state.beta_0_index)
            proposed_beta_0_index = jnp.mod(
                state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) *
                (-jnp.mod(state.beta_0_index, 2) * 2 + 1),
                sampler.n_replicas,
            )

            s["beta_0_index"] = jnp.where(beta_0_moved, proposed_beta_0_index,
                                          s["beta_0_index"])

            # swap acceptances
            swapped_n_accepted_per_beta = swap_rows(n_accepted_per_beta, idxs,
                                                    inn)
            s["n_accepted_per_beta"] = jax.numpy.where(
                do_swap,
                swapped_n_accepted_per_beta,
                n_accepted_per_beta,
            )

            # Update statistics to compute diffusion coefficient of replicas
            # Total exchange steps performed
            delta = s["beta_0_index"] - s["beta_position"]
            s["beta_position"] = s["beta_position"] + delta / (
                state.exchange_steps + jnp.asarray(i, dtype=jnp.int64))
            delta2 = s["beta_0_index"] - s["beta_position"]
            s["beta_diffusion"] = s["beta_diffusion"] + delta * delta2

            return s

        new_rng, rng = jax.random.split(state.rng)

        # def cbr(data):
        #    new_rng, rng = data
        #    print("sample_next newrng:\n", new_rng,  "\nand rng:\n", rng)
        #    return new_rng
        #
        # new_rng = hcb.call(
        #   cbr,
        #   (new_rng, rng),
        #   result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype),
        # )

        s = {
            "key": rng,
            "σ": state.σ,
            "log_prob":
            sampler.machine_pow * machine.apply(parameters, state.σ).real,
            "beta": state.beta,
            # for logging
            "beta_0_index": state.beta_0_index,
            "n_accepted_per_beta": state.n_accepted_per_beta,
            "beta_position": state.beta_position,
            "beta_diffusion": state.beta_diffusion,
        }
        s = jax.lax.fori_loop(0, sampler.n_sweeps, loop_body, s)

        new_state = state.replace(
            rng=new_rng,
            σ=s["σ"],
            # n_accepted=s["accepted"],
            n_steps_proc=state.n_steps_proc +
            sampler.n_sweeps * sampler.n_chains,
            beta=s["beta"],
            beta_0_index=s["beta_0_index"],
            beta_position=s["beta_position"],
            beta_diffusion=s["beta_diffusion"],
            exchange_steps=state.exchange_steps + sampler.n_sweeps,
            n_accepted_per_beta=s["n_accepted_per_beta"],
        )

        offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas,
                             sampler.n_replicas)

        return new_state, new_state.σ[new_state.beta_0_index + offsets, :]
예제 #10
0
class Sampler(abc.ABC):
    """
    Abstract base class for all samplers.

    It contains the fields that all of them should posses, defining the common
    API.
    Note that fields marked with pytree_node=False are treated as static arguments
    when jitting.
    """

    hilbert: AbstractHilbert = struct.field(pytree_node=False)
    """Hilbert space to be sampled."""

    n_chains: int = struct.field(pytree_node=False, default=16)
    """Number of batches along the chain"""

    machine_pow: int = struct.field(default=2)
    """Exponent of the pdf sampled"""

    dtype: type = struct.field(pytree_node=False, default=np.float64)
    """DType of the states returned."""

    def __post_init__(self):
        # Raise errors if hilbert is not an Hilbert
        if not isinstance(self.hilbert, AbstractHilbert):
            raise ValueError(
                "hilbert must be a subtype of netket.hilbert.AbstractHilbert, "
                + "instead, type {} is not.".format(type(self.hilbert))
            )

        if not isinstance(self.n_chains, int) and self.n_chains >= 0:
            raise ValueError("n_chains must be a positivee integer")

        # if not isinstance(self.machine_pow, int) and self.machine_pow>= 0:
        #    raise ValueError("machine_pow must be a positivee integer")

    @property
    def n_batches(self) -> int:
        """
        The batch size of the configuration $\sigma$ used by this sampler.

        In general, it is equivalent to :attr:`~Sampler.n_chains`.
        """
        return self.n_chains

    def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable:
        """
        Returns a closure with the log_pdf function encoded by this sampler.

        Note: the result is returned as an HashablePartial so that the closure
        does not trigger recompilation.

        Args:
            model: The machine, or apply_fun

        Returns:
            the log probability density function
        """
        apply_fun = get_afun_if_module(model)
        log_pdf = HashablePartial(
            lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).real,
            apply_fun,
        )
        return log_pdf

    def init_state(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        seed: Optional[SeedType] = None,
    ) -> SamplerState:
        """
        Creates the structure holding the state of the sampler.

        If you want reproducible samples, you should specify `seed`, otherwise the state
        will be initialised randomly.

        If running across several MPI processes, all sampler_states are guaranteed to be
        in a different (but deterministic) state.
        This is achieved by first reducing (summing) the seed provided to every MPI rank,
        then generating n_rank seeds starting from the reduced one, and every rank is
        initialized with one of those seeds.

        The resulting state is guaranteed to be a frozen python dataclass (in particular,
        a flax's dataclass), and it can be serialized using Flax serialization methods.

        Args:
            machine: a Flax module or callable with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used.

        Returns:
            The structure holding the state of the sampler. In general you should not expect
            it to be in a valid state, and should reset it before use.
        """
        key = nkjax.PRNGKey(seed)
        key = nkjax.mpi_split(key)

        return sampler._init_state(get_afun_if_module(machine), parameters, key)

    def reset(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> SamplerState:
        """
        Resets the state of the sampler. To be used every time the parameters are changed.

        Args:
            machine: a Flax module or callable with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If it's not provided, it will be constructed
                by calling :code:`sampler.init_state(machine, parameters)` with a random seed.

        Returns:
            A valid sampler state.
        """
        if state is None:
            state = sampler_state(sampler, machine, parameters)

        return sampler._reset(get_afun_if_module(machine), parameters, state)

    def sample_next(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: Optional[SamplerState] = None,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Samples the next state in the markov chain.

        Args:
            machine: a Flax module or callable apply function with the forward pass of the log-pdf.
            parameters: The PyTree of parameters of the model.
            state: The current state of the sampler. If it's not provided, it will be constructed
                by calling :code:`sampler.reset(machine, parameters)` with a random seed.

        Returns:
            state: The new state of the sampler
            σ: The next batch of samples.
        """
        if state is None:
            state = sampler_state(sampler, machine, parameters)

        return sampler._sample_next(get_afun_if_module(machine), parameters, state)

    def sample(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        *,
        state: Optional[SamplerState] = None,
        chain_length: int = 1,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Samples chain_length elements along the chains.

        Arguments:
            sampler: The Monte Carlo sampler.
            machine: The model or callable to sample from (if it's a function it should have
                the signature :code:`f(parameters, σ) -> jnp.ndarray`).
            parameters: The PyTree of parameters of the model.
            state: current state of the sampler. If None, then initialises it.
            chain_length: (default=1), the length of the chains.

        Returns:
            state: The new state of the sampler
            σ: The next batch of samples.
        """

        return sample(
            sampler, machine, parameters, state=state, chain_length=chain_length
        )

    def _sample_chain(
        sampler,
        machine: Union[Callable, nn.Module],
        parameters: PyTree,
        state: SamplerState,
        chain_length: int,
    ) -> Tuple[jnp.ndarray, SamplerState]:
        """
        Samples chain_length elements along the chains.

        In general this should not be overridden unless you want to modify the logic by which
        the whole sampling is performed.
        If using Jax, this function should be jitted

        Arguments:
            sampler: The Monte Carlo sampler.
            machine: The model or callable to sample from (if it's a function it should have
                the signature :code:`f(parameters, σ) -> jnp.ndarray`).
            parameters: The PyTree of parameters of the model.
            state: current state of the sampler. If None, then initialises it.
            chain_length: (default=1), the length of the chains.

        Returns:
            state: The new state of the sampler
            σ: The next batch of samples.
        """
        return _sample_chain(sampler, machine, parameters, state, chain_length)

    @abc.abstractmethod
    def _init_state(sampler, machine, params, seed) -> SamplerState:
        """
        Implementation of init_state for subclasses of Sampler.

        If you sub-class Sampler, you should define this and not init_state
        itself, because init_state contains some common logic.
        """
        raise NotImplementedError("init_state Not Implemented")

    @abc.abstractmethod
    def _reset(sampler, machine, parameters, state):
        """
        Implementation of reset for subclasses of Sampler.

        If you sub-class Sampler, you should define _reset and not reset
        itself, because reset contains some common logic.
        """
        raise NotImplementedError("reset Not Implemented")

    @abc.abstractmethod
    def _sample_next(sampler, machine, parameters, state=None):
        """
        Implementation of sample_next for subclasses of Sampler.

        If you sub-class Sampler, you should define _sample_next and not sample_next
        itself, because reset contains some common logic.
        """
        raise NotImplementedError("sample_next Not Implemented")