class RBMSymm(nn.Module): """A symmetrized RBM using the :ref:`netket.nn.DenseSymm` layer internally.""" symmetries: Union[HashableArray, PermutationGroup] """A group of symmetry operations (or array of permutation indices) over which the layer should be invariant. Numpy/Jax arrays must be wrapped into an :class:`netket.utils.HashableArray`. """ dtype: Any = np.float64 """The dtype of the weights.""" activation: Any = nknn.log_cosh """The nonlinear activation function.""" alpha: Union[float, int] = 1 """feature density. Number of features equal to alpha * input.shape[-1]""" use_hidden_bias: bool = True """if True uses a bias in the dense layer (hidden layer bias).""" use_visible_bias: bool = True """if True adds a bias to the input not passed through the nonlinear layer.""" precision: Any = None """numerical precision of the computation see `jax.lax.Precision`for details.""" kernel_init: NNInitFunc = normal(stddev=0.1) """Initializer for the Dense layer matrix.""" hidden_bias_init: NNInitFunc = normal(stddev=0.1) """Initializer for the hidden bias.""" visible_bias_init: NNInitFunc = normal(stddev=0.1) """Initializer for the visible bias.""" def setup(self): self.n_symm, self.n_sites = np.asarray(self.symmetries).shape self.features = int(self.alpha * self.n_sites / self.n_symm) if self.alpha > 0 and self.features == 0: raise ValueError( f"RBMSymm: alpha={self.alpha} is too small " f"for {self.n_symm} permutations, alpha ≥ {self.n_symm / self.n_sites} is needed." ) @nn.compact def __call__(self, x_in): x = nknn.DenseSymm( name="Dense", mode="matrix", symmetries=self.symmetries, features=self.features, dtype=self.dtype, use_bias=self.use_hidden_bias, kernel_init=self.kernel_init, bias_init=self.hidden_bias_init, precision=self.precision, )(x_in) x = self.activation(x) x = x.reshape(-1, self.features * self.n_symm) x = jnp.sum(x, axis=-1) if self.use_visible_bias: v_bias = self.param("visible_bias", self.visible_bias_init, (1, ), self.dtype) out_bias = v_bias[0] * jnp.sum(x_in, axis=-1) return x + out_bias else: return x
class RBMSymm(nn.Module): """A symmetrized RBM using the :ref:`netket.nn.DenseSymm` layer internally.""" permutations: Callable[[], Array] """See documentation of :ref:`netket.nn.DenseSymm`.""" dtype: Any = np.float64 """The dtype of the weights.""" activation: Any = nknn.logcosh """The nonlinear activation function.""" alpha: Union[float, int] = 1 """feature density. Number of features equal to alpha * input.shape[-1]""" use_hidden_bias: bool = True """if True uses a bias in the dense layer (hidden layer bias).""" use_visible_bias: bool = True """if True adds a bias to the input not passed through the nonlinear layer.""" precision: Any = None """numerical precision of the computation see `jax.lax.Precision`for details.""" kernel_init: NNInitFunc = normal(stddev=0.1) """Initializer for the Dense layer matrix.""" hidden_bias_init: NNInitFunc = normal(stddev=0.1) """Initializer for the hidden bias.""" visible_bias_init: NNInitFunc = normal(stddev=0.1) """Initializer for the visible bias.""" def setup(self): self.n_symm, self.n_sites = self.permutations().shape self.features = int(self.alpha * self.n_sites / self.n_symm) if self.alpha > 0 and self.features == 0: raise ValueError( f"RBMSymm: alpha={self.alpha} is too small " f"for {self.n_symm} permutations, alpha ≥ {self.n_symm / self.n_sites} is needed." ) @nn.compact def __call__(self, x_in): x = nknn.DenseSymm( name="Dense", permutations=self.permutations, features=self.features, dtype=self.dtype, use_bias=self.use_hidden_bias, kernel_init=self.kernel_init, bias_init=self.hidden_bias_init, precision=self.precision, )(x_in) x = self.activation(x) x = jnp.sum(x, axis=-1) if self.use_visible_bias: v_bias = self.param( "visible_bias", self.visible_bias_init, (1,), self.dtype ) out_bias = v_bias[0] * jnp.sum(x_in, axis=-1) return x + out_bias else: return x
class Gaussian(nn.Module): r""" Multivariate Gaussain function with mean 0 and parametrised covariance matrix :math:`\Sigma_{ij}`. The wavefunction is given by the formula: :math:`\Psi(x) = \exp(\sum_{ij} x_i \Sigma_{ij} x_j)`. The (positive definite) :math:`\Sigma_{ij} = AA^T` matrix is stored as non-positive definite matrix A. """ dtype: DType = jnp.float64 """The dtype of the weights.""" kernel_init: NNInitFunc = normal(stddev=1.0) """Initializer for the weights.""" @nn.compact def __call__(self, x_in: Array): nv = x_in.shape[-1] dtype = jnp.promote_types(x_in.dtype, self.dtype) x_in = jnp.asarray(x_in, dtype=dtype) kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype) kernel = jnp.dot(kernel.T, kernel) # print(kernel) y = -0.5 * jnp.einsum("...i,ij,...j", x_in, kernel, x_in) return y
class Jastrow(nn.Module): """Jastrow wave function :math:`\Psi(s) = \exp(\sum_{ij} s_i W_{ij} s_j)`.""" dtype: DType = jnp.complex128 """The dtype of the weights.""" kernel_init: NNInitFunc = normal() """Initializer for the weights.""" @nn.compact def __call__(self, x_in: Array): nv = x_in.shape[-1] dtype = jnp.promote_types(x_in.dtype, self.dtype) x_in = jnp.asarray(x_in, dtype=dtype) kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype) y = jnp.einsum("...i,ij,...j", x_in, kernel, x_in) return y
class Jastrow(nn.Module): r""" Jastrow wave function :math:`\Psi(s) = \exp(\sum_{ij} s_i W_{ij} s_j)`. The W matrix is stored as a non-symmetric matrix, and symmetrized during computation by doing :code:`W = W + W.T` in the computation. """ dtype: DType = jnp.complex128 """The dtype of the weights.""" kernel_init: NNInitFunc = normal() """Initializer for the weights.""" @nn.compact def __call__(self, x_in: Array): nv = x_in.shape[-1] dtype = jnp.promote_types(x_in.dtype, self.dtype) x_in = jnp.asarray(x_in, dtype=dtype) kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype) kernel = kernel + kernel.T y = jnp.einsum("...i,ij,...j", x_in, kernel, x_in) return y
from typing import Any, Callable, Iterable, Optional, Tuple, Union import flax from flax.linen.module import Module, compact from jax import lax import jax.numpy as jnp import numpy as np from netket.nn.initializers import normal, zeros PRNGKey = Any Shape = Iterable[int] Dtype = Any # this could be a real type? Array = Any default_kernel_init = normal(stddev=0.01) def _normalize_axes(axes, ndim): # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) def _canonicalize_tuple(x): if isinstance(x, Iterable): return tuple(x) else: return (x,) class DenseGeneral(Module):
import flax from flax.linen.module import Module, compact from netket.nn.initializers import lecun_normal, normal, variance_scaling, zeros from netket import jax as nkjax from jax import lax import jax.numpy as jnp import numpy as np PRNGKey = Any Shape = Iterable[int] Dtype = Any # this could be a real type? Array = Any default_kernel_init = normal() # complex_kernel_init = lecun_normal() def _normalize_axes(axes, ndim): # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) def _canonicalize_tuple(x): if isinstance(x, Iterable): return tuple(x) else: return (x, )