Beispiel #1
0
 def _SparsifiableDense(layer_sparsity):
   if layer_sparsity is None:
     return core.Dense(d_feature)
   elif layer_sparsity == 'noop':
     return cb.Serial()  # No-op layer.
   else:
     d_module = d_feature // layer_sparsity
     return cb.Serial(
         sparsity.FactoredDense(layer_sparsity, d_feature, d_feature),
         sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode,
                                   kernel_size=3, length_kernel_size=3)
     )
  def test_dense_param_sharing(self):
    model1 = combinators.Serial(core.Dense(32), core.Dense(32))
    layer = core.Dense(32)
    model2 = combinators.Serial(layer, layer)

    input_signature = ShapeDtype((1, 32))
    params1, _ = model1.initialize_once(input_signature)
    params2, _ = model2.initialize_once(input_signature)
    # The first parameters have 2 kernels of size (32, 32).
    self.assertEqual((32, 32), params1[0][0].shape)
    self.assertEqual((32, 32), params1[1][0].shape)
    # The second parameters have 1 kernel of size (32, 32) and an empty dict.
    self.assertEqual((32, 32), params2[0][0].shape)
    self.assertEqual((), params2[1])
Beispiel #3
0
def CountWeights(mask_id=None, has_weights=False):
    """Sum the weights assigned to all elements."""
    if has_weights:
        return cb.Serial(
            cb.Drop(),  # Drop inputs.
            WeightMask(mask_id=mask_id),  # pylint: disable=no-value-for-parameter
            cb.Multiply(),  # Multiply with provided mask.
            core.Sum(axis=None)  # Sum all weights.
        )
    return cb.Serial(
        cb.Drop(),  # Drop inputs.
        WeightMask(mask_id=mask_id),  # pylint: disable=no-value-for-parameter
        core.Sum(axis=None)  # Sum all weights.
    )
Beispiel #4
0
  def test_dense_param_sharing(self):
    model1 = combinators.Serial(core.Dense(32), core.Dense(32))
    layer = core.Dense(32)
    model2 = combinators.Serial(layer, layer)

    rng1, rng2 = backend.random.split(backend.random.get_prng(0), 2)
    params1, _ = model1.initialize_once((1, 32), onp.float32, rng1)
    params2, _ = model2.initialize_once((1, 32), onp.float32, rng2)
    # The first parameters have 2 kernels of size (32, 32).
    self.assertEqual((32, 32), params1[0][0].shape)
    self.assertEqual((32, 32), params1[1][0].shape)
    # The second parameters have 1 kernel of size (32, 32) and an empty dict.
    self.assertEqual((32, 32), params2[0][0].shape)
    self.assertEqual((), params2[1])
Beispiel #5
0
    def test_dense_weight_sharing(self):
        model1 = combinators.Serial(core.Dense(32), core.Dense(32))
        layer = core.Dense(32)
        model2 = combinators.Serial(layer, layer)

        input_signature = ShapeDtype((1, 32))
        weights1, _ = model1.init(input_signature)
        weights2, _ = model2.init(input_signature)
        # The first weights have 2 kernels of size (32, 32).
        self.assertEqual((32, 32), weights1[0][0].shape)
        self.assertEqual((32, 32), weights1[1][0].shape)
        # The second weights have 1 kernel of size (32, 32) and an empty dict.
        self.assertEqual((32, 32), weights2[0][0].shape)
        self.assertEqual((), weights2[1])
Beispiel #6
0
def Accuracy(classifier=core.ArgMax()):
    """Returns a layer that computes mean category prediction accuracy."""
    return cb.Serial(classifier,
                     _Accuracy(),
                     _WeightedMean(),
                     name='Accuracy',
                     sublayers_to_print=[])
Beispiel #7
0
def SequenceAccuracy(classifier=core.ArgMax()):
    """Returns a layer that computes mean sequence prediction accuracy."""
    return cb.Serial(classifier,
                     _Accuracy(),
                     _WeightedSequenceMean(),
                     name='SequenceAccuracy',
                     sublayers_to_print=[])
