コード例 #1
0
ファイル: simple_cnn.py プロジェクト: google/init2winit
class SimpleCNN(nn.Module):
    """Defines a simple CNN model.

  The model assumes the input shape is [batch, H, W, C].
  """
    num_outputs: int
    num_filters: Sequence[int]
    kernel_sizes: Sequence[int]
    activation_function: int
    kernel_init: model_utils.Initializer = initializers.lecun_normal()
    bias_init: model_utils.Initializer = initializers.zeros

    @nn.compact
    def __call__(self, x, train):
        for num_filters, kernel_size in zip(self.num_filters,
                                            self.kernel_sizes):
            x = nn.Conv(num_filters, (kernel_size, kernel_size), (1, 1),
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)(x)
            x = model_utils.ACTIVATIONS[self.activation_function](x)
        x = jnp.reshape(x, (x.shape[0], -1))
        x = nn.Dense(self.num_outputs,
                     kernel_init=self.kernel_init,
                     bias_init=self.bias_init)(x)
        return x
コード例 #2
0
class WideResnetGroup(nn.Module):
    """Defines a WideResnetGroup."""
    blocks_per_group: int
    channels: int
    strides: Tuple[int, int] = (1, 1)
    conv_kernel_init: model_utils.Initializer = initializers.lecun_normal()
    normalizer: str = 'batch_norm'
    dropout_rate: float = 0.0
    activation_function: str = 'relu'
    batch_size: Optional[int] = None
    virtual_batch_size: Optional[int] = None
    total_batch_size: Optional[int] = None

    @nn.compact
    def __call__(self, x, train):
        for i in range(self.blocks_per_group):
            x = WideResnetBlock(self.channels,
                                self.strides if i == 0 else (1, 1),
                                conv_kernel_init=self.conv_kernel_init,
                                normalizer=self.normalizer,
                                dropout_rate=self.dropout_rate,
                                activation_function=self.activation_function,
                                batch_size=self.batch_size,
                                virtual_batch_size=self.virtual_batch_size,
                                total_batch_size=self.total_batch_size)(
                                    x, train=train)
        return x
コード例 #3
0
ファイル: wide_resnet.py プロジェクト: cshallue/init2winit
    def apply(
        self,
        x,
        blocks_per_group,
        channel_multiplier,
        num_outputs,
        conv_kernel_init=initializers.lecun_normal(),
        dense_kernel_init=initializers.lecun_normal(),
        normalizer='batch_norm',
        train=True,
    ):

        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    kernel_init=conv_kernel_init,
                    bias=False)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            conv_kernel_init=conv_kernel_init,
                            normalizer=normalizer,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            conv_kernel_init=conv_kernel_init,
                            normalizer=normalizer,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            conv_kernel_init=conv_kernel_init,
                            normalizer=normalizer,
                            train=train)
        maybe_normalize = model_utils.get_normalizer(normalizer, train)
        x = maybe_normalize(x)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs, kernel_init=dense_kernel_init)
        return x
コード例 #4
0
class WideResnetBlock(nn.Module):
    """Defines a single WideResnetBlock."""
    channels: int
    strides: List[Tuple[int]]
    conv_kernel_init: model_utils.Initializer = initializers.lecun_normal()
    normalizer: str = 'batch_norm'
    dropout_rate: float = 0.0
    activation_function: str = 'relu'
    batch_size: Optional[int] = None
    virtual_batch_size: Optional[int] = None
    total_batch_size: Optional[int] = None

    @nn.compact
    def __call__(self, x, train):
        maybe_normalize = model_utils.get_normalizer(
            self.normalizer,
            train,
            batch_size=self.batch_size,
            virtual_batch_size=self.virtual_batch_size,
            total_batch_size=self.total_batch_size)
        y = maybe_normalize(name='bn1')(x)
        y = model_utils.ACTIVATIONS[self.activation_function](y)

        # Apply an up projection in case of channel mismatch
        if (x.shape[-1] != self.channels) or self.strides != (1, 1):
            x = nn.Conv(
                self.channels,
                (1, 1),  # Note: Some implementations use (3, 3) here.
                self.strides,
                padding='SAME',
                kernel_init=self.conv_kernel_init,
                use_bias=False)(y)

        y = nn.Conv(self.channels, (3, 3),
                    self.strides,
                    padding='SAME',
                    name='conv1',
                    kernel_init=self.conv_kernel_init,
                    use_bias=False)(y)
        y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y)
        y = maybe_normalize(name='bn2')(y)
        y = model_utils.ACTIVATIONS[self.activation_function](y)
        y = nn.Conv(self.channels, (3, 3),
                    padding='SAME',
                    name='conv2',
                    kernel_init=self.conv_kernel_init,
                    use_bias=False)(y)

        if self.normalizer == 'none':
            y = model_utils.ScalarMultiply()(y)

        return x + y
