コード例 #1
0
ファイル: sparsity.py プロジェクト: piotrekp1/trax
def EinsumDense(d_input, d_output, use_bias):
    """Returns a reimplementation of Dense layer, using einsum.

  While this is an equivalent of a Dense layer, it seems to be faster when used
  in decoding if used with bias (see decoding_timing_test.py ).
  This layer can be removed when we understand better the reason for the
  difference in decoding speed.

  Args:
    d_input: Dimensionality of the input tensor.
    d_output: Dimensionality of the output tensor.
    use_bias: Whether to use bias.
  """
    layers = [
        tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]),
        tl.Fn(
            'EinsumDense',
            (
                lambda kernel, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('xd,...d->...x', kernel, embeds)))
    ]
    if use_bias:
        layers.extend([
            tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]),
            tl.Add()
        ])
    return tl.Serial(layers)
コード例 #2
0
    def test_custom_initializer_shape(self):
        layer = tl.Weights(
            lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32), (2, 2))
        layer.init(())
        y = layer(())
        self.assertEqual(y.tolist(), [[0., 0.], [0., 0.]])

        layer = tl.Weights(init.RandomNormalInitializer(), (2, 2))
        layer.init(())
        y = layer(())
        self.assertEqual(y.shape, (2, 2))
        self.assertNotEqual(y.tolist(), [[0., 0.], [0., 0.]])
コード例 #3
0
ファイル: sparsity.py プロジェクト: piotrekp1/trax
def MultiplicativeSparseDense(sparsity,
                              d_input,
                              d_output=None,
                              use_bias=True,
                              use_bfloat16=False):
    """Returns a replacement of Dense layer which uses less parameters.

  The layer uses number of modules equal to `sparsity`. It multiplies each
  dimension of the input tensor by a scalar specific to each dimension and each
  module separately; then it applies Dense(d_output/sparsity) to each module.
  Compared to standard dense layer, MultiplicativeSparseDense uses less
  parameters while still being able to express many interesting functions (for
  example a permutation).

  Args:
    sparsity: The sparsity of the layer; the output vector is divided into this
        number of modules.
    d_input: Dimensionality of input tensor.
    d_output: Dimensionality of output tensor; by default equal to d_input.
    use_bias: Whether to use bias.
    use_bfloat16: Whether to use bfloat16 for weights.
  """

    assert d_output % sparsity == 0
    d_module = d_output // sparsity

    layers = [
        # Weight below is used for per-head preprocessing of an embedding.
        tl.Weights(init.RandomNormalInitializer(stddev=0.5),
                   shape=[sparsity, d_input],
                   use_bfloat16=use_bfloat16),
        # Weight below is dense kernel, shared across heads.
        tl.Weights(init.GlorotUniformInitializer(), [d_input, d_module],
                   use_bfloat16=use_bfloat16),
        # To save memory the per-head preprocessing and multiplying by the
        # kernel is done in the same einsum.
        tl.Fn(
            'AttentionEinsum',
            (
                lambda kernel, multiplier, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('dx,hd,...d->...hx', kernel, multiplier, embeds))),
        MergeLastTwoAxes(),
    ]
    if use_bias:
        layers.extend([
            # Weight below is bias after dense, per-head.
            tl.Weights(init.RandomNormalInitializer(1e-6), [d_output],
                       use_bfloat16=use_bfloat16),
            tl.Add(),
        ])
    return tl.Serial(layers)
コード例 #4
0
def ResidualZero(*layers, shortcut=None):
    """Wraps a series of layers with a ReZero-style residual connection.

  Instead of computing `(shortcut) + (output of layers)`, like in classical
  Residual connection, ResidualZero computes
  `(shortcut) + alpha * (output of layers)`, where `alpha` is a learnable scalar
  initialized with zero.

  Args:
    *layers: One or more layers, to be applied in series.
    shortcut: If None (the usual case), the Residual layer computes the
        element-wise sum of the stack-top input with the output of the layer
        series. If specified, the `shortcut` layer applies to a copy of the
        inputs and (elementwise) adds its output to the output from the main
        layer series.

  Returns:
      A layer representing a residual connection paired with a layer series.
  """
    layers = _ensure_flat(layers)
    layer = layers[0] if len(layers) == 1 else tl.Serial(layers)
    # TODO(jaszczur): perhaps change inner Serial to Branch?
    return tl.Serial(
        tl.Branch(
            shortcut,
            tl.Serial(
                layer,
                tl.Weights(
                    lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)),
                tl.Multiply())),
        tl.Add(),  # pylint: disable=no-value-for-parameter
    )
コード例 #5
0
ファイル: sparsity.py プロジェクト: piotrekp1/trax
def MultiplicativeModularSparseDense(sparsity, d_feature):
    """Returns a replacement of Dense layer which uses less parameters.

  The layer uses number of modules equal to `sparsity`. It is a combination of
  multiplicative dense and locally connected dense layers.

  Args:
    sparsity: The sparsity of the layer; the output vector is divided into this
        number of modules.
    d_feature: Dimensionality of input and output tensor.
  """

    assert d_feature % sparsity == 0
    d_module = d_feature // sparsity

    return tl.Serial(
        # Weight below is used for per-head preprocessing of an embedding.
        tl.Weights(init.RandomNormalInitializer(stddev=0.5),
                   shape=[sparsity, d_feature]),
        # Weight below is a kernel of multiplicative dense, shared across heads.
        tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]),
        # Weight below is a kernel of modular dense.
        tl.Weights(
            functools.partial(init.GlorotUniformInitializer(),
                              nonreceptive_dims=[0]),
            [sparsity, d_module, d_module]),
        # To save memory the per-head preprocessing and multiplying by
        # kernels is done in a single einsum.
        tl.Fn(
            'SparseDenseEinsum',
            (
                lambda kmod, kmult, multiplier, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier,
                           embeds))),
        MergeLastTwoAxes(),
        # Weight below is bias after dense, per-head.
        tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]),
        tl.Add(),
    )
コード例 #6
0
 def test_simple_custom_initializer(self):
     layer = tl.Weights(init.RandomNormalInitializer())
     layer.init(())
     y = layer(())
     self.assertEqual(y.shape, ())
     self.assertNotEqual(y.tolist(), 0.)
コード例 #7
0
 def test_shape(self):
     layer = tl.Weights(init.RandomNormalInitializer(), (5, 10, 3))
     layer.init(())
     y = layer(())
     self.assertEqual(y.shape, (5, 10, 3))
コード例 #8
0
 def test_simple(self):
     layer = tl.Weights(
         lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32))
     layer.init(())
     y = layer(())
     self.assertEqual(y.tolist(), 0.)