Beispiel #8
0
def AssertFunction(specification, layer, message=None):  # pylint: disable=invalid-name
    """AssertFunction asserts shapes on the input/output tensors of a layer.

  It passes all inputs to the layer, and returns all outputs of the layer
  unchanged.

  Args:
    specification: A specification. See assert_shape decorator for a full
        documentation.
    layer: A base.Layer to wrap around.
    message: An optional message to print if an assert fails. By default it will
        print the filename and the line number where AssertFunction was called.

  Returns:
    The given layer wrapped in asserts on its inputs and outputs.
  """
    if message is None:
        caller = inspect.getframeinfo(inspect.stack()[1][0])
        message = f'Defined at {caller.filename}:{caller.lineno}'
    before_spec, after_spec = specification.split('->')
    before_assert = AssertShape(before_spec,
                                message=message + ' function input')
    after_assert = AssertShape(after_spec,
                               message=message + ' function output')
    after_assert._create_link(before_assert)  # pylint: disable=protected-access
    return combinators.Serial(before_assert, layer, after_assert)
Beispiel #9
0
def SumOfWeights():
    """Returns a layer that computes sum of weights."""
    return cb.Serial(
        cb.Drop(),  # Drop inputs.
        cb.Drop(),  # Drop targets.
        core.Sum(axis=None)  # Sum weights.
    )
Beispiel #10
0
def _WeightedMaskedMean(metric_layer, final_layer_override=None):
  """Computes weighted masked mean of metric_layer(predictions, targets)."""
  final_layer = final_layer_override or _WeightedMean()  # For sequence acc.
  return cb.Serial(
      metric_layer,
      final_layer
  )
Beispiel #11
0
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Returns a layer that maps (activations, mask) to (new_activations, mask).

  This layer type represents one pass of multi-head self-attention, best
  known for its central role in Transformer models. Internally, it:

    - maps incoming sequence of activations to sequence of (query, key, value)
      triples,
    - splits queries, keys, and values into multiple 'heads',
    - computes per-head attention weights from per-head (queries, keys),
    - applies mask to screen out positions that come from padding tokens,
    - [in ``'train'`` mode] applies dropout to attention weights,
    - uses attention weights to combine per-head values vectors, and
    - fuses per-head results into outgoing activations matching original input
      activation shapes.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
  """
    return cb.Serial(
        cb.Select([0, 0, 0]),
        AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
    )
Beispiel #12
0
 def QKVLayer():
     """Function returning the Q, K and V layer."""
     if use_dconv:
         return cb.Serial(core.Dense(d_feature),
                          convolution.CausalDepthwiseConv())
     else:
         return core.Dense(d_feature)
Beispiel #13
0
def CrossEntropyLossWithLogSoftmax():
    """Mean prediction-target cross-entropy for multiclass classification."""
    return cb.Serial(core.LogSoftmax(),
                     _CrossEntropy(),
                     _WeightedMean(),
                     name='CrossEntropyLossWithLogSoftmax',
                     sublayers_to_print=[])
Beispiel #14
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train',
                 cache_KV_in_predict=False, q_sparsity=None,
                 result_sparsity=None):
  """Returns a layer that maps `(AQ, AK, AV, mask)` to `(new-A, mask)`.

  Unlike :py:class:`Attention` above, :py:class:`AttentionQKV` allows the
  incoming activations (`AQ`, `AK`, and `AV`) to come from different sources.
  This is used, for instance, in encoder-decoder attention (Q-related
  activations `AQ` from the decoder, K- and V-related activations -- `AK` and
  `AV` -- from the encoder). Otherwise, see the :py:class:`Attention`
  description for further context/details.

  Args:
    d_feature: Last/innermost dimension of activations in the input to and
        output from this layer.
    n_heads: Number of attention heads. Attention heads effectively split
        activation vectors into ``n_heads`` subvectors, of size
        ``d_feature / n_heads``.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    cache_KV_in_predict: Whether to cache K/V arrays in ``'predict'`` mode.
    q_sparsity: Sparsity with which to process queries. If ``None``,
        :py:class:`Dense` is used; if ``'noop'``, no processing is used.
    result_sparsity: Sparsity with which to process result of the attention.
        If ``None``, :py:class:`Dense` is used; if ``'noop'``, no processing is
        used.
  """
  def _SparsifiableDense(layer_sparsity):
    if layer_sparsity is None:
      return core.Dense(d_feature)
    elif layer_sparsity == 'noop':
      return cb.Serial()  # No-op layer.
    else:
      d_module = d_feature // layer_sparsity
      return cb.Serial(
          sparsity.FactoredDense(layer_sparsity, d_feature, d_feature),
          sparsity.LocallyConvDense(layer_sparsity, d_module, mode=mode,
                                    kernel_size=3, length_kernel_size=3)
      )

  def _CacheableDense():
    if cache_KV_in_predict and mode == 'predict':
      return cb.Cache(core.Dense(d_feature))
    else:
      return core.Dense(d_feature)

  def _PureAttention():
    return PureAttention(n_heads=n_heads, dropout=dropout, mode=mode)

  return cb.Serial(
      cb.Parallel(_SparsifiableDense(q_sparsity),
                  _CacheableDense(),
                  _CacheableDense()),
      _PureAttention(),
      _SparsifiableDense(result_sparsity),
  )
Beispiel #15
0
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Returns a layer that maps (vectors, mask) to (new_vectors, mask).

  This layer type represents one pass of multi-head self-attention, from vector
  set to vector set, using masks to represent out-of-bound (e.g., padding)
  positions. It:

    - maps incoming sequence of activations vectors to sequence of (query, key,
      value) triples,
    - splits queries, keys, and values into multiple 'heads',
    - computes per-head attention weights from per-head (queries, keys),
    - applies mask to screen out positions that come from padding tokens,
    - [in ``'train'`` mode] applies dropout to attention weights,
    - uses attention weights to combine per-head values vectors, and
    - fuses per-head results into outgoing activations matching original input
      activation shapes.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
  """
    return cb.Serial(
        cb.Select([0, 0, 0]),
        AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
    )
