def test_batching_wrapper(): from netket.utils import wrap_to_support_scalar def applyfun(pars, x, mutable=False): # this assert fails if the wrapper is not working assert x.ndim > 1 if not mutable: return x.sum(axis=-1) else: return (x.sum(axis=-1), {}) # check same hash assert hash(wrap_to_support_scalar(applyfun)) == hash( wrap_to_support_scalar(applyfun) ) afun = wrap_to_support_scalar(applyfun) x = jnp.ones(5) xb = jnp.ones((1, 5)) # no mutable state res = afun(None, x) assert res.shape == () assert res == jnp.sum(x, axis=-1) res = afun(None, xb) assert res.shape == (1,) assert res == jnp.sum(x, axis=-1) # mutable state res = afun(None, x, mutable=True)[0] assert res.shape == () assert res == jnp.sum(x, axis=-1) res = afun(None, xb, mutable=True)[0] assert res.shape == (1,) assert res == jnp.sum(x, axis=-1)
def __init__( self, sampler: Sampler, model=None, *, n_samples: int = None, n_samples_per_rank: Optional[int] = None, n_discard: Optional[int] = None, # deprecated n_discard_per_chain: Optional[int] = None, chunk_size: Optional[int] = None, variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, sample_fun: Callable = None, seed: Optional[SeedT] = None, sampler_seed: Optional[SeedT] = None, mutable: bool = False, training_kwargs: Dict = {}, ): """ Constructs the MCState. Args: sampler: The sampler model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. n_samples: the total number of samples across chains and processes when sampling (default=1000). n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None). n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=0 for exact sampler, and n_samples/10 for approximate sampler). parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. sampler_seed: rng seed used to initialise the sampler. Defaults to a random one. mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. variables: Optional initial value for the variables (parameters and model state) of the model. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. sample_fun: Optional function used to sample the state, if it is not the same as `apply_fun`. training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. n_discard: DEPRECATED. Please use `n_discard_per_chain` which has the same behaviour. """ super().__init__(sampler.hilbert) # Init type 1: pass in a model if model is not None: # extract init and apply functions # Wrap it in an HashablePartial because if two instances of the same model are provided, # model.apply and model2.apply will be different methods forcing recompilation, but # model and model2 will have the same hash. _, model = maybe_wrap_module(model) self._model = model self._init_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.init(*args, **kwargs), model ) self._apply_fun = wrap_to_support_scalar( nkjax.HashablePartial( lambda model, pars, x, **kwargs: model.apply(pars, x, **kwargs), model, ) ) elif apply_fun is not None: self._apply_fun = wrap_to_support_scalar(apply_fun) if init_fun is not None: self._init_fun = init_fun elif variables is None: raise ValueError( "If you don't provide variables, you must pass a valid init_fun." ) self._model = wrap_afun(apply_fun) else: raise ValueError( "Must either pass the model or apply_fun, otherwise how do you think we" "gonna evaluate the model?" ) # default argument for n_samples/n_samples_per_rank if n_samples is None and n_samples_per_rank is None: # get the first multiple of sampler.n_chains above 1000 to avoid # printing a warning on construction n_samples = int(np.ceil(1000 / sampler.n_chains) * sampler.n_chains) elif n_samples is not None and n_samples_per_rank is not None: raise ValueError( "Only one argument between `n_samples` and `n_samples_per_rank`" "can be specified at the same time." ) if n_discard is not None and n_discard_per_chain is not None: raise ValueError( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Specify only `n_discard_per_chain`." ) elif n_discard is not None: warn_deprecation( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Please update your code to use `n_discard_per_chain`." ) n_discard_per_chain = n_discard if sample_fun is not None: self._sample_fun = sample_fun else: self._sample_fun = self._apply_fun self.mutable = mutable self.training_kwargs = flax.core.freeze(training_kwargs) if variables is not None: self.variables = variables else: self.init(seed, dtype=sampler.dtype) if sampler_seed is None and seed is not None: key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2) sampler_seed = key2 self._sampler_seed = sampler_seed self.sampler = sampler if n_samples is not None: self.n_samples = n_samples else: self.n_samples_per_rank = n_samples_per_rank self.n_discard_per_chain = n_discard_per_chain self.chunk_size = chunk_size
def __init__( self, hilbert: AbstractHilbert, model=None, *, variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, seed: Optional[SeedT] = None, mutable: bool = False, training_kwargs: Dict = {}, dtype=float, ): """ Constructs the ExactState. Args: hilbert: The Hilbert space model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. variables: Optional initial value for the variables (parameters and model state) of the model. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. """ super().__init__(hilbert) # Init type 1: pass in a model if model is not None: # extract init and apply functions # Wrap it in an HashablePartial because if two instances of the same model are provided, # model.apply and model2.apply will be different methods forcing recompilation, but # model and model2 will have the same hash. _, model = maybe_wrap_module(model) self._model = model self._init_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.init(*args, **kwargs), model) self._apply_fun = wrap_to_support_scalar( nkjax.HashablePartial( lambda model, *args, **kwargs: model.apply( *args, **kwargs), model)) elif apply_fun is not None: self._apply_fun = wrap_to_support_scalar(apply_fun) if init_fun is not None: self._init_fun = init_fun elif variables is None: raise ValueError( "If you don't provide variables, you must pass a valid init_fun." ) self._model = wrap_afun(apply_fun) else: raise ValueError( "Must either pass the model or apply_fun, otherwise how do you think we" "gonna evaluate the model?") self.mutable = mutable self.training_kwargs = flax.core.freeze(training_kwargs) if variables is not None: self.variables = variables else: self.init(seed, dtype=dtype) self._states = None """ Caches the output of `self._all_states()`. """ self._array = None """ Caches the output of `self.to_array()`. """ self._pdf = None """