コード例 #5
0
ファイル: wide_resnet.py プロジェクト: cshallue/init2winit
 def apply(self,
           x,
           blocks_per_group,
           channels,
           strides=(1, 1),
           conv_kernel_init=initializers.lecun_normal(),
           normalizer='batch_norm',
           train=True):
     for i in range(blocks_per_group):
         x = WideResnetBlock(x,
                             channels,
                             strides if i == 0 else (1, 1),
                             conv_kernel_init=conv_kernel_init,
                             normalizer=normalizer,
                             train=train)
     return x
コード例 #6
0
ファイル: max_pooling_cnn.py プロジェクト: google/init2winit
class MaxPoolingCNN(nn.Module):
    """Defines a CNN model with max pooling.

  The model assumes the input shape is [batch, H, W, C].
  """
    num_outputs: int
    num_filters: Sequence[int]
    kernel_sizes: Sequence[int]
    kernel_paddings: Sequence[str]
    window_sizes: Sequence[int]
    window_paddings: Sequence[str]
    strides: Sequence[int]
    num_dense_units: int
    activation_fn: Any
    normalizer: str = 'none'
    kernel_init: model_utils.Initializer = initializers.lecun_normal()
    bias_init: model_utils.Initializer = initializers.zeros

    @nn.compact
    def __call__(self, x, train):
        maybe_normalize = model_utils.get_normalizer(self.normalizer, train)
        iterator = zip(self.num_filters, self.kernel_sizes,
                       self.kernel_paddings, self.window_sizes,
                       self.window_paddings, self.strides)
        for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator:
            x = nn.Conv(num_filters, (kernel_size, kernel_size), (1, 1),
                        padding=kernel_padding,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)(x)
            x = model_utils.ACTIVATIONS[self.activation_fn](x)
            x = maybe_normalize()(x)
            x = nn.max_pool(x,
                            window_shape=(window_size, window_size),
                            strides=(stride, stride),
                            padding=window_padding)
        x = jnp.reshape(x, (x.shape[0], -1))
        for num_units in self.num_dense_units:
            x = nn.Dense(num_units,
                         kernel_init=self.kernel_init,
                         bias_init=self.bias_init)(x)
            x = model_utils.ACTIVATIONS[self.activation_fn](x)
            x = maybe_normalize()(x)
        x = nn.Dense(self.num_outputs,
                     kernel_init=self.kernel_init,
                     bias_init=self.bias_init)(x)
        return x
コード例 #7
0
ファイル: wide_resnet.py プロジェクト: cshallue/init2winit
    def apply(self,
              x,
              channels,
              strides=(1, 1),
              conv_kernel_init=initializers.lecun_normal(),
              normalizer='batch_norm',
              train=True):
        maybe_normalize = model_utils.get_normalizer(normalizer, train)
        y = maybe_normalize(x, name='bn1')
        y = jax.nn.relu(y)

        # Apply an up projection in case of channel mismatch
        if (x.shape[-1] != channels) or strides != (1, 1):
            x = nn.Conv(
                y,
                channels,
                (1, 1),  # Note: Some implementations use (3, 3) here.
                strides,
                padding='SAME',
                kernel_init=conv_kernel_init,
                bias=False)

        y = nn.Conv(y,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    name='conv1',
                    kernel_init=conv_kernel_init,
                    bias=False)
        y = maybe_normalize(y, name='bn2')
        y = jax.nn.relu(y)
        y = nn.Conv(y,
                    channels, (3, 3),
                    padding='SAME',
                    name='conv2',
                    kernel_init=conv_kernel_init,
                    bias=False)

        if normalizer == 'none':
            y = model_utils.ScalarMultiply(y)

        return x + y
