Exemplo n.º 1
0
class Dense(Module):
    features: int
    use_bias: bool = True
    kernel_init: Callable = initializers.lecun_normal()
    bias_init: Callable = initializers.zeros
    dtype: Any = jnp.float32
    precision: Any = None

    @compact
    def __call__(self, inputs):
        inputs = jnp.asarray(inputs, self.dtype)
        kernel = self.param('kernel', self.kernel_init,
                            (inputs.shape[-1], self.features))
        kernel = jnp.asarray(kernel, self.dtype)
        y = lax.dot_general(inputs,
                            kernel, (((inputs.ndim - 1, ), (0, )), ((), ())),
                            precision=self.precision)
        if self.use_bias:
            bias = self.param('bias', self.bias_init, (self.features, ))
            bias = jnp.asarray(bias, self.dtype)
            y = y + bias
        return y
Exemplo n.º 2
0
"""Linear modules."""

from collections.abc import Iterable  # pylint: disable=g-importing-member

from flax.nn import initializers

from flax.core import Scope

from flax import struct

from jax import lax

import jax.numpy as jnp
import numpy as onp

default_kernel_init = initializers.lecun_normal()


def _normalize_axes(axes, ndim):
    # A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
    return tuple([ax if ax >= 0 else ndim + ax for ax in axes])


def dense_general(scope,
                  inputs,
                  features,
                  axis=-1,
                  batch_dims=(),
                  bias=True,
                  dtype=jnp.float32,
                  kernel_init=default_kernel_init,
Exemplo n.º 3
0
def get_conv(activation_f, bias_scale, weight_norm, compensate_padding,
             normalization):
    """Create Conv constructor based on architecture settings."""

    if weight_norm == 'fixed':
        conv = conv_layers.ConvFixedScale
    elif weight_norm == 'learned_b':
        conv = conv_layers.ConvLearnedScale
    elif weight_norm == 'ws_sqrt':
        conv = conv_layers.ConvWS.partial(kaiming_scaling=True)
    elif weight_norm == 'ws':
        conv = conv_layers.ConvWS.partial(kaiming_scaling=False)
    elif weight_norm == 'learned':
        conv = conv_layers.Conv
    elif weight_norm in ['none']:
        conv = conv_layers.Conv
    else:
        raise ValueError('weight_norm invalid option %s' % weight_norm)

    conv = conv.partial(compensate_padding=compensate_padding)

    if normalization in ['bn', 'bn_sync', 'gn_16', 'gn_32', 'gn_4', 'frn']:
        bias = False
    elif normalization in ['none']:
        bias = True
    else:
        raise ValueError('Does not exist')

    bias_init = None
    # TODO(basv): refactor to use config.activation_f.
    if activation_f.__name__ in [
            'relu',
            'tlu',
            'none',
            'tldu',
            'tlduz',
            'tlum',
            'relu_unitvar',
            'swish',
    ]:
        kernel_init = initializers.kaiming_normal()
        bias_init = jax.nn.initializers.normal(bias_scale)
    elif activation_f.__name__ in [
            'bias_relu_norm', 'bias_SELU_norm', 'SELU_norm_rebias',
            'bias_scale_relu_norm', 'bias_scale_SELU_norm',
            'bias_scale_SELU_norm_gb'
    ]:
        # TODO(basv): parametrize normalized initializaton using lecun_normed().
        kernel_init = initializers.lecun_normal()
        bias = False
    elif activation_f.__name__ in [
            'selu', 'relu_norm', 'capped', 'evonorm_s0', 'evonorm_b0'
    ]:
        kernel_init = initializers.lecun_normal()
        bias_init = jax.nn.initializers.normal(bias_scale)
    elif activation_f.__name__ == 'tanh':
        # Scale according to:
        # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#calculate_gain
        kernel_init = initializers.variance_scaling((5.0 / 3)**2, 'fan_in',
                                                    'truncated_normal')
        bias_init = jax.nn.initializers.normal(bias_scale)
    else:
        raise ValueError('Not prepared for activation_f',
                         activation_f.__name__)
    conv = conv.partial(kernel_init=kernel_init,
                        bias=bias,
                        bias_init=bias_init)
    return conv