コード例 #1
0
class RBMSymm(nn.Module):
    """A symmetrized RBM using the :ref:`netket.nn.DenseSymm` layer internally."""

    symmetries: Union[HashableArray, PermutationGroup]
    """A group of symmetry operations (or array of permutation indices) over which the layer should be invariant.
    Numpy/Jax arrays must be wrapped into an :class:`netket.utils.HashableArray`.
    """
    dtype: Any = np.float64
    """The dtype of the weights."""
    activation: Any = nknn.log_cosh
    """The nonlinear activation function."""
    alpha: Union[float, int] = 1
    """feature density. Number of features equal to alpha * input.shape[-1]"""
    use_hidden_bias: bool = True
    """if True uses a bias in the dense layer (hidden layer bias)."""
    use_visible_bias: bool = True
    """if True adds a bias to the input not passed through the nonlinear layer."""
    precision: Any = None
    """numerical precision of the computation see `jax.lax.Precision`for details."""

    kernel_init: NNInitFunc = normal(stddev=0.1)
    """Initializer for the Dense layer matrix."""
    hidden_bias_init: NNInitFunc = normal(stddev=0.1)
    """Initializer for the hidden bias."""
    visible_bias_init: NNInitFunc = normal(stddev=0.1)
    """Initializer for the visible bias."""
    def setup(self):
        self.n_symm, self.n_sites = np.asarray(self.symmetries).shape
        self.features = int(self.alpha * self.n_sites / self.n_symm)
        if self.alpha > 0 and self.features == 0:
            raise ValueError(
                f"RBMSymm: alpha={self.alpha} is too small "
                f"for {self.n_symm} permutations, alpha ≥ {self.n_symm / self.n_sites} is needed."
            )

    @nn.compact
    def __call__(self, x_in):
        x = nknn.DenseSymm(
            name="Dense",
            mode="matrix",
            symmetries=self.symmetries,
            features=self.features,
            dtype=self.dtype,
            use_bias=self.use_hidden_bias,
            kernel_init=self.kernel_init,
            bias_init=self.hidden_bias_init,
            precision=self.precision,
        )(x_in)
        x = self.activation(x)

        x = x.reshape(-1, self.features * self.n_symm)
        x = jnp.sum(x, axis=-1)

        if self.use_visible_bias:
            v_bias = self.param("visible_bias", self.visible_bias_init, (1, ),
                                self.dtype)
            out_bias = v_bias[0] * jnp.sum(x_in, axis=-1)
            return x + out_bias
        else:
            return x
コード例 #2
0
ファイル: rbm.py プロジェクト: huihuangvv/netket
class RBMSymm(nn.Module):
    """A symmetrized RBM using the :ref:`netket.nn.DenseSymm` layer internally."""

    permutations: Callable[[], Array]
    """See documentation of :ref:`netket.nn.DenseSymm`."""
    dtype: Any = np.float64
    """The dtype of the weights."""
    activation: Any = nknn.logcosh
    """The nonlinear activation function."""
    alpha: Union[float, int] = 1
    """feature density. Number of features equal to alpha * input.shape[-1]"""
    use_hidden_bias: bool = True
    """if True uses a bias in the dense layer (hidden layer bias)."""
    use_visible_bias: bool = True
    """if True adds a bias to the input not passed through the nonlinear layer."""
    precision: Any = None
    """numerical precision of the computation see `jax.lax.Precision`for details."""

    kernel_init: NNInitFunc = normal(stddev=0.1)
    """Initializer for the Dense layer matrix."""
    hidden_bias_init: NNInitFunc = normal(stddev=0.1)
    """Initializer for the hidden bias."""
    visible_bias_init: NNInitFunc = normal(stddev=0.1)
    """Initializer for the visible bias."""

    def setup(self):
        self.n_symm, self.n_sites = self.permutations().shape
        self.features = int(self.alpha * self.n_sites / self.n_symm)
        if self.alpha > 0 and self.features == 0:
            raise ValueError(
                f"RBMSymm: alpha={self.alpha} is too small "
                f"for {self.n_symm} permutations, alpha ≥ {self.n_symm / self.n_sites} is needed."
            )

    @nn.compact
    def __call__(self, x_in):
        x = nknn.DenseSymm(
            name="Dense",
            permutations=self.permutations,
            features=self.features,
            dtype=self.dtype,
            use_bias=self.use_hidden_bias,
            kernel_init=self.kernel_init,
            bias_init=self.hidden_bias_init,
            precision=self.precision,
        )(x_in)
        x = self.activation(x)
        x = jnp.sum(x, axis=-1)

        if self.use_visible_bias:
            v_bias = self.param(
                "visible_bias", self.visible_bias_init, (1,), self.dtype
            )
            out_bias = v_bias[0] * jnp.sum(x_in, axis=-1)
            return x + out_bias
        else:
            return x