Beispiel #16
0
def SRU(n_units, activation=None):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r)))
Beispiel #17
0
def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Transformer-style multi-headed attention.

  Accepts inputs of the form q, k, v, mask.

  Args:
    d_feature: int:  dimensionality of feature embedding
    n_heads: int: number of attention heads
    dropout: float: dropout rate
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result and the mask.
  """
    return cb.Serial(
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        PureAttention(  # pylint: disable=no-value-for-parameter
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        core.Dense(d_feature),
    )
Beispiel #18
0
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Returns a layer that maps (activations, mask) to (new_activations, mask).

  This layer type represents one pass of multi-head self-attention, best
  known for its central role in Transformer models. Internally, it:

    - maps activations to `(queries, keys, values)` triples,
    - splits `queries`, `keys`, and `values` into multiple 'heads',
    - computes per-head attention weights from per-head `(queries, keys)`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights,
    - uses attention weights to combine per-head `values` vectors, and
    - fuses per-head results into activations matching original input shapes.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or 'eval'.
  """
    return cb.Serial(
        cb.Dup(),
        cb.Dup(),  # TODO(jonni): replace with Select([0, 0, 0])
        AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
    )
Beispiel #19
0
def SumOfWeights():
    """Returns a layer to compute sum of weights of all non-masked elements."""
    return cb.Serial(
        cb.Drop(),  # Drop inputs.
        cb.Drop(),  # Drop targets.
        core.Sum(axis=None)  # Sum weights.
    )
Beispiel #20
0
def GeneralGRUCell(candidate_transform,
                   memory_transform_fn=None,
                   gate_nonlinearity=core.Sigmoid,
                   candidate_nonlinearity=core.Tanh,
                   dropout_rate_c=0.1,
                   sigmoid_bias=0.5):
  r"""Parametrized Gated Recurrent Unit (GRU) cell construction.

  GRU update equations:
  $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$
  $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$
  $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$
  $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$

  See combinators.Gate for details on the gating function.


  Args:
    candidate_transform: Transform to apply inside the Candidate branch. Applied
      before nonlinearities.
    memory_transform_fn: Optional transformation on the memory before gating.
    gate_nonlinearity: Function to use as gate activation. Allows trying
      alternatives to Sigmoid, such as HardSigmoid.
    candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows
      trying alternatives to traditional Tanh, such as HardTanh
    dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works
      best in a GRU when applied exclusively to this branch.
    sigmoid_bias: Constant to add before sigmoid gates. Generally want to start
      off with a positive bias.

  Returns:
    A model representing a GRU cell with specified transforms.
  """
  gate_block = [  # u_t
      candidate_transform(),
      core.AddConstant(constant=sigmoid_bias),
      gate_nonlinearity(),
  ]
  reset_block = [  # r_t
      candidate_transform(),
      core.AddConstant(constant=sigmoid_bias),  # Want bias to start positive.
      gate_nonlinearity(),
  ]
  candidate_block = [
      cb.Dup(),
      reset_block,
      cb.Multiply(),  # Gate S{t-1} with sigmoid(candidate_transform(S{t-1}))
      candidate_transform(),  # Final projection + tanh to get Ct
      candidate_nonlinearity(),  # Candidate gate

      # Only apply dropout on the C gate. Paper reports 0.1 as a good default.
      core.Dropout(rate=dropout_rate_c)
  ]
  memory_transform = memory_transform_fn() if memory_transform_fn else []
  return cb.Serial(
      cb.Dup(), cb.Dup(),
      cb.Parallel(memory_transform, gate_block, candidate_block),
      cb.Gate(),
  )
