def symm_input_warning(x_shape, new_x_shape, name): warn_deprecation( (f"{len(x_shape)}-dimensional input to {name} layer is deprecated.\n" f"Input shape {x_shape} has been reshaped to {new_x_shape}, where " "the middle dimension encodes different input channels.\n" "Please provide a 3-dimensional input.\nThis warning will become an " "error in the future."))
def graph_to_N_depwarn(N, graph): if graph is not None: warn_deprecation( r""" The ``graph`` argument for hilbert spaces has been deprecated in v3.0. It has been replaced by the argument ``N`` accepting an integer, with the number of nodes in the graph. You can update your code by passing `N=_your_graph.n_nodes`. If you are also using `Ising`, `Heisenberg`, `BoseHubbard` or `GraphOperator` Hamiltonians you must now provide them with the extra argument ``graph=_your_graph``, as they no longer take it from the Hilbert space. """ ) if N == 1: return graph.n_nodes else: raise ValueError( "Graph object can only take one argument among N and graph" "(deprecated)." ) return N
def QGTOnTheFly(vstate=None, **kwargs) -> "QGTOnTheFlyT": """ Lazy representation of an S Matrix computed by performing 2 jvp and 1 vjp products, using the variational state's model, the samples that have already been computed, and the vector. The S matrix is not computed yet, but can be computed by calling :code:`to_dense`. The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contaianed in the field `sr`. Args: vstate: The variational State. """ if vstate is None: return partial(QGTOnTheFly, **kwargs) if "centered" in kwargs: warn_deprecation( "The argument `centered` is deprecated. The implementation now always behaves as if centered=False." ) return QGTOnTheFlyT( apply_fun=vstate._apply_fun, params=vstate.parameters, samples=vstate.samples, model_state=vstate.model_state, **kwargs, )
def QGTOnTheFly(vstate=None, **kwargs) -> "QGTOnTheFlyT": """ Lazy representation of an S Matrix computed by performing 2 jvp and 1 vjp products, using the variational state's model, the samples that have already been computed, and the vector. The S matrix is not computed yet, but can be computed by calling :code:`to_dense`. The details on how the ⟨S⟩⁻¹⟨F⟩ system is solved are contained in the field `sr`. Args: vstate: The variational State. """ if vstate is None: return partial(QGTOnTheFly, **kwargs) if "centered" in kwargs: warn_deprecation( "The argument `centered` is deprecated. The implementation now always behaves as if centered=False." ) kwargs.pop("centered") # TODO: Find a better way to handle this case from netket.vqs import ExactState if isinstance(vstate, ExactState): raise TypeError("Only QGTJacobianPyTree works with ExactState.") if jnp.ndim(vstate.samples) == 2: samples = vstate.samples else: samples = vstate.samples.reshape((-1, vstate.samples.shape[-1])) chunk_size = vstate.chunk_size n_samples = samples.shape[0] if chunk_size is None or chunk_size >= n_samples: mv_factory = mat_vec_factory chunking = False else: samples, _ = nkjax.chunk(samples, chunk_size) mv_factory = mat_vec_chunked_factory chunking = True mat_vec = mv_factory( forward_fn=vstate._apply_fun, params=vstate.parameters, model_state=vstate.model_state, samples=samples, ) return QGTOnTheFlyT( _mat_vec=mat_vec, _params=vstate.parameters, _chunking=chunking, **kwargs, )
def random_state( self, key=NoneType(), size: Optional[int] = NoneType(), dtype=np.float32, out: Optional[np.ndarray] = None, rgen=None, ) -> jnp.ndarray: r"""Generates either a single or a batch of uniformly distributed random states. Runs as :code:`random_state(self, key, size=None, dtype=np.float32)` by default. Args: key: rng state from a jax-style functional generator. size: If provided, returns a batch of configurations of the form :code:`(size, N)` if size is an integer or :code:`(*size, N)` if it is a tuple and where :math:`N` is the Hilbert space size. By default, a single random configuration with shape :code:`(#,)` is returned. dtype: DType of the resulting vector. out: Deprecated. Will be removed in v3.1 rgen: Deprecated. Will be removed in v3.1 Returns: A state or batch of states sampled from the uniform distribution on the hilbert space. Example: >>> import netket, jax >>> hi = netket.hilbert.Qubit(N=2) >>> k1, k2 = jax.random.split(jax.random.PRNGKey(1)) >>> print(hi.random_state(key=k1)) [1. 0.] >>> print(hi.random_state(key=k2, size=2)) [[0. 0.] [0. 1.]] """ # legacy support # TODO: Remove in 3.1 # if no positional arguments, and key is unspecified -> legacy if isinstance(key, NoneType): warn_deprecation(legacy_warn_str) # legacy sure if isinstance(size, NoneType): return self._random_state_legacy(size=None, out=out, rgen=rgen) else: return self._random_state_legacy(size=size, out=out, rgen=rgen) elif (isinstance(key, tuple) or isinstance(key, int) and isinstance(size, NoneType)): # if one positional argument legacy typee... warn_deprecation(legacy_warn_str) return self._random_state_legacy(size=key, out=out, rgen=rgen) else: from netket.hilbert import random size = size if not isinstance(size, NoneType) else None return random.random_state(self, key, size, dtype=dtype)
def n_discard(self) -> int: """ DEPRECATED: Use `n_discard_per_chain` instead. Number of discarded samples at the beginning of the markov chain. """ warn_deprecation( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Please update your code to use `n_discard_per_chain`.") return self.n_discard_per_chain
def setup(self): # TODO: evenutally remove this warning # supports a deprecated attribute if self.extra_bias: warn_deprecation( ( "`extra_bias` is detrimental for performance and is deprecated. " "Please switch to the default `extra_bias=False`. Previously saved " "parameters can be migrated using `nk.models.update_GCNN_parity`." ) ) self.n_symm = np.asarray(self.symmetries).shape[0] self.dense_symm = DenseSymmFFT( space_group=self.symmetries, shape=self.shape, features=self.features[0], dtype=self.dtype, use_bias=self.use_bias, kernel_init=self.kernel_init, bias_init=self.bias_init, precision=self.precision, ) self.equivariant_layers = [ DenseEquivariantFFT( product_table=self.product_table, shape=self.shape, features=self.features[layer + 1], use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for layer in range(self.layers - 1) ] self.equivariant_layers_flip = [ DenseEquivariantFFT( product_table=self.product_table, shape=self.shape, features=self.features[layer + 1], # this would bias the same outputs as self.equivariant use_bias=self.extra_bias and self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for layer in range(self.layers - 1) ]
def SRLazyGMRES(diag_shift: float = 0.01, centered: bool = None, **kwargs): if centered is not None: warn_deprecation( "The argument `centered` is deprecated. The implementation now always behaves as if centered=False." ) return SR( qgt.QGTOnTheFly, solver=partial(jax.scipy.sparse.linalg.gmres, **kwargs), diag_shift=diag_shift, **kwargs, )
def __init__( self, sampler, model=None, *, sampler_diag: Sampler = None, n_samples_diag: int = None, n_samples_per_rank_diag: Optional[int] = None, n_discard_per_chain_diag: Optional[int] = None, n_discard_diag: Optional[int] = None, # deprecated seed=None, sampler_seed: Optional[int] = None, variables=None, **kwargs, ): """ Constructs the MCMixedState. Arguments are the same as :class:`MCState`. Arguments: sampler: The sampler model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. n_samples: the total number of samples across chains and processes when sampling (default=1000). n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None). n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=n_samples/10). n_samples_diag: the total number of samples across chains and processes when sampling the diagonal of the density matrix (default=1000). n_samples_per_rank_diag: the total number of samples across chains on one process when sampling the diagonal. Cannot be specified together with `n_samples_diag` (default=None). n_discard_per_chain_diag: number of discarded samples at the beginning of each monte-carlo chain used when sampling the diagonal of the density matrix for observables (default=n_samples_diag/10). parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. sampler_seed: rng seed used to initialise the sampler. Defaults to a random one. mutable: Dict specifying mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defaults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. training_kwargs: a dict containing the optional keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. """ seed, seed_diag = jax.random.split(nkjax.PRNGKey(seed)) if sampler_seed is None: sampler_seed_diag = None else: sampler_seed, sampler_seed_diag = jax.random.split( nkjax.PRNGKey(sampler_seed) ) self._diagonal = None hilbert_physical = sampler.hilbert.physical super().__init__( sampler.hilbert.physical, sampler, model, **kwargs, seed=seed, sampler_seed=sampler_seed, variables=variables, ) if sampler_diag is None: sampler_diag = sampler.replace(hilbert=hilbert_physical) sampler_diag = sampler_diag.replace(machine_pow=1) diagonal_apply_fun = nkjax.HashablePartial(apply_diagonal, self._apply_fun) for kw in [ "n_samples", "n_discard", "n_discard_per_chain", ]: # TODO remove n_discard after deprecation. if kw in kwargs: kwargs.pop(kw) # TODO: remove deprecation. if n_discard_diag is not None and n_discard_per_chain_diag is not None: raise ValueError( "`n_discard_diag` has been renamed to `n_discard_per_chain_diag` and deprecated." "Specify only `n_discard_per_chain_diag`." ) elif n_discard_diag is not None: warn_deprecation( "`n_discard_diag` has been renamed to `n_discard_per_chain_diag` and deprecated." "Please update your code to `n_discard_per_chain_diag`." ) n_discard_per_chain_diag = n_discard_diag self._diagonal = MCState( sampler_diag, apply_fun=diagonal_apply_fun, n_samples=n_samples_diag, n_samples_per_rank=n_samples_per_rank_diag, n_discard_per_chain=n_discard_per_chain_diag, variables=self.variables, seed=seed_diag, sampler_seed=sampler_seed_diag, **kwargs, )
def n_discard_diag(self, val) -> int: warn_deprecation( "`n_discard_diag` has been renamed to `n_discard_per_chain_diag` and deprecated." "Please update your code to use `n_discard_per_chain_diag`." ) self.n_discard_per_chain_diag = val
def __init__( self, hamiltonian: AbstractOperator, optimizer, *args, variational_state=None, preconditioner: PreconditionerT = None, sr: PreconditionerT = None, sr_restart: bool = None, **kwargs, ): """ Initializes the driver class. Args: hamiltonian: The Hamiltonian of the system. optimizer: Determines how optimization steps are performed given the bare energy gradient. preconditioner: Determines which preconditioner to use for the loss gradient. This must be a tuple of `(object, solver)` as documented in the section `preconditioners` in the documentation. The standard preconditioner included with NetKet is Stochastic Reconfiguration. By default, no preconditioner is used and the bare gradient is passed to the optimizer. """ if variational_state is None: variational_state = MCState(*args, **kwargs) if variational_state.hilbert != hamiltonian.hilbert: raise TypeError( dedent( f"""the variational_state has hilbert space {variational_state.hilbert} (this is normally defined by the hilbert space in the sampler), but the hamiltonian has hilbert space {hamiltonian.hilbert}. The two should match. """)) if sr is not None: if preconditioner is not None: raise ValueError( "sr is deprecated in favour of preconditioner kwarg. You should not pass both" ) else: preconditioner = sr warn_deprecation(( "The `sr` keyword argument is deprecated in favour of `preconditioner`." "Please update your code to `VMC(.., preconditioner=your_sr)`" )) if sr_restart is not None: if preconditioner is None: raise ValueError( "sr_restart only makes sense if you have a preconditioner/SR." ) else: preconditioner.solver_restart = sr_restart warn_deprecation(( "The `sr_restart` keyword argument is deprecated in favour of specifying " "`solver_restart` in the constructor of the SR object." "Please update your code to `VMC(.., preconditioner=nk.optimizer.SR(..., solver_restart=True/False))`" )) # move as kwarg once deprecations are removed if preconditioner is None: preconditioner = identity_preconditioner super().__init__(variational_state, optimizer, minimized_quantity_name="Energy") self._ham = hamiltonian.collect() # type: AbstractOperator self.preconditioner = preconditioner self._dp = None # type: PyTree self._S = None self._sr_info = None
def __init__( self, sampler: Sampler, model=None, *, n_samples: int = None, n_samples_per_rank: Optional[int] = None, n_discard: Optional[int] = None, # deprecated n_discard_per_chain: Optional[int] = None, variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, sample_fun: Callable = None, seed: Optional[SeedT] = None, sampler_seed: Optional[SeedT] = None, mutable: bool = False, training_kwargs: Dict = {}, ): """ Constructs the MCState. Args: sampler: The sampler model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. n_samples: the total number of samples across chains and processes when sampling (default=1000). n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None). n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=0 for exact sampler, and n_samples/10 for approximate sampler). parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. sampler_seed: rng seed used to initialise the sampler. Defaults to a random one. mutable: Dict specifing mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. variables: Optional initial value for the variables (parameters and model state) of the model. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defafults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. sample_fun: Optional function used to sample the state, if it is not the same as `apply_fun`. training_kwargs: a dict containing the optionaal keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. n_discard: DEPRECATED. Please use `n_discard_per_chain` which has the same behaviour. """ super().__init__(sampler.hilbert) # Init type 1: pass in a model if model is not None: # extract init and apply functions # Wrap it in an HashablePartial because if two instances of the same model are provided, # model.apply and model2.apply will be different methods forcing recompilation, but # model and model2 will have the same hash. _, model = maybe_wrap_module(model) self._model = model self._init_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.init(*args, **kwargs), model) self._apply_fun = nkjax.HashablePartial( lambda model, *args, **kwargs: model.apply(*args, **kwargs), model) elif apply_fun is not None: self._apply_fun = apply_fun if init_fun is not None: self._init_fun = init_fun elif variables is None: raise ValueError( "If you don't provide variables, you must pass a valid init_fun." ) self._model = wrap_afun(apply_fun) else: raise ValueError( "Must either pass the model or apply_fun, otherwise how do you think we" "gonna evaluate the model?") # default argument for n_samples/n_samples_per_rank if n_samples is None and n_samples_per_rank is None: n_samples = 1000 elif n_samples is not None and n_samples_per_rank is not None: raise ValueError( "Only one argument between `n_samples` and `n_samples_per_rank`" "can be specified at the same time.") if n_discard is not None and n_discard_per_chain is not None: raise ValueError( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Specify only `n_discard_per_chain`.") elif n_discard is not None: warn_deprecation( "`n_discard` has been renamed to `n_discard_per_chain` and deprecated." "Please update your code to use `n_discard_per_chain`.") n_discard_per_chain = n_discard if sample_fun is not None: self._sample_fun = sample_fun else: self._sample_fun = self._apply_fun self.mutable = mutable self.training_kwargs = flax.core.freeze(training_kwargs) if variables is not None: self.variables = variables else: self.init(seed, dtype=sampler.dtype) if sampler_seed is None and seed is not None: key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2) sampler_seed = key2 self._sampler_seed = sampler_seed self.sampler = sampler if n_samples is not None: self.n_samples = n_samples else: self.n_samples_per_rank = n_samples_per_rank self.n_discard_per_chain = n_discard_per_chain
def __init__( self, lindbladian, optimizer, *args, variational_state=None, preconditioner=None, sr=None, sr_restart=None, **kwargs, ): """ Initializes the driver class. Args: lindbladian: The Lindbladian of the system. optimizer: Determines how optimization steps are performed given the bare energy gradient. preconditioner: Determines which preconditioner to use for the loss gradient. This must be a tuple of `(object, solver)` as documented in the section `preconditioners` in the documentation. The standard preconditioner included with NetKet is Stochastic Reconfiguration. By default, no preconditioner is used and the bare gradient is passed to the optimizer. """ if variational_state is None: variational_state = MCMixedState(*args, **kwargs) if not isinstance(lindbladian, AbstractSuperOperator): raise TypeError("The first argument must be a super-operator") if sr is not None: if preconditioner is not None: raise ValueError( "sr is deprecated in favour of preconditioner kwarg. You should not pass both" ) else: preconditioner = sr warn_deprecation( ( "The `sr` keyword argument is deprecated in favour of `preconditioner`." "Please update your code to `SteadyState(.., precondioner=your_sr)`" ) ) if sr_restart is not None: if preconditioner is None: raise ValueError( "sr_restart only makes sense if you have a preconditioner/SR." ) else: preconditioner.solver_restart = sr_restart warn_deprecation( ( "The `sr_restart` keyword argument is deprecated in favour of specifiying " "`solver_restart` in the constructor of the SR object." "Please update your code to `SteadyState(.., preconditioner=nk.optimizer.SR(..., solver_restart=True/False))`" ) ) # move as kwarg once deprecations are removed if preconditioner is None: preconditioner = identity_preconditioner super().__init__(variational_state, optimizer, minimized_quantity_name="LdagL") self._lind = lindbladian self._ldag_l = Squared(lindbladian) self.preconditioner = preconditioner self._dp = None self._S = None self._sr_info = None
def DenseEquivariant( symmetries, features: int = None, mode="auto", shape=None, point_group=None, in_features=None, **kwargs, ): r"""A group convolution operation that is equivariant over a symmetry group. Acts on a feature map of symmetry poses of shape [num_samples, in_features, num_symm] and returns a feature map of poses of shape [num_samples, features, num_symm] G-convolutions are described in ` Cohen et. {\it al} <http://proceedings.mlr.press/v48/cohenc16.pdf>`_ and applied to quantum many-body problems in ` Roth et. {\it al} <https://arxiv.org/pdf/2104.05085.pdf>`_ The G-convolution generalizes the convolution to non-commuting groups: .. math :: C^i_g = \sum_h {\bf W}_{g^{-1} h} \cdot {\bf f}_h Group elements that differ by the same symmetry operation (i.e. :math:`g = xh` and :math:`g' = xh'`) are connected by the same filter. This layer maps an input of shape `(..., in_features, n_sites)` to an output of shape `(..., features, num_symm)`. Args: symmetries: A specification of the symmetry group. Can be given by a nk.graph.Graph, an nk.utils.PermuationGroup, a list of irreducible representations or a product table. point_group: The point group, from which the space group is built. If symmetries is a graph the default point group is overwritten. mode: string "fft, irreps, matrix, auto" specifying whether to use a fast fourier transform over the translation group, a fourier transform using the irreducible representations or by constructing the full kernel matrix. shape: A tuple specifying the dimensions of the translation group. features: The number of output features. The full output shape is [n_batch,features,n_symm]. use_bias: A bool specifying whether to add a bias to the output (default: True). mask: An optional array of shape [n_sites] consisting of ones and zeros that can be used to give the kernel a particular shape. dtype: The datatype of the weights. Defaults to a 64bit float. precision: Optional argument specifying numerical precision of the computation. see `jax.lax.Precision`for details. kernel_init: Optional kernel initialization function. Defaults to variance scaling. bias_init: Optional bias initialization function. Defaults to zero initialization. """ # deprecate in_features if in_features is not None: warn_deprecation(( "`in_features` is now automatically detected from the input and deprecated." "Please remove it when calling `DenseEquivariant`.")) if "out_features" in kwargs: warn_deprecation( "`out_features` has been renamed to `features` and the old name is " "now deprecated. Please update your code.") if features is not None: raise ValueError( "You must only specify `features`. `out_features` is deprecated." ) features = kwargs.pop("out_features") if features is None: raise ValueError( "`features` not specified (the number of output features).") kwargs["features"] = features if isinstance(symmetries, Lattice) and (point_group is not None or symmetries._point_group is not None): shape = tuple(symmetries.extent) # With graph try to find point group, otherwise default to automorphisms sg = symmetries.space_group(point_group) if mode == "auto": mode = "fft" elif isinstance(symmetries, Graph): sg = symmetries.automorphisms() if mode == "auto": mode = "irreps" elif mode == "fft": raise ValueError( "When requesting 'mode=fft' a valid point group must be specified" "in order to construct the space group") elif isinstance(symmetries, PermutationGroup): # If we get a group and default to irrep projection if mode == "auto": mode = "irreps" sg = symmetries elif isinstance(symmetries, Sequence): if mode not in ["irreps", "auto"]: raise ValueError( "Specification of symmetries incompatible with mode") return DenseEquivariantIrrep(symmetries, **kwargs) else: if symmetries.ndim == 2 and symmetries.shape[0] == symmetries.shape[1]: if mode == "irreps": raise ValueError( "Specification of symmetries incompatible with mode") elif mode == "matrix": return DenseEquivariantMatrix(symmetries, **kwargs) else: if shape is None: raise TypeError( "When requesting `mode=fft`, the shape of the translation group must be specified. " "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to " "the symmetries keyword argument.") else: return DenseEquivariantFFT(symmetries, shape=shape, **kwargs) return ValueError("Invalid Specification of Symmetries") if mode == "fft": if shape is None: raise TypeError( "When requesting `mode=fft`, the shape of the translation group must be specified. " "Either supply the `shape` keyword argument or pass a `netket.graph.Graph` object to " "the symmetries keyword argument.") else: return DenseEquivariantFFT(HashableArray(sg.product_table), shape=shape, **kwargs) elif mode in ["irreps", "auto"]: irreps = tuple(HashableArray(irrep) for irrep in sg.irrep_matrices()) return DenseEquivariantIrrep(irreps, **kwargs) elif mode == "matrix": return DenseEquivariantMatrix(HashableArray(sg.product_table), **kwargs) else: raise ValueError( f"Unknown mode={mode}. Valid modes are 'fft', 'matrix', 'irreps' or 'auto'." )