Пример #1
0
 def test_select_computes_n_in(self):
     layer = cb.Select([0, 0])
     self.assertEqual(layer.n_in, 1)
     layer = cb.Select([1, 0])
     self.assertEqual(layer.n_in, 2)
     layer = cb.Select([2])
     self.assertEqual(layer.n_in, 3)
Пример #2
0
def RelativeAttentionLayer(d_feature,
                           context_bias_layer,
                           location_bias_layer,
                           total_kv_pooling,
                           separate_cls,
                           n_heads=1,
                           dropout=0.0,
                           mode='train'):
    """Returns a layer that maps (q, k, v, masks) to (activations, masks).

  When number of keys is smaller than number of queries layer works in O(q^2*d).
  Otherwise it is O(q*k*d). That is because we need to shift relative distances
  by current_pooling. When we upsample this is current pooling is a fraction < 1
  Visual explanation:
  [01][23][45][67] -> [0][1][2][3][4][5][6][7]
  For token [0] we calculate relative distances as follows:
  * 0 2 4 6
  However for token [1] we need relative distances changed by 1, specifically:
  * -1 1 3 5
  So we not only need to calculate the distances that corresponds to spacing
  between the keys but also for the ones in between because there are more than
  one query tokens (on different positions which means different relative
  distances) for single key token.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    context_bias_layer: Global context bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers
    location_bias_layer: Global location bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer
    separate_cls: True/False if we separate_cls in calculations.

    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

    return cb.Serial(
        cb.Branch(
            PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling),
            cb.Select([0]), cb.Select([1])),
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        context_bias_layer,
        location_bias_layer,
        RelativeAttention(  # pylint: disable=no-value-for-parameter
            separate_cls=separate_cls,
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        core.Dense(d_feature),
    )
Пример #3
0
def SRU(n_units, activation=None):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r)))
Пример #4
0
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Returns a layer that maps (activations, mask) to (new_activations, mask).

  This layer type represents one pass of multi-head self-attention, best
  known for its central role in Transformer models. Internally, it:

    - maps activations to `(queries, keys, values)` triples,
    - splits `queries`, `keys`, and `values` into multiple 'heads',
    - computes per-head attention weights from per-head `(queries, keys)`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights,
    - uses attention weights to combine per-head `values` vectors, and
    - fuses per-head results into activations matching original input shapes.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: Either 'train' or 'eval'.
  """
    return cb.Serial(
        cb.Select([0, 0, 0]),
        AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
    )
Пример #5
0
 def test_select_second_of_3(self):
     layer = cb.Select([1], n_in=3)
     input_signature = (ShapeDtype((3, 2)), ShapeDtype(
         (4, 7)), ShapeDtype((11, 13)))
     expected_shape = (4, 7)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Пример #6
0
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Returns a layer that maps (activations, mask) to (new_activations, mask).

  This layer type represents one pass of multi-head self-attention, best
  known for its central role in Transformer models. Internally, it:

    - maps incoming sequence of activations to sequence of (query, key, value)
      triples,
    - splits queries, keys, and values into multiple 'heads',
    - computes per-head attention weights from per-head (queries, keys),
    - applies mask to screen out positions that come from padding tokens,
    - [in ``'train'`` mode] applies dropout to attention weights,
    - uses attention weights to combine per-head values vectors, and
    - fuses per-head results into outgoing activations matching original input
      activation shapes.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
  """
    return cb.Serial(
        cb.Select([0, 0, 0]),
        AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
    )
Пример #7
0
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
    """Returns a layer that maps (vectors, mask) to (new_vectors, mask).

  This layer type represents one pass of multi-head self-attention, from vector
  set to vector set, using masks to represent out-of-bound (e.g., padding)
  positions. It:

    - maps incoming sequence of activations vectors to sequence of (query, key,
      value) triples,
    - splits queries, keys, and values into multiple 'heads',
    - computes per-head attention weights from per-head (queries, keys),
    - applies mask to screen out positions that come from padding tokens,
    - [in ``'train'`` mode] applies dropout to attention weights,
    - uses attention weights to combine per-head values vectors, and
    - fuses per-head results into outgoing activations matching original input
      activation shapes.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
  """
    return cb.Serial(
        cb.Select([0, 0, 0]),
        AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
    )
Пример #8
0
def RelativeAttentionLMLayer(d_feature,
                             total_kv_pooling,
                             n_heads=1,
                             dropout=0.0,
                             n_raw_tokens_generated=1,
                             max_inference_length=3072,
                             chunk_len=None,
                             chunk_offset=None,
                             mode='train'):
  """Returns a layer that maps (q, k, v) to (activations).

  Same as standard Relative attention layer but additionally based on sizes
  of queries and keys prepares a mask that masks out the future.
  Masking the future is the concept primarily used for Language Modelling.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer.
    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
      activations (based on query-key pairs) before dotting them with values.
    n_raw_tokens_generated: Number of tokens generated in a single pass through
      this layer. Used only in 'predict' non-training mode.
    max_inference_length: Maximum sequence length allowed in non-training
      modes.
    chunk_len (optional): Number of tokens per chunk. Setting this option will
      enable chunked attention.
    chunk_offset (optional): Offset for shifting chunks, for shifted chunked
      attention
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

  attention = RelativeAttentionLayer(
      d_feature,
      total_kv_pooling,
      n_heads=n_heads,
      dropout=dropout,
      n_raw_tokens_generated=n_raw_tokens_generated,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      mode=mode)

  mask_layer = AttentionMaskLayer(
      total_kv_pooling=total_kv_pooling,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      n_raw_tokens_generated=n_raw_tokens_generated,
      mode=mode)

  return cb.Serial(
      cb.Branch(
          None,
          mask_layer,  # vecs, mask
      ),
      attention,  # vecs, mask
      cb.Select([0], n_in=2),  # vecs
  )
