def EinsumDense(d_input, d_output, use_bias): """Returns a reimplementation of Dense layer, using einsum. While this is an equivalent of a Dense layer, it seems to be faster when used in decoding if used with bias (see decoding_timing_test.py ). This layer can be removed when we understand better the reason for the difference in decoding speed. Args: d_input: Dimensionality of the input tensor. d_output: Dimensionality of the output tensor. use_bias: Whether to use bias. """ layers = [ tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]), tl.Fn( 'EinsumDense', ( lambda kernel, embeds: # pylint: disable=g-long-lambda jnp.einsum('xd,...d->...x', kernel, embeds))) ] if use_bias: layers.extend([ tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]), tl.Add() ]) return tl.Serial(layers)
def test_custom_initializer_shape(self): layer = tl.Weights( lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32), (2, 2)) layer.init(()) y = layer(()) self.assertEqual(y.tolist(), [[0., 0.], [0., 0.]]) layer = tl.Weights(init.RandomNormalInitializer(), (2, 2)) layer.init(()) y = layer(()) self.assertEqual(y.shape, (2, 2)) self.assertNotEqual(y.tolist(), [[0., 0.], [0., 0.]])
def MultiplicativeSparseDense(sparsity, d_input, d_output=None, use_bias=True, use_bfloat16=False): """Returns a replacement of Dense layer which uses less parameters. The layer uses number of modules equal to `sparsity`. It multiplies each dimension of the input tensor by a scalar specific to each dimension and each module separately; then it applies Dense(d_output/sparsity) to each module. Compared to standard dense layer, MultiplicativeSparseDense uses less parameters while still being able to express many interesting functions (for example a permutation). Args: sparsity: The sparsity of the layer; the output vector is divided into this number of modules. d_input: Dimensionality of input tensor. d_output: Dimensionality of output tensor; by default equal to d_input. use_bias: Whether to use bias. use_bfloat16: Whether to use bfloat16 for weights. """ assert d_output % sparsity == 0 d_module = d_output // sparsity layers = [ # Weight below is used for per-head preprocessing of an embedding. tl.Weights(init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_input], use_bfloat16=use_bfloat16), # Weight below is dense kernel, shared across heads. tl.Weights(init.GlorotUniformInitializer(), [d_input, d_module], use_bfloat16=use_bfloat16), # To save memory the per-head preprocessing and multiplying by the # kernel is done in the same einsum. tl.Fn( 'AttentionEinsum', ( lambda kernel, multiplier, embeds: # pylint: disable=g-long-lambda jnp.einsum('dx,hd,...d->...hx', kernel, multiplier, embeds))), MergeLastTwoAxes(), ] if use_bias: layers.extend([ # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_output], use_bfloat16=use_bfloat16), tl.Add(), ]) return tl.Serial(layers)
def ResidualZero(*layers, shortcut=None): """Wraps a series of layers with a ReZero-style residual connection. Instead of computing `(shortcut) + (output of layers)`, like in classical Residual connection, ResidualZero computes `(shortcut) + alpha * (output of layers)`, where `alpha` is a learnable scalar initialized with zero. Args: *layers: One or more layers, to be applied in series. shortcut: If None (the usual case), the Residual layer computes the element-wise sum of the stack-top input with the output of the layer series. If specified, the `shortcut` layer applies to a copy of the inputs and (elementwise) adds its output to the output from the main layer series. Returns: A layer representing a residual connection paired with a layer series. """ layers = _ensure_flat(layers) layer = layers[0] if len(layers) == 1 else tl.Serial(layers) # TODO(jaszczur): perhaps change inner Serial to Branch? return tl.Serial( tl.Branch( shortcut, tl.Serial( layer, tl.Weights( lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)), tl.Multiply())), tl.Add(), # pylint: disable=no-value-for-parameter )
def MultiplicativeModularSparseDense(sparsity, d_feature): """Returns a replacement of Dense layer which uses less parameters. The layer uses number of modules equal to `sparsity`. It is a combination of multiplicative dense and locally connected dense layers. Args: sparsity: The sparsity of the layer; the output vector is divided into this number of modules. d_feature: Dimensionality of input and output tensor. """ assert d_feature % sparsity == 0 d_module = d_feature // sparsity return tl.Serial( # Weight below is used for per-head preprocessing of an embedding. tl.Weights(init.RandomNormalInitializer(stddev=0.5), shape=[sparsity, d_feature]), # Weight below is a kernel of multiplicative dense, shared across heads. tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]), # Weight below is a kernel of modular dense. tl.Weights( functools.partial(init.GlorotUniformInitializer(), nonreceptive_dims=[0]), [sparsity, d_module, d_module]), # To save memory the per-head preprocessing and multiplying by # kernels is done in a single einsum. tl.Fn( 'SparseDenseEinsum', ( lambda kmod, kmult, multiplier, embeds: # pylint: disable=g-long-lambda jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier, embeds))), MergeLastTwoAxes(), # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]), tl.Add(), )
def test_simple_custom_initializer(self): layer = tl.Weights(init.RandomNormalInitializer()) layer.init(()) y = layer(()) self.assertEqual(y.shape, ()) self.assertNotEqual(y.tolist(), 0.)
def test_shape(self): layer = tl.Weights(init.RandomNormalInitializer(), (5, 10, 3)) layer.init(()) y = layer(()) self.assertEqual(y.shape, (5, 10, 3))
def test_simple(self): layer = tl.Weights( lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)) layer.init(()) y = layer(()) self.assertEqual(y.tolist(), 0.)