Beispiel #21
0
def RelativeAttentionLMLayer(d_feature,
                             total_kv_pooling,
                             n_heads=1,
                             dropout=0.0,
                             n_raw_tokens_generated=1,
                             max_inference_length=3072,
                             chunk_len=None,
                             chunk_offset=None,
                             mode='train'):
  """Returns a layer that maps (q, k, v) to (activations).

  Same as standard Relative attention layer but additionally based on sizes
  of queries and keys prepares a mask that masks out the future.
  Masking the future is the concept primarily used for Language Modelling.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer.
    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
      activations (based on query-key pairs) before dotting them with values.
    n_raw_tokens_generated: Number of tokens generated in a single pass through
      this layer. Used only in 'predict' non-training mode.
    max_inference_length: Maximum sequence length allowed in non-training
      modes.
    chunk_len (optional): Number of tokens per chunk. Setting this option will
      enable chunked attention.
    chunk_offset (optional): Offset for shifting chunks, for shifted chunked
      attention
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

  attention = RelativeAttentionLayer(
      d_feature,
      total_kv_pooling,
      n_heads=n_heads,
      dropout=dropout,
      n_raw_tokens_generated=n_raw_tokens_generated,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      mode=mode)

  mask_layer = AttentionMaskLayer(
      total_kv_pooling=total_kv_pooling,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      n_raw_tokens_generated=n_raw_tokens_generated,
      mode=mode)

  return cb.Serial(
      cb.Branch(
          None,
          mask_layer,  # vecs, mask
      ),
      attention,  # vecs, mask
      cb.Select([0], n_in=2),  # vecs
  )
Beispiel #22
0
def CausalAttention(d_feature, n_heads=1, dropout=0.0,
                    max_inference_length=2048, mode='train'):
  """Returns a layer that maps activations to activations, with causal masking.

  Like `Attention`, this layer type represents one pass of multi-head
  self-attention, but with causal masking rather than padding-based masking.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  if d_feature % n_heads != 0:
    raise ValueError(
        f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
        f'of the requested number of attention heads ({n_heads}).')

  d_head = d_feature // n_heads

  def _split_into_heads():
    """Returns a layer that reshapes tensors for multi-headed computation."""
    def f(x):
      batch_size = x.shape[0]
      seq_len = x.shape[1]

      # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
      x = x.reshape((batch_size, seq_len, n_heads, d_head))
      x = x.transpose((0, 2, 1, 3))
      x = x.reshape((-1, seq_len, d_head))
      return x
    return Fn('SplitIntoHeads', f)

  def _merge_heads():
    """Returns a layer that undoes splitting, after multi-head computation."""
    def f(x):
      seq_len = x.shape[1]

      # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
      x = x.reshape((-1, n_heads, seq_len, d_head))
      x = x.transpose((0, 2, 1, 3))
      x = x.reshape((-1, seq_len, n_heads * d_head))
      return x
    return Fn('MergeHeads', f)

  return cb.Serial(
      cb.Branch(
          [core.Dense(d_feature), _split_into_heads()],
          [core.Dense(d_feature), _split_into_heads()],
          [core.Dense(d_feature), _split_into_heads()],
      ),
      DotProductCausalAttention(
          dropout=dropout, max_inference_length=max_inference_length,
          mode=mode),
      _merge_heads(),
      core.Dense(d_feature),
  )