Пример #9
0
def LSTM(n_units):
    """LSTM running on axis 1."""
    zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
    return cb.Serial(
        cb.Branch([], zero_state),
        cb.Scan(LSTMCell(n_units=n_units), axis=1),
        cb.Select([0], n_in=2)  # Drop RNN state.
    )
Пример #10
0
def LSTM(n_units, mode='train', return_state=False, initial_state=False):
    """LSTM running on axis 1.

  Args:
    n_units: `n_units` for the `LSTMCell`.
    mode: if 'predict' then we save the previous state for one-by-one inference.
    return_state: Boolean. Whether to return the latest status in addition to
      the output. Default: False.
    initial_state: Boolean. If the state RNN (c, h) is to be obtained from the
      stack. Default: False.

  Returns:
    A LSTM layer.
  """

    if not initial_state:
        zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
        if return_state:
            return cb.Serial(cb.Branch([], zero_state),
                             cb.Scan(LSTMCell(n_units=n_units),
                                     axis=1,
                                     mode=mode),
                             name=f'LSTM_{n_units}',
                             sublayers_to_print=[])
        else:
            return cb.Serial(
                cb.Branch([], zero_state),  # fill state RNN with zero.
                cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
                cb.Select([0], n_in=2),  # Drop RNN state.
                # Set the name to LSTM and don't print sublayers.
                name=f'LSTM_{n_units}',
                sublayers_to_print=[])
    else:
        if return_state:
            return cb.Serial(cb.Scan(LSTMCell(n_units=n_units),
                                     axis=1,
                                     mode=mode),
                             name=f'LSTM_{n_units}',
                             sublayers_to_print=[])
        else:
            return cb.Serial(
                cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
                cb.Select([0], n_in=2),  # Drop RNN state.
                name=f'LSTM_{n_units}',
                sublayers_to_print=[])
Пример #11
0
def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads,
                        dropout, dropout_shared_axes, mode, ff_activation,
                        context_bias_layer, location_bias_layer, total_pooling,
                        resampling_fn):
    """Attention resampling."""

    attention = RelativeAttentionLMLayer(d_model,
                                         context_bias_layer,
                                         location_bias_layer,
                                         total_pooling,
                                         n_heads=n_heads,
                                         dropout=dropout,
                                         mode=mode)

    feed_forward = FeedForwardBlock(d_model, d_ff, dropout,
                                    dropout_shared_axes, mode, ff_activation)

    resampling = resampling_fn(shorten_factor, d_model, mode=mode)

    def _Dropout():
        return core.Dropout(rate=dropout,
                            shared_axes=dropout_shared_axes,
                            mode=mode)

    return [
        LayerNorm(),  # h
        cb.Branch(cb.Serial(
            resampling,
            LayerNorm(),
        ), None),  # h', h
        cb.Serial(  # pylint: disable=g-long-ternary
            cb.Select([0, 2, 1, 2]),
            cb.Add(),
        ) if is_upsampling else [],
        cb.Residual(
            cb.Select([0, 1, 1]),  # h', h, h
            attention,
            _Dropout(),
        ),
        cb.Residual(
            LayerNorm(),
            feed_forward,
            _Dropout(),
        ),
    ]