コード例 #3
0
class Gaussian(nn.Module):
    r"""
    Multivariate Gaussain function with mean 0 and parametrised covariance matrix
    :math:`\Sigma_{ij}`.

    The wavefunction is given by the formula: :math:`\Psi(x) = \exp(\sum_{ij} x_i \Sigma_{ij} x_j)`.
    The (positive definite) :math:`\Sigma_{ij} = AA^T` matrix is stored as
    non-positive definite matrix A.
    """

    dtype: DType = jnp.float64
    """The dtype of the weights."""
    kernel_init: NNInitFunc = normal(stddev=1.0)
    """Initializer for the weights."""
    @nn.compact
    def __call__(self, x_in: Array):
        nv = x_in.shape[-1]

        dtype = jnp.promote_types(x_in.dtype, self.dtype)
        x_in = jnp.asarray(x_in, dtype=dtype)

        kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype)

        kernel = jnp.dot(kernel.T, kernel)

        # print(kernel)
        y = -0.5 * jnp.einsum("...i,ij,...j", x_in, kernel, x_in)

        return y
コード例 #4
0
ファイル: jastrow.py プロジェクト: huihuangvv/netket
class Jastrow(nn.Module):
    """Jastrow wave function :math:`\Psi(s) = \exp(\sum_{ij} s_i W_{ij} s_j)`."""

    dtype: DType = jnp.complex128
    """The dtype of the weights."""
    kernel_init: NNInitFunc = normal()
    """Initializer for the weights."""
    @nn.compact
    def __call__(self, x_in: Array):
        nv = x_in.shape[-1]

        dtype = jnp.promote_types(x_in.dtype, self.dtype)
        x_in = jnp.asarray(x_in, dtype=dtype)

        kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype)
        y = jnp.einsum("...i,ij,...j", x_in, kernel, x_in)

        return y
コード例 #5
0
class Jastrow(nn.Module):
    r"""
    Jastrow wave function :math:`\Psi(s) = \exp(\sum_{ij} s_i W_{ij} s_j)`.

    The W matrix is stored as a non-symmetric matrix, and symmetrized
    during computation by doing :code:`W = W + W.T` in the computation.
    """

    dtype: DType = jnp.complex128
    """The dtype of the weights."""
    kernel_init: NNInitFunc = normal()
    """Initializer for the weights."""
    @nn.compact
    def __call__(self, x_in: Array):
        nv = x_in.shape[-1]

        dtype = jnp.promote_types(x_in.dtype, self.dtype)
        x_in = jnp.asarray(x_in, dtype=dtype)

        kernel = self.param("kernel", self.kernel_init, (nv, nv), self.dtype)
        kernel = kernel + kernel.T
        y = jnp.einsum("...i,ij,...j", x_in, kernel, x_in)

        return y
コード例 #6
0
ファイル: linear.py プロジェクト: kchoo1118/netket
from typing import Any, Callable, Iterable, Optional, Tuple, Union

import flax
from flax.linen.module import Module, compact
from jax import lax
import jax.numpy as jnp
import numpy as np

from netket.nn.initializers import normal, zeros

PRNGKey = Any
Shape = Iterable[int]
Dtype = Any  # this could be a real type?
Array = Any

default_kernel_init = normal(stddev=0.01)


def _normalize_axes(axes, ndim):
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple([ax if ax >= 0 else ndim + ax for ax in axes])


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    else:
        return (x,)


class DenseGeneral(Module):
コード例 #7
0
import flax
from flax.linen.module import Module, compact
from netket.nn.initializers import lecun_normal, normal, variance_scaling, zeros
from netket import jax as nkjax

from jax import lax
import jax.numpy as jnp
import numpy as np

PRNGKey = Any
Shape = Iterable[int]
Dtype = Any  # this could be a real type?
Array = Any

default_kernel_init = normal()
# complex_kernel_init = lecun_normal()


def _normalize_axes(axes, ndim):
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple([ax if ax >= 0 else ndim + ax for ax in axes])


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    else:
        return (x, )