コード例 #8
0
  def apply(self,
            x,
            num_outputs,
            num_filters,
            kernel_sizes,
            kernel_paddings,
            window_sizes,
            window_paddings,
            strides,
            num_dense_units,
            activation_fn,
            normalizer='none',
            kernel_init=initializers.lecun_normal(),
            bias_init=initializers.zeros,
            train=True):

    maybe_normalize = model_utils.get_normalizer(normalizer, train)
    for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in zip(
        num_filters, kernel_sizes, kernel_paddings, window_sizes,
        window_paddings, strides):
      x = nn.Conv(
          x,
          num_filters, (kernel_size, kernel_size), (1, 1),
          padding=kernel_padding,
          kernel_init=kernel_init,
          bias_init=bias_init)
      x = model_utils.ACTIVATIONS[activation_fn](x)
      x = maybe_normalize(x)
      x = nn.max_pool(
          x,
          window_shape=(window_size, window_size),
          strides=(stride, stride),
          padding=window_padding)
    x = jnp.reshape(x, (x.shape[0], -1))
    for num_units in num_dense_units:
      x = nn.Dense(
          x, num_units, kernel_init=kernel_init, bias_init=bias_init)
      x = model_utils.ACTIVATIONS[activation_fn](x)
      x = maybe_normalize(x)
    x = nn.Dense(x, num_outputs, kernel_init=kernel_init, bias_init=bias_init)
    return x
コード例 #9
0
    def apply(self,
              x,
              num_outputs,
              num_filters,
              kernel_sizes,
              activation_function,
              kernel_init=initializers.lecun_normal(),
              bias_init=initializers.zeros,
              train=True):

        for num_filters, kernel_size in zip(num_filters, kernel_sizes):
            x = nn.Conv(x,
                        num_filters, (kernel_size, kernel_size), (1, 1),
                        kernel_init=kernel_init,
                        bias_init=bias_init)
            x = model_utils.ACTIVATIONS[activation_function](x)
        x = jnp.reshape(x, (x.shape[0], -1))
        x = nn.Dense(x,
                     num_outputs,
                     kernel_init=kernel_init,
                     bias_init=bias_init)
        return x
コード例 #10
0
ファイル: module_test.py プロジェクト: joelgarde/flax
 def __call__(self, x):
   kernel = self.param('kernel',
                       initializers.lecun_normal(),
                       (x.shape[-1], self.features))
   y = jnp.dot(x, kernel)
   return y
コード例 #11
0
ファイル: linear.py プロジェクト: tobiaswiener/netket
from typing import Any, Callable, Iterable, Optional, Tuple, Union

import flax
import jax.numpy as jnp
import numpy as np
from flax.linen.module import Module, compact
from jax import lax
from jax.nn.initializers import lecun_normal, zeros

PRNGKey = Any
Shape = Iterable[int]
Dtype = Any  # this could be a real type?
Array = Any

default_kernel_init = lecun_normal()


def _normalize_axes(axes, ndim):
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple([ax if ax >= 0 else ndim + ax for ax in axes])


def _canonicalize_tuple(x):
    if isinstance(x, Iterable):
        return tuple(x)
    else:
        return (x, )


class DenseGeneral(Module):
コード例 #12
0
from netket.utils import HashableArray, warn_deprecation
from netket.utils.types import NNInitFunc
from netket.utils.group import PermutationGroup
from netket.graph import Graph, Lattice
from netket.nn.activation import reim_selu
from netket.nn.symmetric_linear import (
    DenseSymmMatrix,
    DenseSymmFFT,
    DenseEquivariantFFT,
    DenseEquivariantIrrep,
)

# Same as netket.nn.symmetric_linear.default_equivariant_initializer
# All GCNN layers have kernels of shape [out_features, in_features, n_symm]
default_gcnn_initializer = lecun_normal(in_axis=1, out_axis=0)


def identity(x):
    return x


