class SimpleCNN(nn.Module): """Defines a simple CNN model. The model assumes the input shape is [batch, H, W, C]. """ num_outputs: int num_filters: Sequence[int] kernel_sizes: Sequence[int] activation_function: int kernel_init: model_utils.Initializer = initializers.lecun_normal() bias_init: model_utils.Initializer = initializers.zeros @nn.compact def __call__(self, x, train): for num_filters, kernel_size in zip(self.num_filters, self.kernel_sizes): x = nn.Conv(num_filters, (kernel_size, kernel_size), (1, 1), kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = model_utils.ACTIVATIONS[self.activation_function](x) x = jnp.reshape(x, (x.shape[0], -1)) x = nn.Dense(self.num_outputs, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return x
class WideResnetGroup(nn.Module): """Defines a WideResnetGroup.""" blocks_per_group: int channels: int strides: Tuple[int, int] = (1, 1) conv_kernel_init: model_utils.Initializer = initializers.lecun_normal() normalizer: str = 'batch_norm' dropout_rate: float = 0.0 activation_function: str = 'relu' batch_size: Optional[int] = None virtual_batch_size: Optional[int] = None total_batch_size: Optional[int] = None @nn.compact def __call__(self, x, train): for i in range(self.blocks_per_group): x = WideResnetBlock(self.channels, self.strides if i == 0 else (1, 1), conv_kernel_init=self.conv_kernel_init, normalizer=self.normalizer, dropout_rate=self.dropout_rate, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) return x
def apply( self, x, blocks_per_group, channel_multiplier, num_outputs, conv_kernel_init=initializers.lecun_normal(), dense_kernel_init=initializers.lecun_normal(), normalizer='batch_norm', train=True, ): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv', kernel_init=conv_kernel_init, bias=False) x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, conv_kernel_init=conv_kernel_init, normalizer=normalizer, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), conv_kernel_init=conv_kernel_init, normalizer=normalizer, train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), conv_kernel_init=conv_kernel_init, normalizer=normalizer, train=train) maybe_normalize = model_utils.get_normalizer(normalizer, train) x = maybe_normalize(x) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=dense_kernel_init) return x
class WideResnetBlock(nn.Module): """Defines a single WideResnetBlock.""" channels: int strides: List[Tuple[int]] conv_kernel_init: model_utils.Initializer = initializers.lecun_normal() normalizer: str = 'batch_norm' dropout_rate: float = 0.0 activation_function: str = 'relu' batch_size: Optional[int] = None virtual_batch_size: Optional[int] = None total_batch_size: Optional[int] = None @nn.compact def __call__(self, x, train): maybe_normalize = model_utils.get_normalizer( self.normalizer, train, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size) y = maybe_normalize(name='bn1')(x) y = model_utils.ACTIVATIONS[self.activation_function](y) # Apply an up projection in case of channel mismatch if (x.shape[-1] != self.channels) or self.strides != (1, 1): x = nn.Conv( self.channels, (1, 1), # Note: Some implementations use (3, 3) here. self.strides, padding='SAME', kernel_init=self.conv_kernel_init, use_bias=False)(y) y = nn.Conv(self.channels, (3, 3), self.strides, padding='SAME', name='conv1', kernel_init=self.conv_kernel_init, use_bias=False)(y) y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y) y = maybe_normalize(name='bn2')(y) y = model_utils.ACTIVATIONS[self.activation_function](y) y = nn.Conv(self.channels, (3, 3), padding='SAME', name='conv2', kernel_init=self.conv_kernel_init, use_bias=False)(y) if self.normalizer == 'none': y = model_utils.ScalarMultiply()(y) return x + y
def apply(self, x, blocks_per_group, channels, strides=(1, 1), conv_kernel_init=initializers.lecun_normal(), normalizer='batch_norm', train=True): for i in range(blocks_per_group): x = WideResnetBlock(x, channels, strides if i == 0 else (1, 1), conv_kernel_init=conv_kernel_init, normalizer=normalizer, train=train) return x
class MaxPoolingCNN(nn.Module): """Defines a CNN model with max pooling. The model assumes the input shape is [batch, H, W, C]. """ num_outputs: int num_filters: Sequence[int] kernel_sizes: Sequence[int] kernel_paddings: Sequence[str] window_sizes: Sequence[int] window_paddings: Sequence[str] strides: Sequence[int] num_dense_units: int activation_fn: Any normalizer: str = 'none' kernel_init: model_utils.Initializer = initializers.lecun_normal() bias_init: model_utils.Initializer = initializers.zeros @nn.compact def __call__(self, x, train): maybe_normalize = model_utils.get_normalizer(self.normalizer, train) iterator = zip(self.num_filters, self.kernel_sizes, self.kernel_paddings, self.window_sizes, self.window_paddings, self.strides) for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator: x = nn.Conv(num_filters, (kernel_size, kernel_size), (1, 1), padding=kernel_padding, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = model_utils.ACTIVATIONS[self.activation_fn](x) x = maybe_normalize()(x) x = nn.max_pool(x, window_shape=(window_size, window_size), strides=(stride, stride), padding=window_padding) x = jnp.reshape(x, (x.shape[0], -1)) for num_units in self.num_dense_units: x = nn.Dense(num_units, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = model_utils.ACTIVATIONS[self.activation_fn](x) x = maybe_normalize()(x) x = nn.Dense(self.num_outputs, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return x
def apply(self, x, channels, strides=(1, 1), conv_kernel_init=initializers.lecun_normal(), normalizer='batch_norm', train=True): maybe_normalize = model_utils.get_normalizer(normalizer, train) y = maybe_normalize(x, name='bn1') y = jax.nn.relu(y) # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = nn.Conv( y, channels, (1, 1), # Note: Some implementations use (3, 3) here. strides, padding='SAME', kernel_init=conv_kernel_init, bias=False) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1', kernel_init=conv_kernel_init, bias=False) y = maybe_normalize(y, name='bn2') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2', kernel_init=conv_kernel_init, bias=False) if normalizer == 'none': y = model_utils.ScalarMultiply(y) return x + y
def apply(self, x, num_outputs, num_filters, kernel_sizes, kernel_paddings, window_sizes, window_paddings, strides, num_dense_units, activation_fn, normalizer='none', kernel_init=initializers.lecun_normal(), bias_init=initializers.zeros, train=True): maybe_normalize = model_utils.get_normalizer(normalizer, train) for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in zip( num_filters, kernel_sizes, kernel_paddings, window_sizes, window_paddings, strides): x = nn.Conv( x, num_filters, (kernel_size, kernel_size), (1, 1), padding=kernel_padding, kernel_init=kernel_init, bias_init=bias_init) x = model_utils.ACTIVATIONS[activation_fn](x) x = maybe_normalize(x) x = nn.max_pool( x, window_shape=(window_size, window_size), strides=(stride, stride), padding=window_padding) x = jnp.reshape(x, (x.shape[0], -1)) for num_units in num_dense_units: x = nn.Dense( x, num_units, kernel_init=kernel_init, bias_init=bias_init) x = model_utils.ACTIVATIONS[activation_fn](x) x = maybe_normalize(x) x = nn.Dense(x, num_outputs, kernel_init=kernel_init, bias_init=bias_init) return x
def apply(self, x, num_outputs, num_filters, kernel_sizes, activation_function, kernel_init=initializers.lecun_normal(), bias_init=initializers.zeros, train=True): for num_filters, kernel_size in zip(num_filters, kernel_sizes): x = nn.Conv(x, num_filters, (kernel_size, kernel_size), (1, 1), kernel_init=kernel_init, bias_init=bias_init) x = model_utils.ACTIVATIONS[activation_function](x) x = jnp.reshape(x, (x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=kernel_init, bias_init=bias_init) return x
def __call__(self, x): kernel = self.param('kernel', initializers.lecun_normal(), (x.shape[-1], self.features)) y = jnp.dot(x, kernel) return y
from typing import Any, Callable, Iterable, Optional, Tuple, Union import flax import jax.numpy as jnp import numpy as np from flax.linen.module import Module, compact from jax import lax from jax.nn.initializers import lecun_normal, zeros PRNGKey = Any Shape = Iterable[int] Dtype = Any # this could be a real type? Array = Any default_kernel_init = lecun_normal() def _normalize_axes(axes, ndim): # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) def _canonicalize_tuple(x): if isinstance(x, Iterable): return tuple(x) else: return (x, ) class DenseGeneral(Module):
from netket.utils import HashableArray, warn_deprecation from netket.utils.types import NNInitFunc from netket.utils.group import PermutationGroup from netket.graph import Graph, Lattice from netket.nn.activation import reim_selu from netket.nn.symmetric_linear import ( DenseSymmMatrix, DenseSymmFFT, DenseEquivariantFFT, DenseEquivariantIrrep, ) # Same as netket.nn.symmetric_linear.default_equivariant_initializer # All GCNN layers have kernels of shape [out_features, in_features, n_symm] default_gcnn_initializer = lecun_normal(in_axis=1, out_axis=0) def identity(x): return x class GCNN_FFT(nn.Module): r"""Implements a GCNN using a fast fourier transform over the translation group. The group convolution can be written in terms of translational convolutions with symmetry transformed filters as desribed in ` Cohen et. *al* <http://proceedings.mlr.press/v48/cohenc16.pdf>`_ The translational convolutions are then implemented with Fast Fourier Transforms. """ symmetries: HashableArray """A group of symmetry operations (or array of permutation indices) over which the network should be equivariant.
class DeepSetRelDistance(nn.Module): r"""Implements an equivariant version of the DeepSets architecture given by (https://arxiv.org/abs/1703.06114) .. math :: f(x_1,...,x_N) = \rho\left(\sum_i \phi(x_i)\right) that is suitable for the simulation of periodic systems. Additionally one can add a cusp condition by specifying the asymptotic exponent. For helium the Ansatz reads (https://arxiv.org/abs/2112.11957): .. math :: \psi(x_1,...,x_N) = \rho\left(\sum_i \phi(d_{\sin}(x_i,x_j))\right) \cdot \exp\left[-\frac{1}{2}\left(b/d_{\sin}(x_i,x_j)\right)^5\right] """ hilbert: ContinuousHilbert """The hilbert space defining the periodic box where this ansatz is defined.""" layers_phi: int """Number of layers in phi network.""" layers_rho: int """Number of layers in rho network.""" features_phi: Union[Tuple, int] """Number of features in each layer for phi network.""" features_rho: Union[Tuple, int] """ Number of features in each layer for rho network. If specified as a list, the last layer must have 1 feature. """ cusp_exponent: Optional[int] = None """exponent of Katos cusp condition""" param_dtype: Any = jnp.float64 """The dtype of the weights.""" activation: Any = jax.nn.gelu """The nonlinear activation function between hidden layers.""" pooling: Any = jnp.sum """The pooling operation to be used after the phi-transformation""" use_bias: bool = True """if True uses a bias in all layers.""" kernel_init: NNInitFunc = lecun_normal() """Initializer for the Dense layer matrix""" bias_init: NNInitFunc = zeros """Initializer for the hidden bias""" params_init: NNInitFunc = ones """Initializer for the parameter in the cusp""" def setup(self): if not all(self.hilbert.pbc): raise ValueError( "The DeepSetRelDistance model only works with " "hilbert spaces with periodic boundary conditions " "among all directions." ) features_phi = self.features_phi if isinstance(features_phi, int): features_phi = [features_phi] * self.layers_phi check_features_length(features_phi, self.layers_phi, "phi") features_rho = self.features_rho if isinstance(features_rho, int): features_rho = [features_rho] * (self.layers_rho - 1) + [1] check_features_length(features_rho, self.layers_rho, "rho") assert features_rho[-1] == 1 self.phi = [ nn.Dense( feat, use_bias=self.use_bias, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for feat in features_phi ] self.rho = [ nn.Dense( feat, use_bias=self.use_bias, param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for feat in features_rho ] def distance(self, x, sdim, L): n_particles = x.shape[0] // sdim x = x.reshape(-1, sdim) dis = -x[jnp.newaxis, :, :] + x[:, jnp.newaxis, :] dis = dis[jnp.triu_indices(n_particles, 1)] dis = L[jnp.newaxis, :] / 2.0 * jnp.sin(jnp.pi * dis / L[jnp.newaxis, :]) return dis @nn.compact def __call__(self, x): sha = x.shape param = self.param("cusp", self.params_init, (1,), self.param_dtype) L = jnp.array(self.hilbert.extent) sdim = L.size d = jax.vmap(self.distance, in_axes=(0, None, None))(x, sdim, L) dis = jnp.linalg.norm(d, axis=-1) cusp = 0.0 if self.cusp_exponent is not None: cusp = -0.5 * jnp.sum(param / dis**self.cusp_exponent, axis=-1) y = (d / L[jnp.newaxis, :]) ** 2 """ The phi transformation """ for layer in self.phi: y = self.activation(layer(y)) """ Pooling operation """ y = self.pooling(y, axis=-2) """ The rho transformation """ for i, layer in enumerate(self.rho): y = layer(y) if i == len(self.rho) - 1: break y = self.activation(y) return (y.reshape(-1) + cusp).reshape(sha[0])
import numpy as np import jax.numpy as jnp from jax import lax from jax.nn.initializers import zeros, lecun_normal from flax.linen.module import Module, compact from netket.utils import warn_deprecation from netket.utils import HashableArray from netket.utils.types import Array, DType, NNInitFunc from netket.utils.group import PermutationGroup from typing import Sequence from netket.graph import Graph, Lattice # All layers defined here have kernels of shape [out_features, in_features, n_symm] default_equivariant_initializer = lecun_normal(in_axis=1, out_axis=0) def _normalise_mask(mask, new_norm): mask = jnp.asarray(mask) return mask / jnp.linalg.norm(mask) * new_norm**0.5 def symm_input_warning(x_shape, new_x_shape, name): warn_deprecation( (f"{len(x_shape)}-dimensional input to {name} layer is deprecated.\n" f"Input shape {x_shape} has been reshaped to {new_x_shape}, where " "the middle dimension encodes different input channels.\n" "Please provide a 3-dimensional input.\nThis warning will become an " "error in the future."))
def __call__(self, key, shape, dtype=None): if dtype is None: dtype = "float32" initializer_fn = jax_initializers.lecun_normal() return initializer_fn(key, shape, dtype)
class WideResnet(nn.Module): """Defines the WideResnet Model.""" blocks_per_group: int channel_multiplier: int group_strides: List[Tuple[int]] num_outputs: int conv_kernel_init: model_utils.Initializer = initializers.lecun_normal() dense_kernel_init: model_utils.Initializer = initializers.lecun_normal() normalizer: str = 'batch_norm' dropout_rate: float = 0.0 activation_function: str = 'relu' batch_size: Optional[int] = None virtual_batch_size: Optional[int] = None total_batch_size: Optional[int] = None @nn.compact def __call__(self, x, train): x = nn.Conv(16, (3, 3), padding='SAME', name='init_conv', kernel_init=self.conv_kernel_init, use_bias=False)(x) x = WideResnetGroup(self.blocks_per_group, 16 * self.channel_multiplier, self.group_strides[0], conv_kernel_init=self.conv_kernel_init, normalizer=self.normalizer, dropout_rate=self.dropout_rate, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) x = WideResnetGroup(self.blocks_per_group, 32 * self.channel_multiplier, self.group_strides[1], conv_kernel_init=self.conv_kernel_init, normalizer=self.normalizer, dropout_rate=self.dropout_rate, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) x = WideResnetGroup(self.blocks_per_group, 64 * self.channel_multiplier, self.group_strides[2], conv_kernel_init=self.conv_kernel_init, dropout_rate=self.dropout_rate, normalizer=self.normalizer, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) maybe_normalize = model_utils.get_normalizer( self.normalizer, train, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size) x = maybe_normalize()(x) x = model_utils.ACTIVATIONS[self.activation_function](x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(self.num_outputs, kernel_init=self.dense_kernel_init)(x) return x