class Point1: x: float y: float meta: Any = struct.field(pytree_node=False) @struct.property_cached def cached_node(self) -> int: return 3
class Point0: x: float y: float meta: Any = struct.field(pytree_node=False) def __pre_init__(self, *args, **kwargs): if "z" in kwargs: kwargs["x"] = kwargs.pop("z") * 10 return args, kwargs @struct.property_cached def cached_node(self) -> int: return 3
class MetropolisSampler(Sampler): r""" Metropolis-Hastings sampler for an Hilbert space according to a specific transition rule. The transition rule is used to generate a proposed state :math:`s^\prime`, starting from the current state :math:`s`. The move is accepted with probability .. math:: A(s \rightarrow s^\prime) = \mathrm{min} \left( 1,\frac{P(s^\prime)}{P(s)} F(e^{L(s,s^\prime)}) \right) , where the probability being sampled from is :math:`P(s)=|M(s)|^p. Here ::math::`M(s)` is a user-provided function (the machine), :math:`p` is also user-provided with default value :math:`p=2`, and :math:`L(s,s^\prime)` is a suitable correcting factor computed by the transition kernel. The dtype of the sampled states can be chosen. """ rule: MetropolisRule = None """The metropolis transition rule.""" n_sweeps: int = struct.field(pytree_node=False, default=None) """Number of sweeps for each step along the chain. Defaults to number of sites in hilbert space.""" reset_chains: bool = struct.field(pytree_node=False, default=False) """If True resets the chain state when reset is called (every new sampling).""" def __pre_init__(self, hilbert, rule, **kwargs): r""" Constructs a Metropolis Sampler. Args: hilbert: The hilbert space to sample rule: A `MetropolisRule` to generate random transitions from a given state as well as uniform random states. n_sweeps: The number of exchanges that compose a single sweep. If None, sweep_size is equal to the number of degrees of freedom being sampled (the size of the input vector s to the machine). reset_chains: If False the state configuration is not resetted when reset() is called. n_chains: The number of Markov Chain to be run in parallel on a single process. machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2). dtype: The dtype of the statees sampled (default = np.float32). """ # process arguments in the base args, kwargs = super().__pre_init__(hilbert=hilbert, **kwargs) kwargs["rule"] = rule # deprecation warnings if "reset_chain" in kwargs: warn_deprecation( "The keyword argument `reset_chain` is deprecated in favour of `reset_chains`" ) kwargs["reset_chains"] = kwargs.pop("reset_chain") return args, kwargs def __post_init__(self): super().__post_init__() # Validate the inputs if not isinstance(self.rule, MetropolisRule): raise TypeError("rule must be a MetropolisRule.") if not isinstance(self.reset_chains, bool): raise TypeError("reset_chains must be a boolean.") # Default value of n_sweeps if self.n_sweeps is None: object.__setattr__(self, "n_sweeps", self.hilbert.size) def _init_state(sampler, machine, params, key): key_state, key_rule = jax.random.split(key, 2) rule_state = sampler.rule.init_state(sampler, machine, params, key_rule) σ = jnp.zeros((sampler.n_chains_per_rank, sampler.hilbert.size), dtype=sampler.dtype) state = MetropolisSamplerState(σ=σ, rng=key_state, rule_state=rule_state) # If we don't reset the chain at every sampling iteration, then reset it # now. if not sampler.reset_chains: key_state, rng = jax.random.split(key_state) σ = sampler.rule.random_state(sampler, machine, params, state, rng) state = state.replace(σ=σ, rng=key_state) return state def _reset(sampler, machine, parameters, state): new_rng, rng = jax.random.split(state.rng) if sampler.reset_chains: σ = sampler.rule.random_state(sampler, machine, parameters, state, rng) else: σ = state.σ rule_state = sampler.rule.reset(sampler, machine, parameters, state) return state.replace(σ=σ, rng=new_rng, rule_state=rule_state, n_steps_proc=0, n_accepted_proc=0) def _sample_next(sampler, machine, parameters, state): new_rng, rng = jax.random.split(state.rng) with loops.Scope() as s: s.key = rng s.σ = state.σ s.log_prob = sampler.machine_pow * machine.apply( parameters, state.σ).real # for logging s.accepted = state.n_accepted_proc for i in s.range(sampler.n_sweeps): # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel s.key, key1, key2 = jax.random.split(s.key, 3) σp, log_prob_correction = sampler.rule.transition( sampler, machine, parameters, state, key1, s.σ) proposal_log_prob = (sampler.machine_pow * machine.apply(parameters, σp).real) uniform = jax.random.uniform( key2, shape=(sampler.n_chains_per_rank, )) if log_prob_correction is not None: do_accept = uniform < jnp.exp(proposal_log_prob - s.log_prob + log_prob_correction) else: do_accept = uniform < jnp.exp(proposal_log_prob - s.log_prob) # do_accept must match ndim of proposal and state (which is 2) s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ) s.accepted += do_accept.sum() s.log_prob = jax.numpy.where(do_accept.reshape(-1), proposal_log_prob, s.log_prob) new_state = state.replace( rng=new_rng, σ=s.σ, n_accepted_proc=s.accepted, n_steps_proc=state.n_steps_proc + sampler.n_sweeps * sampler.n_chains_per_rank, ) return new_state, new_state.σ def __repr__(sampler): return (f"{type(sampler).__name__}(" + "\n hilbert = {},".format(sampler.hilbert) + "\n rule = {},".format(sampler.rule) + "\n n_chains = {},".format(sampler.n_chains) + "\n machine_power = {},".format(sampler.machine_pow) + "\n reset_chains = {},".format(sampler.reset_chains) + "\n n_sweeps = {},".format(sampler.n_sweeps) + "\n dtype = {}".format(sampler.dtype) + ")") def __str__(sampler): return (f"{type(sampler).__name__}(" + "rule = {}, ".format(sampler.rule) + "n_chains = {}, ".format(sampler.n_chains) + "machine_power = {}, ".format(sampler.machine_pow) + "n_sweeps = {}, ".format(sampler.n_sweeps) + "dtype = {})".format(sampler.dtype))
class Sampler(abc.ABC): """ Abstract base class for all samplers. It contains the fields that all of them should possess, defining the common API. Note that fields marked with `pytree_node=False` are treated as static arguments when jitting. """ hilbert: AbstractHilbert = struct.field(pytree_node=False) """The Hilbert space to sample.""" n_chains_per_rank: int = struct.field(pytree_node=False, default=None) """Number of independent chains on every MPI rank.""" machine_pow: int = struct.field(default=2) """The power to which the machine should be exponentiated to generate the pdf.""" dtype: DType = struct.field(pytree_node=False, default=np.float64) """The dtype of the states sampled.""" def __pre_init__(self, hilbert: AbstractHilbert, n_chains: Optional[int] = None, **kwargs): """ Construct a Monte Carlo sampler. Args: hilbert: The Hilbert space to sample. n_chains: The total number of independent chains across all MPI ranks. Either specify this or `n_chains_per_rank`. n_chains_per_rank: Number of independent chains on every MPI rank (default = 1). machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2). dtype: The dtype of the states sampled (default = np.float64). """ if "n_chains_per_rank" in kwargs: if n_chains is not None: raise ValueError( "Cannot specify both `n_chains` and `n_chains_per_rank`") else: if n_chains is None: # Default value n_chains_per_rank = 1 else: n_chains_per_rank = max(int(np.ceil(n_chains / mpi.n_nodes)), 1) if mpi.n_nodes > 1 and mpi.rank == 0: if n_chains_per_rank * mpi.n_nodes != n_chains: import warnings warnings.warn( f"Using {n_chains_per_rank} chains per rank among {mpi.n_nodes} ranks " f"(total={n_chains_per_rank * mpi.n_nodes} instead of n_chains={n_chains}). " f"To directly control the number of chains on every rank, specify " f"`n_chains_per_rank` when constructing the sampler. " f"To silence this warning, either use `n_chains_per_rank` or use `n_chains` " f"that is a multiple of the number of MPI ranks.", category=UserWarning, ) kwargs["n_chains_per_rank"] = n_chains_per_rank return (hilbert, ), kwargs def __post_init__(self): # Raise errors if hilbert is not an Hilbert if not isinstance(self.hilbert, AbstractHilbert): raise ValueError( "hilbert must be a subtype of netket.hilbert.AbstractHilbert, " + "instead, type {} is not.".format(type(self.hilbert))) # workaround Jax bug under pmap # might be removed in the future if type(self.machine_pow) != object: if not np.issubdtype(numbers.dtype(self.machine_pow), np.integer): raise ValueError( f"machine_pow ({self.machine_pow}) must be a positive integer" ) @property def n_chains(self) -> int: """ The total number of independent chains across all MPI ranks. If you are not using MPI, this is equal to `n_chains_per_rank`. """ return self.n_chains_per_rank * mpi.n_nodes @property def n_batches(self) -> int: r""" The batch size of the configuration $\sigma$ used by this sampler. In general, it is equivalent to :attr:`~Sampler.n_chains_per_rank`. """ return self.n_chains_per_rank @property def is_exact(self) -> bool: """ Returns `True` if the sampler is exact. The sampler is exact if all the samples are exactly distributed according to the chosen power of the variational state, and there is no correlation among them. """ return False def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable: """ Returns a closure with the log-pdf function encoded by this sampler. Args: model: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. Returns: The log-probability density function. Note: The result is returned as a `HashablePartial` so that the closure does not trigger recompilation. """ apply_fun = get_afun_if_module(model) log_pdf = HashablePartial( lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ). real, apply_fun, ) return log_pdf def init_state( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, seed: Optional[SeedT] = None, ) -> SamplerState: """ Creates the structure holding the state of the sampler. If you want reproducible samples, you should specify `seed`, otherwise the state will be initialised randomly. If running across several MPI processes, all `sampler_state`s are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating `n_rank` seeds starting from the reduced one, and every rank is initialized with one of those seeds. The resulting state is guaranteed to be a frozen Python dataclass (in particular, a Flax dataclass), and it can be serialized using Flax serialization methods. Args: machine: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. parameters: The PyTree of parameters of the model. seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used. Returns: The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use. """ key = nkjax.PRNGKey(seed) key = nkjax.mpi_split(key) return sampler._init_state(wrap_afun(machine), parameters, key) def reset( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> SamplerState: """ Resets the state of the sampler. To be used every time the parameters are changed. Args: machine: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If not specified, it will be constructed by calling :code:`sampler.init_state(machine, parameters)` with a random seed. Returns: A valid sampler state. """ if state is None: state = sampler.init_state(machine, parameters) return sampler._reset(wrap_afun(machine), parameters, state) def sample( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, *, state: Optional[SamplerState] = None, chain_length: int = 1, ) -> Tuple[jnp.ndarray, SamplerState]: """ Samples `chain_length` batches of samples along the chains. Arguments: machine: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If not specified, then initialize and reset it. chain_length: The length of the chains (default = 1). Returns: σ: The generated batches of samples. state: The new state of the sampler. """ if state is None: state = sampler.reset(machine, parameters) return sampler._sample_chain(wrap_afun(machine), parameters, state, chain_length) def samples( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, *, state: Optional[SamplerState] = None, chain_length: int = 1, ) -> Iterator[jnp.ndarray]: """ Returns a generator sampling `chain_length` batches of samples along the chains. Arguments: machine: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If not specified, then initialize and reset it. chain_length: The length of the chains (default = 1). """ if state is None: state = sampler.reset(machine, parameters) machine = wrap_afun(machine) for i in range(chain_length): samples, state = sampler._sample_chain(machine, parameters, state, 1) yield samples[0, :, :] @abc.abstractmethod def _sample_chain( sampler, machine: nn.Module, parameters: PyTree, state: SamplerState, chain_length: int, ) -> Tuple[jnp.ndarray, SamplerState]: """ Implementation of `sample` for subclasses of `Sampler`. If you subclass `Sampler`, you should override this and not `sample` itself, because `sample` contains some common logic. If using Jax, this function should be jitted. Arguments: machine: A Flax module with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. state: The current state of the sampler. chain_length: The length of the chains. Returns: σ: The generated batches of samples. state: The new state of the sampler. """ @abc.abstractmethod def _init_state(sampler, machine, params, seed) -> SamplerState: """ Implementation of `init_state` for subclasses of `Sampler`. If you subclass `Sampler`, you should override this and not `init_state` itself, because `init_state` contains some common logic. """ @abc.abstractmethod def _reset(sampler, machine, parameters, state): """
class MetropolisPtSampler(MetropolisSampler): """ Metropolis-Hastings with Parallel Tempering sampler. This sampler samples an Hilbert space, producing samples off a specific dtype. The samples are generated according to a transition rule that must be specified. """ n_replicas: int = struct.field(pytree_node=False, default=32) """The number of replicas""" def __post_init__(self): if not config.FLAGS["NETKET_EXPERIMENTAL"]: raise RuntimeError(""" Parallel Tempering samplers are under development and are known not to work. If you want to debug it, set the environment variable NETKET_EXPERIMENTAL=1 """) super().__post_init__() if (not isinstance(self.n_replicas, int) and self.n_replicas > 0 and np.mod(self.n_replicas, 2) == 0): raise ValueError("n_replicas must be an even integer > 0.") @property def n_batches(self): return self.n_chains * self.n_replicas def _init_state(sampler, machine, params: PyTree, key: PRNGKeyT) -> MetropolisPtSamplerState: key_state, key_rule = jax.random.split(key, 2) σ = jnp.zeros( (sampler.n_batches, sampler.hilbert.size), dtype=sampler.dtype, ) rule_state = sampler.rule.init_state(sampler, machine, params, key_rule) beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas beta = jnp.tile(beta, (sampler.n_chains, 1)) return MetropolisPtSamplerState( σ=σ, rng=key_state, rule_state=rule_state, n_steps_proc=0, n_accepted_proc=0, beta=beta, beta_0_index=jnp.zeros((sampler.n_chains, ), dtype=int), n_accepted_per_beta=jnp.zeros( (sampler.n_chains, sampler.n_replicas), dtype=int), beta_position=jnp.zeros((sampler.n_chains, )), beta_diffusion=jnp.zeros((sampler.n_chains, )), exchange_steps=0, ) def _reset(sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState): new_rng, rng = jax.random.split(state.rng) σ = sampler.rule.random_state(sampler, machine, parameters, state, rng) rule_state = sampler.rule.reset(sampler, machine, parameters, state) beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas beta = jnp.tile(beta, (sampler.n_chains, 1)) return state.replace( σ=σ, rng=new_rng, rule_state=rule_state, n_steps_proc=0, n_accepted_proc=0, n_accepted_per_beta=jnp.zeros( (sampler.n_chains, sampler.n_replicas)), beta_position=jnp.zeros((sampler.n_chains, )), beta_diffusion=jnp.zeros((sampler.n_chains)), exchange_steps=0, # beta=beta, # beta_0_index=jnp.zeros((sampler.n_chains,), dtype=jnp.int32), ) def _sample_next(sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState): new_rng, rng = jax.random.split(state.rng) # def cbr(data): # new_rng, rng = data # print("sample_next newrng:\n", new_rng, "\nand rng:\n", rng) # return new_rng # new_rng = hcb.call( # cbr, # (new_rng, rng), # result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype), # ) with loops.Scope() as s: s.key = rng s.σ = state.σ s.log_prob = sampler.machine_pow * machine.apply( parameters, state.σ).real s.beta = state.beta # for logging s.beta_0_index = state.beta_0_index s.n_accepted_per_beta = state.n_accepted_per_beta s.beta_position = state.beta_position s.beta_diffusion = state.beta_diffusion for i in s.range(sampler.n_sweeps): # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel s.key, key1, key2, key3, key4 = jax.random.split(s.key, 5) # def cbi(data): # i, beta = data # print("sweep #", i, " for beta=\n", beta) # return beta # beta = hcb.call( # cbi, # (i, s.beta), # result_shape=jax.ShapeDtypeStruct(s.beta.shape, s.beta.dtype), # ) beta = s.beta σp, log_prob_correction = sampler.rule.transition( sampler, machine, parameters, state, key1, s.σ) proposal_log_prob = (sampler.machine_pow * machine.apply(parameters, σp).real) uniform = jax.random.uniform(key2, shape=(sampler.n_batches, )) if log_prob_correction is not None: do_accept = uniform < jnp.exp( beta.reshape((-1, )) * (proposal_log_prob - s.log_prob + log_prob_correction)) else: do_accept = uniform < jnp.exp( beta.reshape( (-1, )) * (proposal_log_prob - s.log_prob)) # do_accept must match ndim of proposal and state (which is 2) s.σ = jnp.where(do_accept.reshape(-1, 1), σp, s.σ) n_accepted_per_beta = s.n_accepted_per_beta + do_accept.reshape( (sampler.n_chains, sampler.n_replicas)) s.log_prob = jax.numpy.where(do_accept.reshape(-1), proposal_log_prob, s.log_prob) # exchange betas # randomly decide if every set of replicas should be swapped in even or odd order swap_order = jax.random.randint( key3, minval=0, maxval=2, shape=(sampler.n_chains, ), ) # 0 or 1 iswap_order = jnp.mod(swap_order + 1, 2) # 1 or 0 # indices of even swapped elements (per-row) idxs = jnp.arange(0, sampler.n_replicas, 2).reshape( (1, -1)) + swap_order.reshape((-1, 1)) # indices off odd swapped elements (per-row) inn = (idxs + 1) % sampler.n_replicas # for every rows of the input, swap elements at idxs with elements at inn @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def swap_rows(beta_row, idxs, inn): proposed_beta = jax.ops.index_update( beta_row, idxs, beta_row[inn], unique_indices=True, indices_are_sorted=True, ) proposed_beta = jax.ops.index_update( proposed_beta, inn, beta_row[idxs], unique_indices=True, indices_are_sorted=False, ) return proposed_beta proposed_beta = swap_rows(beta, idxs, inn) @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def compute_proposed_prob(prob, idxs, inn): prob_rescaled = prob[idxs] + prob[inn] return prob_rescaled # compute the probability of the swaps log_prob = (proposed_beta - state.beta) * s.log_prob.reshape( (sampler.n_chains, sampler.n_replicas)) prob_rescaled = jnp.exp( compute_proposed_prob(log_prob, idxs, inn)) prob_rescaled = jnp.exp( compute_proposed_prob(log_prob, idxs, inn)) uniform = jax.random.uniform(key4, shape=(sampler.n_chains, sampler.n_replicas // 2)) do_swap = uniform < prob_rescaled do_swap = jnp.dstack((do_swap, do_swap)).reshape( (-1, sampler.n_replicas)) # concat along last dimension # roll if swap_ordeer is odd @partial(jax.vmap, in_axes=(0, 0), out_axes=0) def fix_swap(do_swap, swap_order): return jax.lax.cond(swap_order == 0, lambda x: x, lambda x: jnp.roll(x, 1), do_swap) do_swap = fix_swap(do_swap, swap_order) # jax.experimental.host_callback.id_print(state.beta) # jax.experimental.host_callback.id_print(proposed_beta) new_beta = jax.numpy.where(do_swap, proposed_beta, beta) def cb(data): _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data print("--------.---------.---------.--------") print(" cur beta:\n", _bt) print("proposed beta:\n", _pbt) print(" new beta:\n", new_beta) print("swaporder :", so) print("do_swap :\n", do_swap) print("log_prob;\n", log_prob) print("prob_rescaled;\n", prob) return new_beta # new_beta = hcb.call( # cb, # ( # beta, # proposed_beta, # new_beta, # swap_order, # do_swap, # log_prob, # prob_rescaled, # ), # result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype), # ) # s.beta = new_beta swap_order = swap_order.reshape(-1) beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i], in_axes=(0, 0), out_axes=0)(do_swap, state.beta_0_index) proposed_beta_0_index = jnp.mod( state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) * (-jnp.mod(state.beta_0_index, 2) * 2 + 1), sampler.n_replicas, ) s.beta_0_index = jnp.where(beta_0_moved, proposed_beta_0_index, s.beta_0_index) # swap acceptances swapped_n_accepted_per_beta = swap_rows( n_accepted_per_beta, idxs, inn) s.n_accepted_per_beta = jax.numpy.where( do_swap, swapped_n_accepted_per_beta, n_accepted_per_beta, ) # Update statistics to compute diffusion coefficient of replicas # Total exchange steps performed delta = s.beta_0_index - s.beta_position s.beta_position = s.beta_position + delta / ( state.exchange_steps + i) delta2 = s.beta_0_index - s.beta_position s.beta_diffusion = s.beta_diffusion + delta * delta2 new_state = state.replace( rng=new_rng, σ=s.σ, # n_accepted=s.accepted, n_steps_proc=state.n_steps_proc + sampler.n_sweeps * sampler.n_chains, beta=s.beta, beta_0_index=s.beta_0_index, beta_position=s.beta_position, beta_diffusion=s.beta_diffusion, exchange_steps=state.exchange_steps + sampler.n_sweeps, n_accepted_per_beta=s.n_accepted_per_beta, ) offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas, sampler.n_replicas) return new_state, new_state.σ[new_state.beta_0_index + offsets, :] def __repr__(sampler): return ("MetropolisPTSampler(" + "\n hilbert = {},".format(sampler.hilbert) + "\n rule = {},".format(sampler.rule) + "\n n_chains = {},".format(sampler.n_chains) + "\n machine_power = {},".format(sampler.machine_pow) + "\n reset_chain = {},".format(sampler.reset_chain) + "\n n_sweeps = {},".format(sampler.n_sweeps) + "\n dtype = {},".format(sampler.dtype) + ")") def __str__(sampler): return ("MetropolisPTSampler(" + "rule = {}, ".format(sampler.rule) + "n_chains = {}, ".format(sampler.n_chains) + "machine_power = {}, ".format(sampler.machine_pow) + "reset_chain = {}, ".format(sampler.reset_chain) + "n_sweeps = {}, ".format(sampler.n_sweeps) + "dtype = {})".format(sampler.dtype))
class MetropolisSampler(Sampler): r""" Metropolis-Hastings sampler for a Hilbert space according to a specific transition rule. The transition rule is used to generate a proposed state :math:`s^\prime`, starting from the current state :math:`s`. The move is accepted with probability .. math:: A(s \rightarrow s^\prime) = \mathrm{min} \left( 1,\frac{P(s^\prime)}{P(s)} e^{L(s,s^\prime)} \right) , where the probability being sampled from is :math:`P(s)=|M(s)|^p`. Here :math:`M(s)` is a user-provided function (the machine), :math:`p` is also user-provided with default value :math:`p=2`, and :math:`L(s,s^\prime)` is a suitable correcting factor computed by the transition kernel. The dtype of the sampled states can be chosen. """ rule: MetropolisRule = None """The Metropolis transition rule.""" n_sweeps: int = struct.field(pytree_node=False, default=None) """Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space.""" reset_chains: bool = struct.field(pytree_node=False, default=False) """If True, resets the chain state when `reset` is called on every new sampling.""" def __pre_init__(self, hilbert, rule, **kwargs): """ Constructs a Metropolis Sampler. Args: hilbert: The Hilbert space to sample. rule: A `MetropolisRule` to generate random transitions from a given state as well as uniform random states. n_chains: The total number of independent Markov chains across all MPI ranks. Either specify this or `n_chains_per_rank`. n_chains_per_rank: Number of independent chains on every MPI rank (default = 16). n_sweeps: Number of sweeps for each step along the chain. Defaults to the number of sites in the Hilbert space. This is equivalent to subsampling the Markov chain. reset_chains: If True, resets the chain state when `reset` is called on every new sampling (default = False). machine_pow: The power to which the machine should be exponentiated to generate the pdf (default = 2). dtype: The dtype of the states sampled (default = np.float64). """ if "n_chains" not in kwargs and "n_chains_per_rank" not in kwargs: kwargs["n_chains_per_rank"] = 16 # process arguments in the base args, kwargs = super().__pre_init__(hilbert=hilbert, **kwargs) kwargs["rule"] = rule # deprecation warnings if "reset_chain" in kwargs: warn_deprecation( "The keyword argument `reset_chain` is deprecated in favour of `reset_chains`" ) kwargs["reset_chains"] = kwargs.pop("reset_chain") return args, kwargs def __post_init__(self): super().__post_init__() # Validate the inputs if not isinstance(self.rule, MetropolisRule): raise TypeError("rule must be a MetropolisRule.") if not isinstance(self.reset_chains, bool): raise TypeError("reset_chains must be a boolean.") # Default value of n_sweeps if self.n_sweeps is None: object.__setattr__(self, "n_sweeps", self.hilbert.size) def sample_next( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> Tuple[SamplerState, jnp.ndarray]: """ Samples the next state in the Markov chain. Args: machine: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If not specified, then initialize and reset it. Returns: state: The new state of the sampler. σ: The next batch of samples. Note: The return order is inverted wrt `sample` because when called inside of a scan function the first returned argument should be the state. """ if state is None: state = sampler.reset(machine, parameters) return sampler._sample_next(wrap_afun(machine), parameters, state) def _init_state(sampler, machine, params, key): key_state, key_rule = jax.random.split(key, 2) rule_state = sampler.rule.init_state(sampler, machine, params, key_rule) σ = jnp.zeros((sampler.n_chains_per_rank, sampler.hilbert.size), dtype=sampler.dtype) state = MetropolisSamplerState(σ=σ, rng=key_state, rule_state=rule_state) # If we don't reset the chain at every sampling iteration, then reset it # now. if not sampler.reset_chains: key_state, rng = jax.random.split(key_state) σ = sampler.rule.random_state(sampler, machine, params, state, rng) state = state.replace(σ=σ, rng=key_state) return state def _reset(sampler, machine, parameters, state): new_rng, rng = jax.random.split(state.rng) if sampler.reset_chains: σ = sampler.rule.random_state(sampler, machine, parameters, state, rng) else: σ = state.σ rule_state = sampler.rule.reset(sampler, machine, parameters, state) return state.replace(σ=σ, rng=new_rng, rule_state=rule_state, n_steps_proc=0, n_accepted_proc=0) def _sample_next(sampler, machine, parameters, state): """ Implementation of `sample_next` for subclasses of `MetropolisSampler`. If you subclass `MetropolisSampler`, you should override this and not `sample_next` itself, because `sample_next` contains some common logic. """ def loop_body(i, s): # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel s["key"], key1, key2 = jax.random.split(s["key"], 3) σp, log_prob_correction = sampler.rule.transition( sampler, machine, parameters, state, key1, s["σ"]) proposal_log_prob = sampler.machine_pow * machine.apply( parameters, σp).real uniform = jax.random.uniform(key2, shape=(sampler.n_chains_per_rank, )) if log_prob_correction is not None: do_accept = uniform < jnp.exp(proposal_log_prob - s["log_prob"] + log_prob_correction) else: do_accept = uniform < jnp.exp(proposal_log_prob - s["log_prob"]) # do_accept must match ndim of proposal and state (which is 2) s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"]) s["accepted"] += do_accept.sum() s["log_prob"] = jax.numpy.where(do_accept.reshape(-1), proposal_log_prob, s["log_prob"]) return s new_rng, rng = jax.random.split(state.rng) s = { "key": rng, "σ": state.σ, "log_prob": sampler.machine_pow * machine.apply(parameters, state.σ).real, # for logging "accepted": state.n_accepted_proc, } s = jax.lax.fori_loop(0, sampler.n_sweeps, loop_body, s) new_state = state.replace( rng=new_rng, σ=s["σ"], n_accepted_proc=s["accepted"], n_steps_proc=state.n_steps_proc + sampler.n_sweeps * sampler.n_chains_per_rank, ) return new_state, new_state.σ def _sample_chain(sampler, machine, parameters, state, chain_length): return _sample_chain(sampler, machine, parameters, state, chain_length) def __repr__(sampler): return (f"{type(sampler).__name__}(" + "\n hilbert = {},".format(sampler.hilbert) + "\n rule = {},".format(sampler.rule) + "\n n_chains = {},".format(sampler.n_chains) + "\n n_sweeps = {},".format(sampler.n_sweeps) + "\n reset_chains = {},".format(sampler.reset_chains) + "\n machine_power = {},".format(sampler.machine_pow) + "\n dtype = {}".format(sampler.dtype) + ")") def __str__(sampler): return (f"{type(sampler).__name__}(" + "rule = {}, ".format(sampler.rule) + "n_chains = {}, ".format(sampler.n_chains) + "n_sweeps = {}, ".format(sampler.n_sweeps) + "reset_chains = {}, ".format(sampler.reset_chains) + "machine_power = {}, ".format(sampler.machine_pow) + "dtype = {})".format(sampler.dtype))
class MetropolisSamplerState(SamplerState): """ State for a Metropolis sampler. Contains the current configuration, the RNG state and the (optional) state of the transition rule. """ σ: jnp.ndarray """Current batch of configurations in the Markov chain.""" rng: jnp.ndarray """State of the random number generator (key, in jax terms).""" rule_state: Optional[Any] """Optional state of the transition rule.""" # those are initialised to 0. We want to initialise them to zero arrays because they can # be passed to jax jitted functions that require type invariance to avoid recompilation n_steps_proc: int = struct.field( default_factory=lambda: jnp.zeros((), dtype=jnp.int64)) """Number of moves performed along the chains in this process since the last reset.""" n_accepted_proc: int = struct.field( default_factory=lambda: jnp.zeros((), dtype=jnp.int64)) """Number of accepted transitions among the chains in this process since the last reset.""" @property def acceptance(self) -> float: """The fraction of accepted moves across all chains and MPI processes. The rate is computed since the last reset of the sampler. Will return None if no sampling has been performed since then. """ if self.n_steps == 0: return None return self.n_accepted / self.n_steps @property @deprecated("""Please use the attribute `.acceptance` instead of `.acceptance_ratio`. The new attribute `.acceptance` returns the acceptance ratio ∈ [0,1], instead of the current `acceptance_ratio` returning a percentage, which is a bug.""") def acceptance_ratio(self): """DEPRECATED: Please use the attribute `.acceptance` instead of `.acceptance_ratio`. The new attribute `.acceptance` returns the acceptance ratio ∈ [0,1], instead of the current `acceptance_ratio` returning a percentage, which is a bug. The percentage of accepted moves across all chains and MPI processes. The rate is computed since the last reset of the sampler. Will return None if no sampling has been performed since then. """ return self.acceptance * 100 @property def n_steps(self) -> int: """Total number of moves performed across all processes since the last reset.""" return self.n_steps_proc * mpi.n_nodes @property def n_accepted(self) -> int: """Total number of moves accepted across all processes since the last reset.""" res, _ = mpi.mpi_sum_jax(self.n_accepted_proc) return res def __repr__(self): if self.n_steps > 0: acc_string = "# accepted = {}/{} ({}%), ".format( self.n_accepted, self.n_steps, self.acceptance * 100) else: acc_string = "" return f"{type(self).__name__}({acc_string}rng state={self.rng})"
class RungeKuttaIntegrator: tableau: rkt.NamedTableau f: Callable = field(repr=False) t0: float y0: Array = field(repr=False) initial_dt: float use_adaptive: bool norm: Callable atol: float = 0.0 rtol: float = 1e-7 dt_limits: Optional[LimitsType] = None def __post_init__(self): if self.use_adaptive and not self.tableau.data.is_adaptive: raise RuntimeError( f"Solver {self.tableau} does not support adaptive step size") if self.use_adaptive: self._do_step = self._do_step_adaptive else: self._do_step = self._do_step_fixed if self.norm is None: self.norm = euclidean_norm if self.dt_limits is None: self.dt_limits = (None, 10 * self.initial_dt) self._rkstate = RungeKuttaState( step_no=0, step_no_total=0, t=nk.utils.KahanSum(self.t0), y=self.y0, dt=self.initial_dt, last_norm=0.0 if self.use_adaptive else None, flags=SolverFlags(0), ) def step(self, max_dt=None): """ Perform one full Runge-Kutta step by min(self.dt, max_dt). Returns: A boolean indicating whether the step was successful or was rejected by the step controller and should be retried. Note that the step size can be adjusted by the step controller in both cases, so the integrator state will have changed even after a rejected step. """ self._rkstate = self._do_step(self._rkstate, max_dt) return self._rkstate.accepted def _do_step_fixed(self, rk_state, max_dt=None): return general_time_step_fixed( self.tableau.data, self.f, rk_state, max_dt=max_dt, ) def _do_step_adaptive(self, rk_state, max_dt=None): return general_time_step_adaptive( self.tableau.data, self.f, rk_state, atol=self.atol, rtol=self.rtol, norm_fn=self.norm, max_dt=max_dt, dt_limits=self.dt_limits, ) @property def t(self): return self._rkstate.t.value @property def y(self): return self._rkstate.y @property def dt(self): return self._rkstate.dt def _get_solver_flags(self, intersect=SolverFlags.NONE) -> SolverFlags: """Returns the currently set flags of the solver, intersected with `intersect`.""" # _rkstate.flags is turned into an int-valued DeviceArray by JAX, # so we convert it back. return SolverFlags(int(self._rkstate.flags) & intersect) @property def errors(self) -> SolverFlags: """Returns the currently set error flags of the solver.""" return self._get_solver_flags(SolverFlags.ERROR_FLAGS) @property def warnings(self) -> SolverFlags: """Returns the currently set warning flags of the solver.""" return self._get_solver_flags(SolverFlags.WARNINGS_FLAGS)
class MetropolisPtSampler(MetropolisSampler): """ Metropolis-Hastings with Parallel Tempering sampler. This sampler samples an Hilbert space, producing samples off a specific dtype. The samples are generated according to a transition rule that must be specified. """ n_replicas: int = struct.field(pytree_node=False, default=32) """The number of replicas""" def __post_init__(self): super().__post_init__() if (not isinstance(self.n_replicas, int) and self.n_replicas > 0 and np.mod(self.n_replicas, 2) == 0): raise ValueError("n_replicas must be an even integer > 0.") @property def n_batches(self): return self.n_chains * self.n_replicas def _init_state(sampler, machine, params: PyTree, key: PRNGKeyT) -> MetropolisPtSamplerState: key_state, key_rule = jax.random.split(key, 2) σ = jnp.zeros( (sampler.n_batches, sampler.hilbert.size), dtype=sampler.dtype, ) rule_state = sampler.rule.init_state(sampler, machine, params, key_rule) beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas beta = jnp.tile(beta, (sampler.n_chains, 1)) return MetropolisPtSamplerState( σ=σ, rng=key_state, rule_state=rule_state, n_steps_proc=0, n_accepted_proc=0, beta=beta, beta_0_index=jnp.zeros((sampler.n_chains, ), dtype=jnp.int64), n_accepted_per_beta=jnp.zeros( (sampler.n_chains, sampler.n_replicas), dtype=jnp.int64), beta_position=jnp.zeros((sampler.n_chains, )), beta_diffusion=jnp.zeros((sampler.n_chains, )), exchange_steps=0, ) def _reset(sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState): new_rng, rng = jax.random.split(state.rng) σ = sampler.rule.random_state(sampler, machine, parameters, state, rng) rule_state = sampler.rule.reset(sampler, machine, parameters, state) beta = 1.0 - jnp.arange(sampler.n_replicas) / sampler.n_replicas beta = jnp.tile(beta, (sampler.n_chains, 1)) return state.replace( σ=σ, rng=new_rng, rule_state=rule_state, n_steps_proc=0, n_accepted_proc=0, n_accepted_per_beta=jnp.zeros( (sampler.n_chains, sampler.n_replicas), dtype=jnp.int64), beta_position=jnp.zeros((sampler.n_chains, )), beta_diffusion=jnp.zeros((sampler.n_chains)), exchange_steps=0, # beta=beta, # beta_0_index=jnp.zeros((sampler.n_chains,), dtype=jnp.int64), ) def _sample_next(sampler, machine, parameters: PyTree, state: MetropolisPtSamplerState): def loop_body(i, s): # 1 to propagate for next iteration, 1 for uniform rng and n_chains for transition kernel s["key"], key1, key2, key3, key4 = jax.random.split(s["key"], 5) # def cbi(data): # i, beta = data # print("sweep #", i, " for beta=\n", beta) # return beta # # beta = hcb.call( # cbi, # (i, s["beta"]), # result_shape=jax.ShapeDtypeStruct(s["beta"].shape, s["beta"].dtype), # ) beta = s["beta"] σp, log_prob_correction = sampler.rule.transition( sampler, machine, parameters, state, key1, s["σ"]) proposal_log_prob = sampler.machine_pow * machine.apply( parameters, σp).real uniform = jax.random.uniform(key2, shape=(sampler.n_batches, )) if log_prob_correction is not None: do_accept = uniform < jnp.exp( beta.reshape((-1, )) * (proposal_log_prob - s["log_prob"] + log_prob_correction)) else: do_accept = uniform < jnp.exp( beta.reshape((-1, )) * (proposal_log_prob - s["log_prob"])) # do_accept must match ndim of proposal and state (which is 2) s["σ"] = jnp.where(do_accept.reshape(-1, 1), σp, s["σ"]) n_accepted_per_beta = s["n_accepted_per_beta"] + do_accept.reshape( (sampler.n_chains, sampler.n_replicas)) s["log_prob"] = jax.numpy.where(do_accept.reshape(-1), proposal_log_prob, s["log_prob"]) # exchange betas # randomly decide if every set of replicas should be swapped in even or odd order swap_order = jax.random.randint( key3, minval=0, maxval=2, shape=(sampler.n_chains, ), ) # 0 or 1 # iswap_order = jnp.mod(swap_order + 1, 2) # 1 or 0 # indices of even swapped elements (per-row) idxs = jnp.arange(0, sampler.n_replicas, 2).reshape( (1, -1)) + swap_order.reshape((-1, 1)) # indices off odd swapped elements (per-row) inn = (idxs + 1) % sampler.n_replicas # for every rows of the input, swap elements at idxs with elements at inn @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def swap_rows(beta_row, idxs, inn): proposed_beta = beta_row.at[idxs].set(beta_row[inn], unique_indices=True, indices_are_sorted=True) proposed_beta = proposed_beta.at[inn].set( beta_row[idxs], unique_indices=True, indices_are_sorted=False) return proposed_beta proposed_beta = swap_rows(beta, idxs, inn) @partial(jax.vmap, in_axes=(0, 0, 0), out_axes=0) def compute_proposed_prob(prob, idxs, inn): prob_rescaled = prob[idxs] + prob[inn] return prob_rescaled # compute the probability of the swaps log_prob = (proposed_beta - state.beta) * s["log_prob"].reshape( (sampler.n_chains, sampler.n_replicas)) prob_rescaled = jnp.exp(compute_proposed_prob(log_prob, idxs, inn)) uniform = jax.random.uniform(key4, shape=(sampler.n_chains, sampler.n_replicas // 2)) do_swap = uniform < prob_rescaled do_swap = jnp.dstack((do_swap, do_swap)).reshape( (-1, sampler.n_replicas)) # concat along last dimension # roll if swap_ordeer is odd @partial(jax.vmap, in_axes=(0, 0), out_axes=0) def fix_swap(do_swap, swap_order): return jax.lax.cond(swap_order == 0, lambda x: x, lambda x: jnp.roll(x, 1), do_swap) do_swap = fix_swap(do_swap, swap_order) # jax.experimental.host_callback.id_print(state.beta) # jax.experimental.host_callback.id_print(proposed_beta) # new_beta = jax.numpy.where(do_swap, proposed_beta, beta) # def cb(data): # _bt, _pbt, new_beta, so, do_swap, log_prob, prob = data # print("--------.---------.---------.--------") # print(" cur beta:\n", _bt) # print("proposed beta:\n", _pbt) # print(" new beta:\n", new_beta) # print("swaporder :", so) # print("do_swap :\n", do_swap) # print("log_prob;\n", log_prob) # print("prob_rescaled;\n", prob) # return new_beta # # new_beta = hcb.call( # cb, # ( # beta, # proposed_beta, # new_beta, # swap_order, # do_swap, # log_prob, # prob_rescaled, # ), # result_shape=jax.ShapeDtypeStruct(new_beta.shape, new_beta.dtype), # ) # s["beta"] = new_beta swap_order = swap_order.reshape(-1) beta_0_moved = jax.vmap(lambda do_swap, i: do_swap[i], in_axes=(0, 0), out_axes=0)(do_swap, state.beta_0_index) proposed_beta_0_index = jnp.mod( state.beta_0_index + (-jnp.mod(swap_order, 2) * 2 + 1) * (-jnp.mod(state.beta_0_index, 2) * 2 + 1), sampler.n_replicas, ) s["beta_0_index"] = jnp.where(beta_0_moved, proposed_beta_0_index, s["beta_0_index"]) # swap acceptances swapped_n_accepted_per_beta = swap_rows(n_accepted_per_beta, idxs, inn) s["n_accepted_per_beta"] = jax.numpy.where( do_swap, swapped_n_accepted_per_beta, n_accepted_per_beta, ) # Update statistics to compute diffusion coefficient of replicas # Total exchange steps performed delta = s["beta_0_index"] - s["beta_position"] s["beta_position"] = s["beta_position"] + delta / ( state.exchange_steps + jnp.asarray(i, dtype=jnp.int64)) delta2 = s["beta_0_index"] - s["beta_position"] s["beta_diffusion"] = s["beta_diffusion"] + delta * delta2 return s new_rng, rng = jax.random.split(state.rng) # def cbr(data): # new_rng, rng = data # print("sample_next newrng:\n", new_rng, "\nand rng:\n", rng) # return new_rng # # new_rng = hcb.call( # cbr, # (new_rng, rng), # result_shape=jax.ShapeDtypeStruct(new_rng.shape, new_rng.dtype), # ) s = { "key": rng, "σ": state.σ, "log_prob": sampler.machine_pow * machine.apply(parameters, state.σ).real, "beta": state.beta, # for logging "beta_0_index": state.beta_0_index, "n_accepted_per_beta": state.n_accepted_per_beta, "beta_position": state.beta_position, "beta_diffusion": state.beta_diffusion, } s = jax.lax.fori_loop(0, sampler.n_sweeps, loop_body, s) new_state = state.replace( rng=new_rng, σ=s["σ"], # n_accepted=s["accepted"], n_steps_proc=state.n_steps_proc + sampler.n_sweeps * sampler.n_chains, beta=s["beta"], beta_0_index=s["beta_0_index"], beta_position=s["beta_position"], beta_diffusion=s["beta_diffusion"], exchange_steps=state.exchange_steps + sampler.n_sweeps, n_accepted_per_beta=s["n_accepted_per_beta"], ) offsets = jnp.arange(0, sampler.n_chains * sampler.n_replicas, sampler.n_replicas) return new_state, new_state.σ[new_state.beta_0_index + offsets, :]
class Sampler(abc.ABC): """ Abstract base class for all samplers. It contains the fields that all of them should posses, defining the common API. Note that fields marked with pytree_node=False are treated as static arguments when jitting. """ hilbert: AbstractHilbert = struct.field(pytree_node=False) """Hilbert space to be sampled.""" n_chains: int = struct.field(pytree_node=False, default=16) """Number of batches along the chain""" machine_pow: int = struct.field(default=2) """Exponent of the pdf sampled""" dtype: type = struct.field(pytree_node=False, default=np.float64) """DType of the states returned.""" def __post_init__(self): # Raise errors if hilbert is not an Hilbert if not isinstance(self.hilbert, AbstractHilbert): raise ValueError( "hilbert must be a subtype of netket.hilbert.AbstractHilbert, " + "instead, type {} is not.".format(type(self.hilbert)) ) if not isinstance(self.n_chains, int) and self.n_chains >= 0: raise ValueError("n_chains must be a positivee integer") # if not isinstance(self.machine_pow, int) and self.machine_pow>= 0: # raise ValueError("machine_pow must be a positivee integer") @property def n_batches(self) -> int: """ The batch size of the configuration $\sigma$ used by this sampler. In general, it is equivalent to :attr:`~Sampler.n_chains`. """ return self.n_chains def log_pdf(self, model: Union[Callable, nn.Module]) -> Callable: """ Returns a closure with the log_pdf function encoded by this sampler. Note: the result is returned as an HashablePartial so that the closure does not trigger recompilation. Args: model: The machine, or apply_fun Returns: the log probability density function """ apply_fun = get_afun_if_module(model) log_pdf = HashablePartial( lambda apply_fun, pars, σ: self.machine_pow * apply_fun(pars, σ).real, apply_fun, ) return log_pdf def init_state( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, seed: Optional[SeedType] = None, ) -> SamplerState: """ Creates the structure holding the state of the sampler. If you want reproducible samples, you should specify `seed`, otherwise the state will be initialised randomly. If running across several MPI processes, all sampler_states are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating n_rank seeds starting from the reduced one, and every rank is initialized with one of those seeds. The resulting state is guaranteed to be a frozen python dataclass (in particular, a flax's dataclass), and it can be serialized using Flax serialization methods. Args: machine: a Flax module or callable with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used. Returns: The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use. """ key = nkjax.PRNGKey(seed) key = nkjax.mpi_split(key) return sampler._init_state(get_afun_if_module(machine), parameters, key) def reset( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> SamplerState: """ Resets the state of the sampler. To be used every time the parameters are changed. Args: machine: a Flax module or callable with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If it's not provided, it will be constructed by calling :code:`sampler.init_state(machine, parameters)` with a random seed. Returns: A valid sampler state. """ if state is None: state = sampler_state(sampler, machine, parameters) return sampler._reset(get_afun_if_module(machine), parameters, state) def sample_next( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> Tuple[jnp.ndarray, SamplerState]: """ Samples the next state in the markov chain. Args: machine: a Flax module or callable apply function with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If it's not provided, it will be constructed by calling :code:`sampler.reset(machine, parameters)` with a random seed. Returns: state: The new state of the sampler σ: The next batch of samples. """ if state is None: state = sampler_state(sampler, machine, parameters) return sampler._sample_next(get_afun_if_module(machine), parameters, state) def sample( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, *, state: Optional[SamplerState] = None, chain_length: int = 1, ) -> Tuple[jnp.ndarray, SamplerState]: """ Samples chain_length elements along the chains. Arguments: sampler: The Monte Carlo sampler. machine: The model or callable to sample from (if it's a function it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`). parameters: The PyTree of parameters of the model. state: current state of the sampler. If None, then initialises it. chain_length: (default=1), the length of the chains. Returns: state: The new state of the sampler σ: The next batch of samples. """ return sample( sampler, machine, parameters, state=state, chain_length=chain_length ) def _sample_chain( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: SamplerState, chain_length: int, ) -> Tuple[jnp.ndarray, SamplerState]: """ Samples chain_length elements along the chains. In general this should not be overridden unless you want to modify the logic by which the whole sampling is performed. If using Jax, this function should be jitted Arguments: sampler: The Monte Carlo sampler. machine: The model or callable to sample from (if it's a function it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`). parameters: The PyTree of parameters of the model. state: current state of the sampler. If None, then initialises it. chain_length: (default=1), the length of the chains. Returns: state: The new state of the sampler σ: The next batch of samples. """ return _sample_chain(sampler, machine, parameters, state, chain_length) @abc.abstractmethod def _init_state(sampler, machine, params, seed) -> SamplerState: """ Implementation of init_state for subclasses of Sampler. If you sub-class Sampler, you should define this and not init_state itself, because init_state contains some common logic. """ raise NotImplementedError("init_state Not Implemented") @abc.abstractmethod def _reset(sampler, machine, parameters, state): """ Implementation of reset for subclasses of Sampler. If you sub-class Sampler, you should define _reset and not reset itself, because reset contains some common logic. """ raise NotImplementedError("reset Not Implemented") @abc.abstractmethod def _sample_next(sampler, machine, parameters, state=None): """ Implementation of sample_next for subclasses of Sampler. If you sub-class Sampler, you should define _sample_next and not sample_next itself, because reset contains some common logic. """ raise NotImplementedError("sample_next Not Implemented")