Beispiel #23
0
def LSTM(n_units):
    """LSTM running on axis 1."""
    zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
    return cb.Serial(
        cb.Branch([], zero_state),
        cb.Scan(LSTMCell(n_units=n_units), axis=1),
        cb.Select([0], n_in=2)  # Drop RNN state.
    )
Beispiel #24
0
def RelativeAttentionLayer(d_feature,
                           context_bias_layer,
                           location_bias_layer,
                           total_kv_pooling,
                           separate_cls,
                           n_heads=1,
                           dropout=0.0,
                           mode='train'):
    """Returns a layer that maps (q, k, v, masks) to (activations, masks).

  When number of keys is smaller than number of queries layer works in O(q^2*d).
  Otherwise it is O(q*k*d). That is because we need to shift relative distances
  by current_pooling. When we upsample this is current pooling is a fraction < 1
  Visual explanation:
  [01][23][45][67] -> [0][1][2][3][4][5][6][7]
  For token [0] we calculate relative distances as follows:
  * 0 2 4 6
  However for token [1] we need relative distances changed by 1, specifically:
  * -1 1 3 5
  So we not only need to calculate the distances that corresponds to spacing
  between the keys but also for the ones in between because there are more than
  one query tokens (on different positions which means different relative
  distances) for single key token.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    context_bias_layer: Global context bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers
    location_bias_layer: Global location bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer
    separate_cls: True/False if we separate_cls in calculations.

    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

    return cb.Serial(
        cb.Branch(
            PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling),
            cb.Select([0]), cb.Select([1])),
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        context_bias_layer,
        location_bias_layer,
        RelativeAttention(  # pylint: disable=no-value-for-parameter
            separate_cls=separate_cls,
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        core.Dense(d_feature),
    )
Beispiel #25
0
def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads,
                        dropout, dropout_shared_axes, mode, ff_activation,
                        context_bias_layer, location_bias_layer, total_pooling,
                        resampling_fn):
    """Attention resampling."""

    attention = RelativeAttentionLMLayer(d_model,
                                         context_bias_layer,
                                         location_bias_layer,
                                         total_pooling,
                                         n_heads=n_heads,
                                         dropout=dropout,
                                         mode=mode)

    feed_forward = FeedForwardBlock(d_model, d_ff, dropout,
                                    dropout_shared_axes, mode, ff_activation)

    resampling = resampling_fn(shorten_factor, d_model, mode=mode)

    def _Dropout():
        return core.Dropout(rate=dropout,
                            shared_axes=dropout_shared_axes,
                            mode=mode)

    return [
        LayerNorm(),  # h
        cb.Branch(cb.Serial(
            resampling,
            LayerNorm(),
        ), None),  # h', h
        cb.Serial(  # pylint: disable=g-long-ternary
            cb.Select([0, 2, 1, 2]),
            cb.Add(),
        ) if is_upsampling else [],
        cb.Residual(
            cb.Select([0, 1, 1]),  # h', h, h
            attention,
            _Dropout(),
        ),
        cb.Residual(
            LayerNorm(),
            feed_forward,
            _Dropout(),
        ),
    ]
Beispiel #26
0
def LSTM(n_units, mode='train', return_state=False, initial_state=False):
    """LSTM running on axis 1.

  Args:
    n_units: `n_units` for the `LSTMCell`.
    mode: if 'predict' then we save the previous state for one-by-one inference.
    return_state: Boolean. Whether to return the latest status in addition to
      the output. Default: False.
    initial_state: Boolean. If the state RNN (c, h) is to be obtained from the
      stack. Default: False.

  Returns:
    A LSTM layer.
  """

    if not initial_state:
        zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
        if return_state:
            return cb.Serial(cb.Branch([], zero_state),
                             cb.Scan(LSTMCell(n_units=n_units),
                                     axis=1,
                                     mode=mode),
                             name=f'LSTM_{n_units}',
                             sublayers_to_print=[])
        else:
            return cb.Serial(
                cb.Branch([], zero_state),  # fill state RNN with zero.
                cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
                cb.Select([0], n_in=2),  # Drop RNN state.
                # Set the name to LSTM and don't print sublayers.
                name=f'LSTM_{n_units}',
                sublayers_to_print=[])
    else:
        if return_state:
            return cb.Serial(cb.Scan(LSTMCell(n_units=n_units),
                                     axis=1,
                                     mode=mode),
                             name=f'LSTM_{n_units}',
                             sublayers_to_print=[])
        else:
            return cb.Serial(
                cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
                cb.Select([0], n_in=2),  # Drop RNN state.
                name=f'LSTM_{n_units}',
                sublayers_to_print=[])
    def test_input_signatures_serial(self):
        layer = cb.Serial(core.Div(divisor=2.0), core.Div(divisor=5.0))
        self.assertIsNone(layer.input_signature)

        layer.input_signature = ShapeDtype((3, 2))
        self.assertEqual(layer.input_signature, ShapeDtype((3, 2)))
        self.assertLen(layer.sublayers, 2)
        for sublayer in layer.sublayers:
            self.assertEqual(sublayer.input_signature, ShapeDtype((3, 2)))
Beispiel #28
0
def PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling):
  """Positional embedding for relative attention.

  Returns a layer that based on queries, keys and accumulated pool size of
  keys/values until this layer calculates sinusoidal positional embeddings
  for relative attention calculations.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    separate_cls: True/False if we separate_cls in calculations.
    total_kv_pooling: Accumulated pool size of keys/values until this layer.

  Returns:
    Positional embedding.
  """

  def PositionsVectors(queries, keys):
    is_funnel_layer = queries.shape != keys.shape
    keys_len, queries_len = keys.shape[1], queries.shape[1]
    current_pooling_ratio = keys_len / queries_len

    # Special case of upsampling
    if is_funnel_layer and current_pooling_ratio < 1:
      # We should not be doing standard upsampling when we use separate_cls
      # Cls token is being used for classification
      assert not separate_cls
      assert (total_kv_pooling * keys_len) % queries_len == 0
      multiplier = ((total_kv_pooling * keys_len) // queries_len)
      positions = jnp.arange(-queries_len + 1, queries_len, 1.0) * multiplier
    else:
      positions = jnp.arange(-keys_len + 1, keys_len, 1.0) * total_kv_pooling

    if is_funnel_layer and separate_cls:
      # For pool_size 2 without separating cls we have got
      # [0][1][2][3][4][5][6][7] -> [01][23][45][67]
      # With separating cls we have got
      # [0][1][2][3][4][5][6][7] -> [0][12][34][56]

      # First group always will always consist of one token after pooling
      # instead of (pool_size) tokens. We need to add proper offset so
      # that our shift later on in calculating attention works properly
      cls_offset = (current_pooling_ratio - 1) * total_kv_pooling
      positions = positions + cls_offset

    return positions

  def Sinusoidal_Embeddings(positions):
    inv_freq = 1 / (10000**(jnp.arange(0.0, d_feature, 2.0) / d_feature))
    sinusoid_freq = jnp.einsum('i,j->ij', positions, inv_freq)
    pos_emb = jnp.concatenate(
        [jnp.sin(sinusoid_freq), jnp.cos(sinusoid_freq)], axis=1)
    return pos_emb

  return cb.Serial(
      cb.Fn('Generate positions vectors', PositionsVectors, n_out=1),
      cb.Fn(
          'Transform to sinusoidal encodings', Sinusoidal_Embeddings, n_out=1))
Beispiel #29
0
    def test_input_signatures_serial(self):
        layer = cb.Serial(divide_by(2.0), divide_by(5.0))
        self.assertIsNone(layer.input_signature)

        layer._set_input_signature_recursive(ShapeDtype((3, 2)))
        self.assertEqual(layer.input_signature, ShapeDtype((3, 2)))
        self.assertLen(layer.sublayers, 2)
        for sublayer in layer.sublayers:
            self.assertEqual(sublayer.input_signature, ShapeDtype((3, 2)))
Beispiel #30
0
def SumOfWeights(id_to_mask=None, has_weights=False):
  """Returns a layer to compute sum of weights of all non-masked elements."""
  multiply_by_weights = cb.Multiply() if has_weights else []
  return cb.Serial(
      cb.Drop(),  # Drop inputs.
      _ElementMask(id_to_mask=id_to_mask),
      multiply_by_weights,
      core.Sum(axis=None)  # Sum all.
  )