def __init__( self, sampler: Sampler, model=None, *, n_samples: int = 1000, n_discard: Optional[int] = None, variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, sample_fun: Callable = None, seed: Optional[SeedT] = None, sampler_seed: Optional[SeedT] = None, mutable: bool = False, training_kwargs: Dict = {}, ): """ Constructs the MCState. Arguments: sampler: The sampler model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. Keyword Arguments: n_samples: the total number of samples across chains and processes when sampling (default=1000). n_discard: number of discarded samples at the beginning of each monte-carlo chain (default=n_samples/10). parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. sampler_seed: rng seed used to initialise the sampler. Defaults to a random one. mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. """ super().__init__(sampler.hilbert) # Init type 1: pass in a model if model is not None: # exetract init and apply functions # Wrap it in an HashablePartial because if two instances of the same model are provided, # model.apply and model2.apply will be different methods forcing recompilation, but # model and model2 will have the same hash. _, model = maybe_wrap_module(model) self.model = model self._init_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.init(*args, **kwargs), model ) self._apply_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.apply(*args, **kwargs), model ) elif apply_fun is not None: self._apply_fun = apply_fun if init_fun is not None: self._init_fun = init_fun elif variables is None: raise ValueError( "If you don't provide variables, you must pass a valid init_fun." ) else: raise ValueError( "Must either pass the model or apply_fun, otherwise how do you think we" "gonna evaluate the model?" ) if sample_fun is not None: self._sample_fun = sample_fun else: self._sample_fun = self._apply_fun self.mutable = mutable self.training_kwargs = flax.core.freeze(training_kwargs) if variables is not None: self.variables = variables else: self.init(seed, dtype=sampler.dtype) if sampler_seed is None and seed is not None: key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2) sampler_seed = key2 self._sampler_seed = sampler_seed self.sampler = sampler self.n_samples = n_samples self.n_discard = n_discard
def __init__( self, sampler, model=None, *, sampler_diag: Sampler = None, n_samples_diag: int = None, n_samples_per_rank_diag: Optional[int] = None, n_discard_per_chain_diag: Optional[int] = None, n_discard_diag: Optional[int] = None, # deprecated seed=None, sampler_seed: Optional[int] = None, variables=None, **kwargs, ): """ Constructs the MCMixedState. Arguments are the same as :class:`MCState`. Arguments: sampler: The sampler model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. n_samples: the total number of samples across chains and processes when sampling (default=1000). n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None). n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=n_samples/10). n_samples_diag: the total number of samples across chains and processes when sampling the diagonal of the density matrix (default=1000). n_samples_per_rank_diag: the total number of samples across chains on one process when sampling the diagonal. Cannot be specified together with `n_samples_diag` (default=None). n_discard_per_chain_diag: number of discarded samples at the beginning of each monte-carlo chain used when sampling the diagonal of the density matrix for observables (default=n_samples_diag/10). parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. sampler_seed: rng seed used to initialise the sampler. Defaults to a random one. mutable: Dict specifying mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defaults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. training_kwargs: a dict containing the optional keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. """ seed, seed_diag = jax.random.split(nkjax.PRNGKey(seed)) if sampler_seed is None: sampler_seed_diag = None else: sampler_seed, sampler_seed_diag = jax.random.split( nkjax.PRNGKey(sampler_seed) ) self._diagonal = None hilbert_physical = sampler.hilbert.physical super().__init__( sampler.hilbert.physical, sampler, model, **kwargs, seed=seed, sampler_seed=sampler_seed, variables=variables, ) if sampler_diag is None: sampler_diag = sampler.replace(hilbert=hilbert_physical) sampler_diag = sampler_diag.replace(machine_pow=1) diagonal_apply_fun = nkjax.HashablePartial(apply_diagonal, self._apply_fun) for kw in [ "n_samples", "n_discard", "n_discard_per_chain", ]: # TODO remove n_discard after deprecation. if kw in kwargs: kwargs.pop(kw) # TODO: remove deprecation. if n_discard_diag is not None and n_discard_per_chain_diag is not None: raise ValueError( "`n_discard_diag` has been renamed to `n_discard_per_chain_diag` and deprecated." "Specify only `n_discard_per_chain_diag`." ) elif n_discard_diag is not None: warn_deprecation( "`n_discard_diag` has been renamed to `n_discard_per_chain_diag` and deprecated." "Please update your code to `n_discard_per_chain_diag`." ) n_discard_per_chain_diag = n_discard_diag self._diagonal = MCState( sampler_diag, apply_fun=diagonal_apply_fun, n_samples=n_samples_diag, n_samples_per_rank=n_samples_per_rank_diag, n_discard_per_chain=n_discard_per_chain_diag, variables=self.variables, seed=seed_diag, sampler_seed=sampler_seed_diag, **kwargs, )
def __init__( self, sampler: Sampler, model=None, *, n_samples: int = None, n_samples_per_rank: Optional[int] = None, n_discard: Optional[int] = None, # deprecated n_discard_per_chain: Optional[int] = None, chunk_size: Optional[int] = None, variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, sample_fun: Callable = None, seed: Optional[SeedT] = None, sampler_seed: Optional[SeedT] = None, mutable: bool = False, training_kwargs: Dict = {}, ): """ Constructs the MCState. Args: sampler: The sampler model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. n_samples: the total number of samples across chains and processes when sampling (default=1000). n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None). n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=0 for exact sampler, and n_samples/10 for approximate sampler). parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. sampler_seed: rng seed used to initialise the sampler. Defaults to a random one. mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. variables: Optional initial value for the variables (parameters and model state) of the model. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. sample_fun: Optional function used to sample the state, if it is not the same as `apply_fun`. training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. n_discard: DEPRECATED. Please use `n_discard_per_chain` which has the same behaviour. """ super().__init__(sampler.hilbert) # Init type 1: pass in a model if model is not None: # extract init and apply functions # Wrap it in an HashablePartial because if two instances of the same model are provided, # model.apply and model2.apply will be different methods forcing recompilation, but # model and model2 will have the same hash. _, model = maybe_wrap_module(model) self._model = model self._init_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.init(*args, **kwargs), model ) self._apply_fun = wrap_to_support_scalar( nkjax.HashablePartial( lambda model, pars, x, **kwargs: model.apply(pars, x, **kwargs), model, ) ) elif apply_fun is not None: self._apply_fun = wrap_to_support_scalar(apply_fun) if init_fun is not None: self._init_fun = init_fun elif variables is None: raise ValueError( "If you don't provide variables, you must pass a valid init_fun." ) self._model = wrap_afun(apply_fun) else: raise ValueError( "Must either pass the model or apply_fun, otherwise how do you think we" "gonna evaluate the model?" ) # default argument for n_samples/n_samples_per_rank if n_samples is None and n_samples_per_rank is None: # get the first multiple of sampler.n_chains above 1000 to avoid # printing a warning on construction n_samples = int(np.ceil(1000 / sampler.n_chains) * sampler.n_chains) elif n_samples is not None and n_samples_per_rank is not None: raise ValueError( "Only one argument between `n_samples` and `n_samples_per_rank`" "can be specified at the same time." ) if n_discard is not None and n_discard_per_chain is not None: raise ValueError( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Specify only `n_discard_per_chain`." ) elif n_discard is not None: warn_deprecation( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Please update your code to use `n_discard_per_chain`." ) n_discard_per_chain = n_discard if sample_fun is not None: self._sample_fun = sample_fun else: self._sample_fun = self._apply_fun self.mutable = mutable self.training_kwargs = flax.core.freeze(training_kwargs) if variables is not None: self.variables = variables else: self.init(seed, dtype=sampler.dtype) if sampler_seed is None and seed is not None: key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2) sampler_seed = key2 self._sampler_seed = sampler_seed self.sampler = sampler if n_samples is not None: self.n_samples = n_samples else: self.n_samples_per_rank = n_samples_per_rank self.n_discard_per_chain = n_discard_per_chain self.chunk_size = chunk_size
def get_local_kernel( # noqa: F811 vstate: MCState, Ô: ContinuousOperator, chunk_size: int ): return nkjax.HashablePartial(_local_continuous_kernel, Ô._expect_kernel)
def __init__( self, hilbert: AbstractHilbert, model=None, *, variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, seed: Optional[SeedT] = None, mutable: bool = False, training_kwargs: Dict = {}, dtype=float, ): """ Constructs the ExactState. Args: hilbert: The Hilbert space model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. variables: Optional initial value for the variables (parameters and model state) of the model. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. """ super().__init__(hilbert) # Init type 1: pass in a model if model is not None: # extract init and apply functions # Wrap it in an HashablePartial because if two instances of the same model are provided, # model.apply and model2.apply will be different methods forcing recompilation, but # model and model2 will have the same hash. _, model = maybe_wrap_module(model) self._model = model self._init_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.init(*args, **kwargs), model) self._apply_fun = wrap_to_support_scalar( nkjax.HashablePartial( lambda model, *args, **kwargs: model.apply( *args, **kwargs), model)) elif apply_fun is not None: self._apply_fun = wrap_to_support_scalar(apply_fun) if init_fun is not None: self._init_fun = init_fun elif variables is None: raise ValueError( "If you don't provide variables, you must pass a valid init_fun." ) self._model = wrap_afun(apply_fun) else: raise ValueError( "Must either pass the model or apply_fun, otherwise how do you think we" "gonna evaluate the model?") self.mutable = mutable self.training_kwargs = flax.core.freeze(training_kwargs) if variables is not None: self.variables = variables else: self.init(seed, dtype=dtype) self._states = None """ Caches the output of `self._all_states()`. """ self._array = None """ Caches the output of `self.to_array()`. """ self._pdf = None """