Пример #12
0
def GRU(n_units):
  """GRU running on axis 1."""
  zero_state = MakeZeroState(depth_multiplier=1)  # pylint: disable=no-value-for-parameter
  return cb.Serial(
      cb.Branch([], zero_state),
      cb.Scan(GRUCell(n_units=n_units), axis=1),
      cb.Select([0], n_in=2),  # Drop RNN state.
      # Set the name to GRU and don't print sublayers.
      name=f'GRU_{n_units}', sublayers_to_print=[]
  )
Пример #13
0
def LSTM(n_units, mode='train'):
  """LSTM running on axis 1."""
  zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
  return cb.Serial(
      cb.Branch([], zero_state),
      cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
      cb.Select([0], n_in=2),  # Drop RNN state.
      # Set the name to LSTM and don't print sublayers.
      name=f'LSTM_{n_units}', sublayers_to_print=[]
  )
Пример #14
0
def RelativeAttentionLMLayer(d_feature,
                             context_bias_layer,
                             location_bias_layer,
                             total_kv_pooling,
                             separate_cls=False,
                             n_heads=1,
                             dropout=0.0,
                             n_raw_tokens_generated=1,
                             max_inference_length=3072,
                             mode='train'):
    """Returns a layer that maps (q, k, v) to (activations).

  Same as standard Relative attention layer but additionally based on sizes
  of queries and keys prepares a mask that masks out the future.
  Masking the future is the concept primarily used for Language Modelling.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    context_bias_layer: Global context bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers.
    location_bias_layer: Global location bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer.
    separate_cls: True/False if we separate_cls in calculations.
    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
      activations (based on query-key pairs) before dotting them with values.
    n_raw_tokens_generated: Number of tokens generated in a single pass through
      this layer. Used only in 'predict' non-training mode.
    max_inference_length: Maximum sequence length allowed in non-training
        modes.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

    attention = RelativeAttentionLayer(
        d_feature,
        context_bias_layer,
        location_bias_layer,
        total_kv_pooling,
        separate_cls,
        n_heads=n_heads,
        dropout=dropout,
        n_raw_tokens_generated=n_raw_tokens_generated,
        max_inference_length=max_inference_length,
        mode=mode)

    return cb.Serial(
        AttentionMaskLayer(total_kv_pooling=total_kv_pooling,
                           n_raw_tokens_generated=n_raw_tokens_generated,
                           max_inference_length=max_inference_length,
                           mode=mode),  # q, k, v, mask
        attention,  # vecs, mask
        cb.Select([0], n_in=2),  # vecs
    )
Пример #15
0
def _WeightedMaskedMean(metric_layer, id_to_mask, has_weights):
  """Computes weighted masked mean of metric_layer(predictions, targets)."""
  multiply_by_weights = cb.Multiply() if has_weights else []
  # Create a layer with 2 or 3 inputs:
  #   - predictions targets (weights)
  # that applies the specified metric to a batch and gathers the results into
  # a single scalar.
  return cb.Serial(
      cb.Select([0, 1, 1]),
      cb.Parallel(metric_layer, _ElementMask(id_to_mask=id_to_mask)),
      cb.Parallel([], multiply_by_weights),  # Stack now: metric_values weights
      _WeightedMean()
  )
Пример #16
0
def recombine(eqns, inputs, outputs):
    """Implement derived equations via layer-applications and combinators.

  Args:
    eqns: list of ApplyEqns derived from dataflow traces.
    inputs: list of strings representing input symbols
    outputs: list of strings representing output symbols

  Returns:
    Trax layer object that implements the given dataflow on provided layers.
  """
    stack = tuple(inputs)  # models the data stack
    layers = []  # output trax layers

    # Keep track of what variables are still needed after each
    # layer application so we can discard unnecessary variables
    # from the data stack.
    keepsets = [set(outputs)]
    for e in reversed(eqns):
        keepsets.append(keepsets[-1].union(e.src))
    keepsets = list(reversed(keepsets[:-1]))

    # For each layer application, rearrange the data stack to supply
    # its inputs, copying arguments needed later on.
    for eqn, keep in zip(eqns, keepsets):
        remainder = tuple(s for s in stack if s in keep)
        # only insert data-routing layer if needed:
        if stack != eqn.src + remainder:
            select_indices = [stack.index(var) for var in eqn.src + remainder]
            layers.append(cb.Select(select_indices, len(stack)))
        # stack now equals eqn.src + remainder
        layers.append(eqn.lyr)
        stack = eqn.dst + remainder
    # Finally, if needed, select out the final outputs from the data stack.
    if stack != tuple(outputs):
        layers.append(
            cb.Select([stack.index(var) for var in outputs], len(stack)))
    return cb.Serial(*layers)
Пример #17
0
def SRU(n_units, activation=None, mode='train'):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    mode: if 'predict' then we save the previous state for one-by-one inference

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        ScanSRUCell(mode=mode),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation if activation is not None else [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)),
        # Set the name to SRU and don't print sublayers.
        name=f'SRU_{n_units}',
        sublayers_to_print=[])
Пример #18
0
def RelativeAttentionWrapper(d_feature,
                             n_heads=1,
                             dropout=0.0,
                             max_inference_length=2048,
                             mode='train',
                             context_bias_layer=None,
                             location_bias_layer=None,
                             total_pooling=None):
    """Relative attention wrapper.

  Args:
    d_feature: Last/innermost dimension of activations in the input to and
      output from this layer.
    n_heads: Number of attention heads. Attention heads effectively split
      activation vectors into ``n_heads`` subvectors, of size ``d_feature /
      n_heads``.
    dropout: dropout rate.
    max_inference_length: max inference length.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    context_bias_layer: context bias layer.
    location_bias_layer: location bias layer.
    total_pooling: total pooling.

  Returns:
    relative attention layer.

  Relative attention wrapper for compatibility with configurable attention,
  so that it can be called by `ApplyAttentionLayer`.
  """
    del max_inference_length

    attention = RelativeAttentionLMLayer(d_feature,
                                         context_bias_layer,
                                         location_bias_layer,
                                         total_pooling,
                                         n_heads=n_heads,
                                         dropout=dropout,
                                         mode=mode)

    return cb.Serial(cb.Select([0, 0, 0]), attention)
Пример #19
0
def Attention(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Returns a layer that maps `(vectors, mask)` to `(new_vectors, mask)`.

  This layer type represents one pass of multi-head self-attention, from vector
  set to vector set, using masks to represent out-of-bound (e.g., padding)
  positions. It:

    - makes three copies of incoming activations and maps these to multi-head
      query (Q) vectors, key (K) vectors, and value (V) vectors, respectively;
    - for each head, computes the scaled dot product of each Q-K pair;
    - applies mask to screen out positions that come from padding tokens
      (indicated by 0 value);
    - [in ``'train'`` mode] applies dropout to Q-K dot products;
    - for each head, computes Q-K attention strengths using a per-query softmax
      of the Q-K dot products;
    - for each head, for each query position, combines V vectors according
      to the Q-K attention strengths; and
    - concatenates and fuses resulting per-head vectors into outgoing
      activations matching original input activation shapes.

  Args:
    d_feature: Last/innermost dimension of activations in the input to and
        output from this layer.
    n_heads: Number of attention heads. Attention heads effectively split
        activation vectors into ``n_heads`` subvectors, of size
        ``d_feature / n_heads``.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored. Applies only if layer is
        created in ``'train'`` mode.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
  """
  return cb.Serial(
      cb.Select([0, 0, 0]),
      AttentionQKV(d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
  )
Пример #20
0
def SRU(n_units, activation=None):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)))
Пример #21
0
def SRU(n_units, activation=None, rescale=False, highway_bias=0):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    rescale: To offset the problem of the gradient vanishing in the h_t as a result
    of light recurrence and highway computation for deeper layers, a scaling correction
    alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755,
    page 4, section 3.2 Initialization.
    highway_bias: intial bias of highway gates
  Returns:
    The SRU layer.
  """
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(core.Sigmoid(), core.Sigmoid()),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r) *
                ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
Пример #22
0
def CreateAttentionMaskLayer():
    """Creates attention mask layer.

  Returns a layer that based on queries, keys and accumulated pool size of
  keys/values until this layer calculates positional embeddings for
  causal relative attention calculations.

  Takes as input q, k, v and appends proper mask in the end.
  Causal attention uses masking to prevent a given sequence position from
  attending to positions greater than / following it. This is used, for
  example, when training autoregressive sequence models, or when decoding a
  sequence symbol by symbol.

  Returns:
    an attention mask layer.
  """
    def calculate_mask(queries, keys):
        batch_size = queries.shape[0]
        keys_len, queries_len = keys.shape[-2], queries.shape[-2]
        funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len)

        return _funnel_mask(batch_size, keys_len, queries_len, funnel_factor,
                            is_upsampling)

    def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor,
                     is_upsampling):
        """Funnel mask.

    Args:
      batch_size: batch size.
      keys_len: keys length.
      queries_len: queries length.
      funnel_factor: funnel factor.
      is_upsampling: True or False.

    Returns:
      funnel mask.

    This function based on keys/queries lengths creates a triangle mask
    that prevents tokens from attending to positions following it.

    If funnel_factor is not equal to 1 due to funnel upsampling or
    downsampling it adjusts created mask for funnel attention
    by repeating each element funnel_factor times.

    This is because after funnel layer one token attends to funnel_factor
    different tokens in downsampling. During upsampling on the other hand
    funnel_factor tokens are attending to single token before upsampling.
    """

        if funnel_factor != 1:
            if not is_upsampling:
                mask = jnp.tril(
                    jnp.ones((queries_len, queries_len), dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-1)
            else:
                mask = jnp.tril(jnp.ones((keys_len, keys_len),
                                         dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-2)
        else:
            mask = jnp.tril(
                jnp.ones((queries_len, queries_len), dtype=jnp.bool_))

        return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)

    return cb.Branch(
        cb.Select([0]), cb.Select([1]), cb.Select([2]),
        cb.Fn('create attention mask layer', calculate_mask, n_out=1))
Пример #23
0
 def test_select_op_not_defined(self):
     input_shape = ((3, 2), (4, 7))
     with self.assertRaises(AttributeError):
         cb.Select(1, input_shape)
Пример #24
0
 def test_select_given_n_in(self):
     layer = cb.Select([0], n_in=2)
     self.assertEqual(layer.n_in, 2)
     layer = cb.Select([0], n_in=3)
     self.assertEqual(layer.n_in, 3)
Пример #25
0
def RelativeAttentionLayer(d_feature,
                           total_kv_pooling,
                           n_heads=1,
                           dropout=0.0,
                           n_raw_tokens_generated=1,
                           max_inference_length=3072,
                           chunk_len=None,
                           chunk_offset=None,
                           mode='train'):
  """Returns a layer that maps (q, k, v, masks) to (activations, masks).

  When number of keys is smaller than number of queries layer works in O(q^2*d).
  Otherwise it is O(q*k*d). That is because we need to shift relative distances
  by current_pooling. When we upsample this is current pooling is a fraction < 1
  Visual explanation:
  [01][23][45][67] -> [0][1][2][3][4][5][6][7]
  For token [0] we calculate relative distances as follows:
  * 0 2 4 6
  However for token [1] we need relative distances changed by 1, specifically:
  * -1 1 3 5
  So we not only need to calculate the distances that corresponds to spacing
  between the keys but also for the ones in between because there are more than
  one query tokens (on different positions which means different relative
  distances) for single key token.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer.
    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
      activations (based on query-key pairs) before dotting them with values.
    n_raw_tokens_generated: Number of tokens generated in a single pass through
      this layer. Used only in 'predict' non-training mode.
    max_inference_length: Maximum sequence length allowed in non-training
      modes.
    chunk_len (optional): Number of tokens per chunk. Setting this option will
      enable chunked attention.
    chunk_offset (optional): Offset for shifting chunks, for shifted chunked
      attention
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  pos_emb = PositionalEmbeddings(
      d_feature,
      total_kv_pooling,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      n_raw_tokens_generated=n_raw_tokens_generated,
      mode=mode)

  attention = RelativeAttention(  # pylint: disable=no-value-for-parameter
      total_kv_pooling=total_kv_pooling,
      n_heads=n_heads,
      dropout=dropout,
      n_raw_tokens_generated=n_raw_tokens_generated,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      mode=mode),

  assert d_feature % n_heads == 0
  d_head = d_feature // n_heads
  context_bias_layer = core.Weights(
      init.RandomNormalInitializer(1e-6), shape=(1, n_heads, 1, d_head))
  location_bias_layer = core.Weights(
      init.RandomNormalInitializer(1e-6), shape=(1, n_heads, 1, d_head))

  return cb.Serial(
      cb.Branch(
          cb.Serial(pos_emb, core.Dense(d_feature)),
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
          cb.Select([1])  # mask
      ),
      context_bias_layer,
      location_bias_layer,
      attention,
      core.Dense(d_feature),
  )