def sample( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, *, state: Optional[SamplerState] = None, chain_length: int = 1, ) -> Tuple[jnp.ndarray, SamplerState]: """ Samples `chain_length` batches of samples along the chains. Arguments: machine: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If not specified, then initialize and reset it. chain_length: The length of the chains (default = 1). Returns: σ: The generated batches of samples. state: The new state of the sampler. """ if state is None: state = sampler.reset(machine, parameters) return sampler._sample_chain(wrap_afun(machine), parameters, state, chain_length)
def samples( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, *, state: Optional[SamplerState] = None, chain_length: int = 1, ) -> Iterator[jnp.ndarray]: """ Returns a generator sampling `chain_length` batches of samples along the chains. Arguments: machine: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If not specified, then initialize and reset it. chain_length: The length of the chains (default = 1). """ if state is None: state = sampler.reset(machine, parameters) machine = wrap_afun(machine) for i in range(chain_length): samples, state = sampler._sample_chain(machine, parameters, state, 1) yield samples[0, :, :]
def init_state( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, seed: Optional[SeedT] = None, ) -> SamplerState: """ Creates the structure holding the state of the sampler. If you want reproducible samples, you should specify `seed`, otherwise the state will be initialised randomly. If running across several MPI processes, all `sampler_state`s are guaranteed to be in a different (but deterministic) state. This is achieved by first reducing (summing) the seed provided to every MPI rank, then generating `n_rank` seeds starting from the reduced one, and every rank is initialized with one of those seeds. The resulting state is guaranteed to be a frozen Python dataclass (in particular, a Flax dataclass), and it can be serialized using Flax serialization methods. Args: machine: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. parameters: The PyTree of parameters of the model. seed: An optional seed or jax PRNGKey. If not specified, a random seed will be used. Returns: The structure holding the state of the sampler. In general you should not expect it to be in a valid state, and should reset it before use. """ key = nkjax.PRNGKey(seed) key = nkjax.mpi_split(key) return sampler._init_state(wrap_afun(machine), parameters, key)
def sample_next( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> Tuple[SamplerState, jnp.ndarray]: """ Samples the next state in the Markov chain. Args: machine: A Flax module or callable with the forward pass of the log-pdf. If it is a callable, it should have the signature :code:`f(parameters, σ) -> jnp.ndarray`. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If not specified, then initialize and reset it. Returns: state: The new state of the sampler. σ: The next batch of samples. Note: The return order is inverted wrt `sample` because when called inside of a scan function the first returned argument should be the state. """ if state is None: state = sampler.reset(machine, parameters) return sampler._sample_next(wrap_afun(machine), parameters, state)
def sample_next( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> Tuple[jnp.ndarray, SamplerState]: """ Samples the next state in the markov chain. Args: machine: a Flax module or callable apply function with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If it's not provided, it will be constructed by calling :code:`sampler.reset(machine, parameters)` with a random seed. Returns: state: The new state of the sampler σ: The next batch of samples. """ # Note: the return order is inverted wrt `.sample` because when called inside of # a scan function the first returned argument should be the state. if state is None: state = sampler_state(sampler, machine, parameters) return sampler._sample_next(wrap_afun(machine), parameters, state)
def reset( sampler, machine: Union[Callable, nn.Module], parameters: PyTree, state: Optional[SamplerState] = None, ) -> SamplerState: """ Resets the state of the sampler. To be used every time the parameters are changed. Args: machine: a Flax module or callable with the forward pass of the log-pdf. parameters: The PyTree of parameters of the model. state: The current state of the sampler. If it's not provided, it will be constructed by calling :code:`sampler.init_state(machine, parameters)` with a random seed. Returns: A valid sampler state. """ if state is None: state = sampler_state(sampler, machine, parameters) return sampler._reset(wrap_afun(machine), parameters, state)
def __init__( self, sampler: Sampler, model=None, *, n_samples: int = None, n_samples_per_rank: Optional[int] = None, n_discard: Optional[int] = None, # deprecated n_discard_per_chain: Optional[int] = None, variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, sample_fun: Callable = None, seed: Optional[SeedT] = None, sampler_seed: Optional[SeedT] = None, mutable: bool = False, training_kwargs: Dict = {}, ): """ Constructs the MCState. Args: sampler: The sampler model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. n_samples: the total number of samples across chains and processes when sampling (default=1000). n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None). n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=0 for exact sampler, and n_samples/10 for approximate sampler). parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. sampler_seed: rng seed used to initialise the sampler. Defaults to a random one. mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. variables: Optional initial value for the variables (parameters and model state) of the model. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. sample_fun: Optional function used to sample the state, if it is not the same as `apply_fun`. training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. n_discard: DEPRECATED. Please use `n_discard_per_chain` which has the same behaviour. """ super().__init__(sampler.hilbert) # Init type 1: pass in a model if model is not None: # extract init and apply functions # Wrap it in an HashablePartial because if two instances of the same model are provided, # model.apply and model2.apply will be different methods forcing recompilation, but # model and model2 will have the same hash. _, model = maybe_wrap_module(model) self._model = model self._init_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.init(*args, **kwargs), model) self._apply_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.apply(*args, **kwargs), model) elif apply_fun is not None: self._apply_fun = apply_fun if init_fun is not None: self._init_fun = init_fun elif variables is None: raise ValueError( "If you don't provide variables, you must pass a valid init_fun." ) self._model = wrap_afun(apply_fun) else: raise ValueError( "Must either pass the model or apply_fun, otherwise how do you think we" "gonna evaluate the model?") # default argument for n_samples/n_samples_per_rank if n_samples is None and n_samples_per_rank is None: n_samples = 1000 elif n_samples is not None and n_samples_per_rank is not None: raise ValueError( "Only one argument between `n_samples` and `n_samples_per_rank`" "can be specified at the same time.") if n_discard is not None and n_discard_per_chain is not None: raise ValueError( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Specify only `n_discard_per_chain`.") elif n_discard is not None: warn_deprecation( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Please update your code to use `n_discard_per_chain`.") n_discard_per_chain = n_discard if sample_fun is not None: self._sample_fun = sample_fun else: self._sample_fun = self._apply_fun self.mutable = mutable self.training_kwargs = flax.core.freeze(training_kwargs) if variables is not None: self.variables = variables else: self.init(seed, dtype=sampler.dtype) if sampler_seed is None and seed is not None: key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2) sampler_seed = key2 self._sampler_seed = sampler_seed self.sampler = sampler if n_samples is not None: self.n_samples = n_samples else: self.n_samples_per_rank = n_samples_per_rank self.n_discard_per_chain = n_discard_per_chain
def __init__( self, hilbert: AbstractHilbert, model=None, *, variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, seed: Optional[SeedT] = None, mutable: bool = False, training_kwargs: Dict = {}, dtype=float, ): """ Constructs the ExactState. Args: hilbert: The Hilbert space model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. variables: Optional initial value for the variables (parameters and model state) of the model. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. """ super().__init__(hilbert) # Init type 1: pass in a model if model is not None: # extract init and apply functions # Wrap it in an HashablePartial because if two instances of the same model are provided, # model.apply and model2.apply will be different methods forcing recompilation, but # model and model2 will have the same hash. _, model = maybe_wrap_module(model) self._model = model self._init_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.init(*args, **kwargs), model) self._apply_fun = wrap_to_support_scalar( nkjax.HashablePartial( lambda model, *args, **kwargs: model.apply( *args, **kwargs), model)) elif apply_fun is not None: self._apply_fun = wrap_to_support_scalar(apply_fun) if init_fun is not None: self._init_fun = init_fun elif variables is None: raise ValueError( "If you don't provide variables, you must pass a valid init_fun." ) self._model = wrap_afun(apply_fun) else: raise ValueError( "Must either pass the model or apply_fun, otherwise how do you think we" "gonna evaluate the model?") self.mutable = mutable self.training_kwargs = flax.core.freeze(training_kwargs) if variables is not None: self.variables = variables else: self.init(seed, dtype=dtype) self._states = None """ Caches the output of `self._all_states()`. """ self._array = None """ Caches the output of `self.to_array()`. """ self._pdf = None """