class Dense(Module): features: int use_bias: bool = True kernel_init: Callable = initializers.lecun_normal() bias_init: Callable = initializers.zeros dtype: Any = jnp.float32 precision: Any = None @compact def __call__(self, inputs): inputs = jnp.asarray(inputs, self.dtype) kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features)) kernel = jnp.asarray(kernel, self.dtype) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1, ), (0, )), ((), ())), precision=self.precision) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features, )) bias = jnp.asarray(bias, self.dtype) y = y + bias return y
"""Linear modules.""" from collections.abc import Iterable # pylint: disable=g-importing-member from flax.nn import initializers from flax.core import Scope from flax import struct from jax import lax import jax.numpy as jnp import numpy as onp default_kernel_init = initializers.lecun_normal() def _normalize_axes(axes, ndim): # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple([ax if ax >= 0 else ndim + ax for ax in axes]) def dense_general(scope, inputs, features, axis=-1, batch_dims=(), bias=True, dtype=jnp.float32, kernel_init=default_kernel_init,
def get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization): """Create Conv constructor based on architecture settings.""" if weight_norm == 'fixed': conv = conv_layers.ConvFixedScale elif weight_norm == 'learned_b': conv = conv_layers.ConvLearnedScale elif weight_norm == 'ws_sqrt': conv = conv_layers.ConvWS.partial(kaiming_scaling=True) elif weight_norm == 'ws': conv = conv_layers.ConvWS.partial(kaiming_scaling=False) elif weight_norm == 'learned': conv = conv_layers.Conv elif weight_norm in ['none']: conv = conv_layers.Conv else: raise ValueError('weight_norm invalid option %s' % weight_norm) conv = conv.partial(compensate_padding=compensate_padding) if normalization in ['bn', 'bn_sync', 'gn_16', 'gn_32', 'gn_4', 'frn']: bias = False elif normalization in ['none']: bias = True else: raise ValueError('Does not exist') bias_init = None # TODO(basv): refactor to use config.activation_f. if activation_f.__name__ in [ 'relu', 'tlu', 'none', 'tldu', 'tlduz', 'tlum', 'relu_unitvar', 'swish', ]: kernel_init = initializers.kaiming_normal() bias_init = jax.nn.initializers.normal(bias_scale) elif activation_f.__name__ in [ 'bias_relu_norm', 'bias_SELU_norm', 'SELU_norm_rebias', 'bias_scale_relu_norm', 'bias_scale_SELU_norm', 'bias_scale_SELU_norm_gb' ]: # TODO(basv): parametrize normalized initializaton using lecun_normed(). kernel_init = initializers.lecun_normal() bias = False elif activation_f.__name__ in [ 'selu', 'relu_norm', 'capped', 'evonorm_s0', 'evonorm_b0' ]: kernel_init = initializers.lecun_normal() bias_init = jax.nn.initializers.normal(bias_scale) elif activation_f.__name__ == 'tanh': # Scale according to: # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#calculate_gain kernel_init = initializers.variance_scaling((5.0 / 3)**2, 'fan_in', 'truncated_normal') bias_init = jax.nn.initializers.normal(bias_scale) else: raise ValueError('Not prepared for activation_f', activation_f.__name__) conv = conv.partial(kernel_init=kernel_init, bias=bias, bias_init=bias_init) return conv