class GCNN_FFT(nn.Module):
    r"""Implements a GCNN using a fast fourier transform over the translation group.
    The group convolution can be written in terms of translational convolutions with
    symmetry transformed filters as desribed in ` Cohen et. *al* <http://proceedings.mlr.press/v48/cohenc16.pdf>`_
    The translational convolutions are then implemented with Fast Fourier Transforms.
    """

    symmetries: HashableArray
    """A group of symmetry operations (or array of permutation indices) over which the network should be equivariant.
コード例 #13
0
ファイル: deepset.py プロジェクト: yannra/netket
class DeepSetRelDistance(nn.Module):
    r"""Implements an equivariant version of the DeepSets architecture
    given by (https://arxiv.org/abs/1703.06114)

    .. math ::

        f(x_1,...,x_N) = \rho\left(\sum_i \phi(x_i)\right)

    that is suitable for the simulation of periodic systems.
    Additionally one can add a cusp condition by specifying the
    asymptotic exponent.
    For helium the Ansatz reads (https://arxiv.org/abs/2112.11957):

    .. math ::

        \psi(x_1,...,x_N) = \rho\left(\sum_i \phi(d_{\sin}(x_i,x_j))\right) \cdot \exp\left[-\frac{1}{2}\left(b/d_{\sin}(x_i,x_j)\right)^5\right]

    """

    hilbert: ContinuousHilbert
    """The hilbert space defining the periodic box where this ansatz is defined."""

    layers_phi: int
    """Number of layers in phi network."""
    layers_rho: int
    """Number of layers in rho network."""

    features_phi: Union[Tuple, int]
    """Number of features in each layer for phi network."""
    features_rho: Union[Tuple, int]
    """
    Number of features in each layer for rho network.
    If specified as a list, the last layer must have 1 feature.
    """

    cusp_exponent: Optional[int] = None
    """exponent of Katos cusp condition"""

    param_dtype: Any = jnp.float64
    """The dtype of the weights."""

    activation: Any = jax.nn.gelu
    """The nonlinear activation function between hidden layers."""

    pooling: Any = jnp.sum
    """The pooling operation to be used after the phi-transformation"""

    use_bias: bool = True
    """if True uses a bias in all layers."""

    kernel_init: NNInitFunc = lecun_normal()
    """Initializer for the Dense layer matrix"""
    bias_init: NNInitFunc = zeros
    """Initializer for the hidden bias"""
    params_init: NNInitFunc = ones
    """Initializer for the parameter in the cusp"""

    def setup(self):

        if not all(self.hilbert.pbc):
            raise ValueError(
                "The DeepSetRelDistance model only works with "
                "hilbert spaces with periodic boundary conditions "
                "among all directions."
            )

        features_phi = self.features_phi
        if isinstance(features_phi, int):
            features_phi = [features_phi] * self.layers_phi

        check_features_length(features_phi, self.layers_phi, "phi")

        features_rho = self.features_rho
        if isinstance(features_rho, int):
            features_rho = [features_rho] * (self.layers_rho - 1) + [1]

        check_features_length(features_rho, self.layers_rho, "rho")
        assert features_rho[-1] == 1

        self.phi = [
            nn.Dense(
                feat,
                use_bias=self.use_bias,
                param_dtype=self.param_dtype,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
            )
            for feat in features_phi
        ]

        self.rho = [
            nn.Dense(
                feat,
                use_bias=self.use_bias,
                param_dtype=self.param_dtype,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
            )
            for feat in features_rho
        ]

    def distance(self, x, sdim, L):
        n_particles = x.shape[0] // sdim
        x = x.reshape(-1, sdim)

        dis = -x[jnp.newaxis, :, :] + x[:, jnp.newaxis, :]

        dis = dis[jnp.triu_indices(n_particles, 1)]
        dis = L[jnp.newaxis, :] / 2.0 * jnp.sin(jnp.pi * dis / L[jnp.newaxis, :])
        return dis

    @nn.compact
    def __call__(self, x):
        sha = x.shape
        param = self.param("cusp", self.params_init, (1,), self.param_dtype)

        L = jnp.array(self.hilbert.extent)
        sdim = L.size

        d = jax.vmap(self.distance, in_axes=(0, None, None))(x, sdim, L)
        dis = jnp.linalg.norm(d, axis=-1)

        cusp = 0.0
        if self.cusp_exponent is not None:
            cusp = -0.5 * jnp.sum(param / dis**self.cusp_exponent, axis=-1)

        y = (d / L[jnp.newaxis, :]) ** 2

        """ The phi transformation """
        for layer in self.phi:
            y = self.activation(layer(y))

        """ Pooling operation """
        y = self.pooling(y, axis=-2)

        """ The rho transformation """
        for i, layer in enumerate(self.rho):
            y = layer(y)
            if i == len(self.rho) - 1:
                break
            y = self.activation(y)

        return (y.reshape(-1) + cusp).reshape(sha[0])
コード例 #14
0
import numpy as np
import jax.numpy as jnp

from jax import lax
from jax.nn.initializers import zeros, lecun_normal
from flax.linen.module import Module, compact

from netket.utils import warn_deprecation
from netket.utils import HashableArray
from netket.utils.types import Array, DType, NNInitFunc
from netket.utils.group import PermutationGroup
from typing import Sequence
from netket.graph import Graph, Lattice

# All layers defined here have kernels of shape [out_features, in_features, n_symm]
default_equivariant_initializer = lecun_normal(in_axis=1, out_axis=0)


def _normalise_mask(mask, new_norm):
    mask = jnp.asarray(mask)
    return mask / jnp.linalg.norm(mask) * new_norm**0.5


def symm_input_warning(x_shape, new_x_shape, name):
    warn_deprecation(
        (f"{len(x_shape)}-dimensional input to {name} layer is deprecated.\n"
         f"Input shape {x_shape} has been reshaped to {new_x_shape}, where "
         "the middle dimension encodes different input channels.\n"
         "Please provide a 3-dimensional input.\nThis warning will become an "
         "error in the future."))
コード例 #15
0
 def __call__(self, key, shape, dtype=None):
     if dtype is None:
         dtype = "float32"
     initializer_fn = jax_initializers.lecun_normal()
     return initializer_fn(key, shape, dtype)
コード例 #16
0
class WideResnet(nn.Module):
    """Defines the WideResnet Model."""
    blocks_per_group: int
    channel_multiplier: int
    group_strides: List[Tuple[int]]
    num_outputs: int
    conv_kernel_init: model_utils.Initializer = initializers.lecun_normal()
    dense_kernel_init: model_utils.Initializer = initializers.lecun_normal()
    normalizer: str = 'batch_norm'
    dropout_rate: float = 0.0
    activation_function: str = 'relu'
    batch_size: Optional[int] = None
    virtual_batch_size: Optional[int] = None
    total_batch_size: Optional[int] = None

    @nn.compact
    def __call__(self, x, train):
        x = nn.Conv(16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    kernel_init=self.conv_kernel_init,
                    use_bias=False)(x)
        x = WideResnetGroup(self.blocks_per_group,
                            16 * self.channel_multiplier,
                            self.group_strides[0],
                            conv_kernel_init=self.conv_kernel_init,
                            normalizer=self.normalizer,
                            dropout_rate=self.dropout_rate,
                            activation_function=self.activation_function,
                            batch_size=self.batch_size,
                            virtual_batch_size=self.virtual_batch_size,
                            total_batch_size=self.total_batch_size)(
                                x, train=train)
        x = WideResnetGroup(self.blocks_per_group,
                            32 * self.channel_multiplier,
                            self.group_strides[1],
                            conv_kernel_init=self.conv_kernel_init,
                            normalizer=self.normalizer,
                            dropout_rate=self.dropout_rate,
                            activation_function=self.activation_function,
                            batch_size=self.batch_size,
                            virtual_batch_size=self.virtual_batch_size,
                            total_batch_size=self.total_batch_size)(
                                x, train=train)
        x = WideResnetGroup(self.blocks_per_group,
                            64 * self.channel_multiplier,
                            self.group_strides[2],
                            conv_kernel_init=self.conv_kernel_init,
                            dropout_rate=self.dropout_rate,
                            normalizer=self.normalizer,
                            activation_function=self.activation_function,
                            batch_size=self.batch_size,
                            virtual_batch_size=self.virtual_batch_size,
                            total_batch_size=self.total_batch_size)(
                                x, train=train)
        maybe_normalize = model_utils.get_normalizer(
            self.normalizer,
            train,
            batch_size=self.batch_size,
            virtual_batch_size=self.virtual_batch_size,
            total_batch_size=self.total_batch_size)
        x = maybe_normalize()(x)
        x = model_utils.ACTIVATIONS[self.activation_function](x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(self.num_outputs, kernel_init=self.dense_kernel_init)(x)
        return x