class Dense(Module): features: int kernel_init: Callable = initializers.lecun_normal() bias_init: Callable = initializers.zeros use_bias: bool = True @compact def __call__(self, inputs): kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features)) y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())),) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,)) y = y + bias return y
from flax.linen.module import Module, compact from flax.linen.initializers import lecun_normal, variance_scaling, zeros from jax import lax import jax.numpy as jnp import numpy as np PRNGKey = Any Shape = Iterable[int] Dtype = Any # this could be a real type? Array = Any default_kernel_init = lecun_normal() def _normalize_axes(axes, ndim): # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. return tuple(sorted([ax if ax >= 0 else ndim + ax for ax in axes])) def _canonicalize_tuple(x): if isinstance(x, Iterable): return tuple(x) else: return (x,) class DenseGeneral(Module):