Ejemplo n.º 1
0
def SRU(n_units, activation=None):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r)))
Ejemplo n.º 2
0
def CausalAttention(d_feature, n_heads=1, dropout=0.0,
                    max_inference_length=2048, mode='train'):
  """Returns a layer that maps activations to activations, with causal masking.

  Like `Attention`, this layer type represents one pass of multi-head
  self-attention, but with causal masking rather than padding-based masking.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  if d_feature % n_heads != 0:
    raise ValueError(
        f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
        f'of the requested number of attention heads ({n_heads}).')

  d_head = d_feature // n_heads

  def _split_into_heads():
    """Returns a layer that reshapes tensors for multi-headed computation."""
    def f(x):
      batch_size = x.shape[0]
      seq_len = x.shape[1]

      # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
      x = x.reshape((batch_size, seq_len, n_heads, d_head))
      x = x.transpose((0, 2, 1, 3))
      x = x.reshape((-1, seq_len, d_head))
      return x
    return Fn('SplitIntoHeads', f)

  def _merge_heads():
    """Returns a layer that undoes splitting, after multi-head computation."""
    def f(x):
      seq_len = x.shape[1]

      # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
      x = x.reshape((-1, n_heads, seq_len, d_head))
      x = x.transpose((0, 2, 1, 3))
      x = x.reshape((-1, seq_len, n_heads * d_head))
      return x
    return Fn('MergeHeads', f)

  return cb.Serial(
      cb.Branch(
          [core.Dense(d_feature), _split_into_heads()],
          [core.Dense(d_feature), _split_into_heads()],
          [core.Dense(d_feature), _split_into_heads()],
      ),
      DotProductCausalAttention(
          dropout=dropout, max_inference_length=max_inference_length,
          mode=mode),
      _merge_heads(),
      core.Dense(d_feature),
  )
Ejemplo n.º 3
0
def RelativeAttentionLMLayer(d_feature,
                             total_kv_pooling,
                             n_heads=1,
                             dropout=0.0,
                             n_raw_tokens_generated=1,
                             max_inference_length=3072,
                             chunk_len=None,
                             chunk_offset=None,
                             mode='train'):
  """Returns a layer that maps (q, k, v) to (activations).

  Same as standard Relative attention layer but additionally based on sizes
  of queries and keys prepares a mask that masks out the future.
  Masking the future is the concept primarily used for Language Modelling.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer.
    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
      activations (based on query-key pairs) before dotting them with values.
    n_raw_tokens_generated: Number of tokens generated in a single pass through
      this layer. Used only in 'predict' non-training mode.
    max_inference_length: Maximum sequence length allowed in non-training
      modes.
    chunk_len (optional): Number of tokens per chunk. Setting this option will
      enable chunked attention.
    chunk_offset (optional): Offset for shifting chunks, for shifted chunked
      attention
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

  attention = RelativeAttentionLayer(
      d_feature,
      total_kv_pooling,
      n_heads=n_heads,
      dropout=dropout,
      n_raw_tokens_generated=n_raw_tokens_generated,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      mode=mode)

  mask_layer = AttentionMaskLayer(
      total_kv_pooling=total_kv_pooling,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      n_raw_tokens_generated=n_raw_tokens_generated,
      mode=mode)

  return cb.Serial(
      cb.Branch(
          None,
          mask_layer,  # vecs, mask
      ),
      attention,  # vecs, mask
      cb.Select([0], n_in=2),  # vecs
  )
Ejemplo n.º 4
0
def LSTM(n_units):
    """LSTM running on axis 1."""
    zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
    return cb.Serial(
        cb.Branch([], zero_state),
        cb.Scan(LSTMCell(n_units=n_units), axis=1),
        cb.Select([0], n_in=2)  # Drop RNN state.
    )
Ejemplo n.º 5
0
def GeneralGRUCell(candidate_transform,
                   memory_transform_fn=None,
                   gate_nonlinearity=activation_fns.Sigmoid,
                   candidate_nonlinearity=activation_fns.Tanh,
                   dropout_rate_c=0.1,
                   sigmoid_bias=0.5):
    r"""Parametrized Gated Recurrent Unit (GRU) cell construction.

  GRU update equations:
  $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$
  $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$
  $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$
  $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$

  See combinators.Gate for details on the gating function.


  Args:
    candidate_transform: Transform to apply inside the Candidate branch. Applied
      before nonlinearities.
    memory_transform_fn: Optional transformation on the memory before gating.
    gate_nonlinearity: Function to use as gate activation. Allows trying
      alternatives to Sigmoid, such as HardSigmoid.
    candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows
      trying alternatives to traditional Tanh, such as HardTanh
    dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works
      best in a GRU when applied exclusively to this branch.
    sigmoid_bias: Constant to add before sigmoid gates. Generally want to start
      off with a positive bias.

  Returns:
    A model representing a GRU cell with specified transforms.
  """
    gate_block = [  # u_t
        candidate_transform(),
        base.Fn(lambda x: x + sigmoid_bias),
        gate_nonlinearity(),
    ]
    reset_block = [  # r_t
        candidate_transform(),
        base.Fn(lambda x: x + sigmoid_bias),  # Want bias to start positive.
        gate_nonlinearity(),
    ]
    candidate_block = [
        cb.Dup(),
        reset_block,
        cb.Multiply(),  # Gate S{t-1} with sigmoid(candidate_transform(S{t-1}))
        candidate_transform(),  # Final projection + tanh to get Ct
        candidate_nonlinearity(),  # Candidate gate

        # Only apply dropout on the C gate. Paper reports 0.1 as a good default.
        core.Dropout(rate=dropout_rate_c)
    ]
    memory_transform = memory_transform_fn() if memory_transform_fn else []
    return cb.Serial(
        cb.Branch(memory_transform, gate_block, candidate_block),
        cb.Gate(),
    )
Ejemplo n.º 6
0
def RelativeAttentionLayer(d_feature,
                           context_bias_layer,
                           location_bias_layer,
                           total_kv_pooling,
                           separate_cls,
                           n_heads=1,
                           dropout=0.0,
                           mode='train'):
    """Returns a layer that maps (q, k, v, masks) to (activations, masks).

  When number of keys is smaller than number of queries layer works in O(q^2*d).
  Otherwise it is O(q*k*d). That is because we need to shift relative distances
  by current_pooling. When we upsample this is current pooling is a fraction < 1
  Visual explanation:
  [01][23][45][67] -> [0][1][2][3][4][5][6][7]
  For token [0] we calculate relative distances as follows:
  * 0 2 4 6
  However for token [1] we need relative distances changed by 1, specifically:
  * -1 1 3 5
  So we not only need to calculate the distances that corresponds to spacing
  between the keys but also for the ones in between because there are more than
  one query tokens (on different positions which means different relative
  distances) for single key token.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    context_bias_layer: Global context bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers
    location_bias_layer: Global location bias from Transformer XL's attention.
      There should be one such layer shared for all relative attention layers.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer
    separate_cls: True/False if we separate_cls in calculations.

    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """

    return cb.Serial(
        cb.Branch(
            PositionalEmbeddings(d_feature, separate_cls, total_kv_pooling),
            cb.Select([0]), cb.Select([1])),
        cb.Parallel(
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
            core.Dense(d_feature),
        ),
        context_bias_layer,
        location_bias_layer,
        RelativeAttention(  # pylint: disable=no-value-for-parameter
            separate_cls=separate_cls,
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        core.Dense(d_feature),
    )
Ejemplo n.º 7
0
def LSTM(n_units, mode='train', return_state=False, initial_state=False):
    """LSTM running on axis 1.

  Args:
    n_units: `n_units` for the `LSTMCell`.
    mode: if 'predict' then we save the previous state for one-by-one inference.
    return_state: Boolean. Whether to return the latest status in addition to
      the output. Default: False.
    initial_state: Boolean. If the state RNN (c, h) is to be obtained from the
      stack. Default: False.

  Returns:
    A LSTM layer.
  """

    if not initial_state:
        zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
        if return_state:
            return cb.Serial(cb.Branch([], zero_state),
                             cb.Scan(LSTMCell(n_units=n_units),
                                     axis=1,
                                     mode=mode),
                             name=f'LSTM_{n_units}',
                             sublayers_to_print=[])
        else:
            return cb.Serial(
                cb.Branch([], zero_state),  # fill state RNN with zero.
                cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
                cb.Select([0], n_in=2),  # Drop RNN state.
                # Set the name to LSTM and don't print sublayers.
                name=f'LSTM_{n_units}',
                sublayers_to_print=[])
    else:
        if return_state:
            return cb.Serial(cb.Scan(LSTMCell(n_units=n_units),
                                     axis=1,
                                     mode=mode),
                             name=f'LSTM_{n_units}',
                             sublayers_to_print=[])
        else:
            return cb.Serial(
                cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
                cb.Select([0], n_in=2),  # Drop RNN state.
                name=f'LSTM_{n_units}',
                sublayers_to_print=[])
Ejemplo n.º 8
0
def GRU(n_units):
  """GRU running on axis 1."""
  zero_state = MakeZeroState(depth_multiplier=1)  # pylint: disable=no-value-for-parameter
  return cb.Serial(
      cb.Branch([], zero_state),
      cb.Scan(GRUCell(n_units=n_units), axis=1),
      cb.Select([0], n_in=2),  # Drop RNN state.
      # Set the name to GRU and don't print sublayers.
      name=f'GRU_{n_units}', sublayers_to_print=[]
  )
Ejemplo n.º 9
0
def LSTM(n_units, mode='train'):
  """LSTM running on axis 1."""
  zero_state = MakeZeroState(depth_multiplier=2)  # pylint: disable=no-value-for-parameter
  return cb.Serial(
      cb.Branch([], zero_state),
      cb.Scan(LSTMCell(n_units=n_units), axis=1, mode=mode),
      cb.Select([0], n_in=2),  # Drop RNN state.
      # Set the name to LSTM and don't print sublayers.
      name=f'LSTM_{n_units}', sublayers_to_print=[]
  )
Ejemplo n.º 10
0
def SRU(n_units, activation=None, mode='train'):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    mode: if 'predict' then we save the previous state for one-by-one inference

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        ScanSRUCell(mode=mode),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation if activation is not None else [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)),
        # Set the name to SRU and don't print sublayers.
        name=f'SRU_{n_units}',
        sublayers_to_print=[])
Ejemplo n.º 11
0
def ConfigurableAttention(
        q_layer,
        k_layer,
        v_layer,
        final_layer,  # pylint: disable=invalid-name
        qkv_attention_layer,
        n_heads=1):
    return cb.Serial(
        cb.Branch(
            [q_layer, SplitIntoHeads(n_heads)],
            [k_layer, SplitIntoHeads(n_heads)],
            [v_layer, SplitIntoHeads(n_heads)],
        ), qkv_attention_layer, MergeHeads(n_heads), final_layer)
Ejemplo n.º 12
0
def SRU(n_units, activation=None):
    r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:

  .. math::
    y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\
    f_t &= \sigma(Wf x_t + bf) \\
    r_t &= \sigma(Wr x_t + br) \\
    c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\
    h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.

  Returns:
    The SRU layer.
  """
    sigmoid_activation = activation_fns.Sigmoid()
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(sigmoid_activation, sigmoid_activation),  # r, f, y, x
        base.Fn(
            '',
            lambda r, f, y: (y * (1.0 - f), f, r),  # y * (1 - f), f, r, x
            n_out=3),
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) *
                (3**0.5)))
Ejemplo n.º 13
0
def SRU(n_units, activation=None, rescale=False, highway_bias=0):
    """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755.

  As defined in the paper:
  (1) y_t = W x_t (+ B optionally, which we do)
  (2) f_t = sigmoid(Wf x_t + bf)
  (3) r_t = sigmoid(Wr x_t + br)
  (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t
  (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha

  We assume the input is of shape [batch, length, depth] and recurrence
  happens on the length dimension. This returns a single layer. It's best
  to use at least 2, they say in the paper, except inside a Transformer.

  Args:
    n_units: output depth of the SRU layer.
    activation: Optional activation function.
    rescale: To offset the problem of the gradient vanishing in the h_t as a result
    of light recurrence and highway computation for deeper layers, a scaling correction
    alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755,
    page 4, section 3.2 Initialization.
    highway_bias: intial bias of highway gates
  Returns:
    The SRU layer.
  """
    # pylint: disable=no-value-for-parameter
    return cb.Serial(  # x
        cb.Branch(core.Dense(3 * n_units), []),  # r_f_y, x
        cb.Split(n_items=3),  # r, f, y, x
        cb.Parallel(core.Sigmoid(), core.Sigmoid()),  # r, f, y, x
        base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)),  # y * (1 - f), f, r, x
        cb.Parallel([], [], cb.Branch(MakeZeroState(), [])),
        cb.Scan(InnerSRUCell(), axis=1),
        cb.Select([0], n_in=2),  # act(c), r, x
        activation or [],
        base.Fn(lambda c, r, x: c * r + x * (1 - r) *
                ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))
Ejemplo n.º 14
0
def AttentionResampling(shorten_factor, d_model, is_upsampling, d_ff, n_heads,
                        dropout, dropout_shared_axes, mode, ff_activation,
                        context_bias_layer, location_bias_layer, total_pooling,
                        resampling_fn):
    """Attention resampling."""

    attention = RelativeAttentionLMLayer(d_model,
                                         context_bias_layer,
                                         location_bias_layer,
                                         total_pooling,
                                         n_heads=n_heads,
                                         dropout=dropout,
                                         mode=mode)

    feed_forward = FeedForwardBlock(d_model, d_ff, dropout,
                                    dropout_shared_axes, mode, ff_activation)

    resampling = resampling_fn(shorten_factor, d_model, mode=mode)

    def _Dropout():
        return core.Dropout(rate=dropout,
                            shared_axes=dropout_shared_axes,
                            mode=mode)

    return [
        LayerNorm(),  # h
        cb.Branch(cb.Serial(
            resampling,
            LayerNorm(),
        ), None),  # h', h
        cb.Serial(  # pylint: disable=g-long-ternary
            cb.Select([0, 2, 1, 2]),
            cb.Add(),
        ) if is_upsampling else [],
        cb.Residual(
            cb.Select([0, 1, 1]),  # h', h, h
            attention,
            _Dropout(),
        ),
        cb.Residual(
            LayerNorm(),
            feed_forward,
            _Dropout(),
        ),
    ]
Ejemplo n.º 15
0
def CausalAttention(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Transformer-style multi-headed causal attention.

  Args:
    d_feature: int:  dimensionality of feature embedding
    n_heads: int: number of attention heads
    dropout: float: attention dropout
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention result.
  """
  assert d_feature % n_heads == 0
  d_head = d_feature // n_heads

  def compute_attention_heads(x):
    batch_size = x.shape[0]
    seqlen = x.shape[1]
    # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
    x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
    # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
    x = jnp.transpose(x, (0, 2, 1, 3))
    # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
    return jnp.reshape(x, (-1, seqlen, d_head))

  ComputeAttentionHeads = Fn('ComputeAttentionHeads', compute_attention_heads)

  def compute_attention_output(x):
    seqlen = x.shape[1]
    x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
    x = jnp.transpose(x, (0, 2, 1, 3))  # -> n_batch, seqlen, n_heads, d_head
    return jnp.reshape(x, (-1, seqlen, n_heads * d_head))

  return cb.Serial(
      cb.Branch(
          [core.Dense(d_feature), ComputeAttentionHeads],
          [core.Dense(d_feature), ComputeAttentionHeads],
          [core.Dense(d_feature), ComputeAttentionHeads],
      ),
      DotProductCausalAttention(dropout=dropout, mode=mode),
      Fn('ComputeAttentionOutput', compute_attention_output),
      core.Dense(d_feature)
  )
Ejemplo n.º 16
0
def ConfigurableAttention(
        q_layer,
        k_layer,
        v_layer,
        final_layer,  # pylint: disable=invalid-name
        qkv_attention_layer,
        n_heads=1):
    """Returns a configured multi-head self-attention layer.

  A :py:class:`ConfigurableAttention` layer acts similarly to
  :py:class:`Attention` layers, but with configurable components. It

    - makes three copies of incoming activations and uses ``q_layer``,
      ``k_layer``, and ``v_layer`` to map activations to multi-head query (Q)
      vectors, key (K) vectors, and value (V) vectors, respectively;
    - uses ``qkv_attention_layer`` to compute per-head attention, similar to
      :py:class:`DotProductAttention` or :py:class:`DotProductCausalAttention`;
    - concatenates and fuses resulting per-head vectors into activations
      matching original input activation shapes; and
    - applies a final layer, ``final_layer``, mapping activations to
      activations (with shape matching the original input activations).

  Args:
    q_layer: Layer that maps input activations to per-head query activations.
    k_layer: Layer that maps input activations to per-head key activations.
    v_layer: Layer that maps input activations to per-head value activations.
    final_layer: After main multi-head computation and rejoining of heads,
        layer that maps activations to activations (with shape matching the
        original input activations).
    qkv_attention_layer: Layer the does the core multi-head self-attention
        computation.
    n_heads: Number of attention heads. Attention heads effectively split
        activation vectors into ``n_heads`` subvectors, of size
        ``d_feature / n_heads``.
  """
    return cb.Serial(
        cb.Branch(
            [q_layer, SplitIntoHeads(n_heads)],
            [k_layer, SplitIntoHeads(n_heads)],
            [v_layer, SplitIntoHeads(n_heads)],
        ), qkv_attention_layer, MergeHeads(n_heads), final_layer)
Ejemplo n.º 17
0
def CausalFavor(
        d_feature,
        n_heads=1,
        dropout=0.0,  # pylint: disable=invalid-name
        numerical_stabilizer=0.001,
        precision=None,
        mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  causal attention, but using FAVOR fast attention as in the following paper:
  https://arxiv.org/abs/2006.03555

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    numerical_stabilizer: float, small number used for numerical stability.
    precision: passed to np.einsum to define arithmetic precision.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    del dropout, mode  # not implemented yet but needed in the API

    # TODO(lukaszkaiser): make an API for split/merge heads in core layers,
    # and use it here so we don't duplicate these functions.
    if d_feature % n_heads != 0:
        raise ValueError(
            f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
            f'of the requested number of attention heads ({n_heads}).')

    d_head = d_feature // n_heads

    def _split_into_heads():
        """Returns a layer that reshapes tensors for multi-headed computation."""
        def f(x):
            batch_size = x.shape[0]
            seq_len = x.shape[1]

            # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
            x = x.reshape((batch_size, seq_len, n_heads, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((-1, seq_len, d_head))
            return x

        return base.Fn('SplitIntoHeads', f)

    def _merge_heads():
        """Returns a layer that undoes splitting, after multi-head computation."""
        def f(x):
            seq_len = x.shape[1]

            # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
            x = x.reshape((-1, n_heads, seq_len, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((-1, seq_len, n_heads * d_head))
            return x

        return base.Fn('MergeHeads', f)

    def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime,
                            key_prime, value):
        def body(p, qkv):
            (q, k, v) = qkv
            p += np.einsum('...m,...d->...md', k, v, precision=precision)
            x_slice = np.einsum('...m,...md->...d', q, p, precision=precision)
            return p, x_slice

        p, w = fastmath.scan(body, init_prefix_sum_value,
                             (query_prime, key_prime, value))
        return w, (p, query_prime, key_prime, value)

    def favor_numerator_bwd(init_prefix_sum_value, precision, pqkv, w_ct):
        del init_prefix_sum_value

        def body(carry, qkv_xct):
            p, p_ct = carry
            q, k, v, x_ct = qkv_xct
            q_ct = np.einsum('...d,...md->...m', x_ct, p, precision=precision)
            p_ct += np.einsum('...d,...m->...md', x_ct, q, precision=precision)
            k_ct = np.einsum('...md,...d->...m', p_ct, v, precision=precision)
            v_ct = np.einsum('...md,...m->...d', p_ct, k, precision=precision)
            p -= np.einsum('...m,...d->...md', k, v, precision=precision)
            return (p, p_ct), (q_ct, k_ct, v_ct)

        p, qs, ks, vs = pqkv
        _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, np.zeros_like(p)),
                                                 (qs, ks, vs, w_ct),
                                                 reverse=True)
        return qs_ct, ks_ct, vs_ct

    def favor_numerator(init_prefix_sum_value, precision, query_prime,
                        key_prime, value):
        w, _ = favor_numerator_fwd(init_prefix_sum_value, precision,
                                   query_prime, key_prime, value)
        return w

    favor_numerator = fastmath.custom_vjp(favor_numerator,
                                          favor_numerator_fwd,
                                          favor_numerator_bwd,
                                          nondiff_argnums=(0, 1))

    def favor_denominator_fwd(init_prefix_sum_value, precision, query_prime,
                              key_prime):
        def body(p, qk):
            q, k = qk
            p += k
            x = np.einsum('...m,...m->...', q, p, precision=precision)
            return p, x

        p, r = fastmath.scan(body, init_prefix_sum_value,
                             (query_prime, key_prime))
        return r, (query_prime, key_prime, p)

    def favor_denominator_bwd(init_prefix_sum_value, precision, qkp, r_ct):
        del init_prefix_sum_value

        def body(carry, qkx):
            p, p_ct = carry
            q, k, x_ct = qkx
            q_ct = np.einsum('...,...m->...m', x_ct, p, precision=precision)
            p_ct += np.einsum('...,...m->...m', x_ct, q, precision=precision)
            k_ct = p_ct
            p -= k
            return (p, p_ct), (q_ct, k_ct)

        qs, ks, p = qkp
        _, (qs_ct, ks_ct) = fastmath.scan(body, (p, np.zeros_like(p)),
                                          (qs, ks, r_ct),
                                          reverse=True)
        return (qs_ct, ks_ct)

    def favor_denominator(init_prefix_sum_value, precision, query_prime,
                          key_prime):
        r, _ = favor_denominator_fwd(init_prefix_sum_value, precision,
                                     query_prime, key_prime)
        return r

    favor_denominator = fastmath.custom_vjp(favor_denominator,
                                            favor_denominator_fwd,
                                            favor_denominator_bwd,
                                            nondiff_argnums=(0, 1))

    favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd)

    def relu(x):
        return np.where(x <= 0, np.zeros_like(x), x)

    def favor(query, key, value):
        query_prime = relu(query) + numerical_stabilizer
        key_prime = relu(key) + numerical_stabilizer
        prefix_sum_tensor_shape = (key.shape[0], key.shape[-1],
                                   value.shape[-1])
        t_slice_shape = (key.shape[0], key.shape[-1])
        init_prefix_sum_value_numerator = np.zeros(prefix_sum_tensor_shape)
        init_prefix_sum_value_denominator = np.zeros(t_slice_shape)

        w = favor_numerator(init_prefix_sum_value_numerator, precision,
                            np.moveaxis(query_prime, 1, 0),
                            np.moveaxis(key_prime, 1, 0),
                            np.moveaxis(value, 1, 0))
        r = favor_denominator(init_prefix_sum_value_denominator, precision,
                              np.moveaxis(query_prime, 1, 0),
                              np.moveaxis(key_prime, 1, 0))
        w = np.moveaxis(w, 0, 1)
        r = np.moveaxis(r, 0, 1)

        r = r + 2 * numerical_stabilizer * (np.abs(r) <= numerical_stabilizer)
        r = np.reciprocal(r)
        r = np.expand_dims(r, len(r.shape))
        renormalized_attention = w * r
        return renormalized_attention

    return cb.Serial(
        cb.Branch(
            [core.Dense(d_feature), _split_into_heads()],
            [core.Dense(d_feature), _split_into_heads()],
            [core.Dense(d_feature), _split_into_heads()],
        ),
        base.Fn('FAVOR', favor),
        _merge_heads(),
        core.Dense(d_feature),
    )
Ejemplo n.º 18
0
 def test_branch_name(self):
     layer = cb.Branch(cb.Add(), divide_by(0.5))  # pylint: disable=no-value-for-parameter
     self.assertIn('Branch', str(layer))
Ejemplo n.º 19
0
 def test_branch_one_layer(self):
     layer = cb.Branch(divide_by(0.5))
     input_signature = ShapeDtype((3, 2))
     expected_shape = (3, 2)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 20
0
 def test_branch_add_div(self):
     layer = cb.Branch(cb.Add(), divide_by(0.5))
     input_signature = (ShapeDtype((3, 2)), ShapeDtype((3, 2)))
     expected_shape = ((3, 2), (3, 2))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 21
0
 def test_branch_noop_dup(self):
     layer = cb.Branch([], cb.Dup())
     input_signature = ShapeDtype((3, 2))
     expected_shape = ((3, 2), (3, 2), (3, 2))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Ejemplo n.º 22
0
 def test_branch_name(self):
   layer = cb.Branch(cb.Add(), divide_by(0.5))
   self.assertIn('Branch', str(layer))
 def test_branch_op_not_defined(self):
     with self.assertRaises(AttributeError):
         cb.Branch([], [])
Ejemplo n.º 24
0
def RelativeAttentionLayer(d_feature,
                           total_kv_pooling,
                           n_heads=1,
                           dropout=0.0,
                           n_raw_tokens_generated=1,
                           max_inference_length=3072,
                           chunk_len=None,
                           chunk_offset=None,
                           mode='train'):
  """Returns a layer that maps (q, k, v, masks) to (activations, masks).

  When number of keys is smaller than number of queries layer works in O(q^2*d).
  Otherwise it is O(q*k*d). That is because we need to shift relative distances
  by current_pooling. When we upsample this is current pooling is a fraction < 1
  Visual explanation:
  [01][23][45][67] -> [0][1][2][3][4][5][6][7]
  For token [0] we calculate relative distances as follows:
  * 0 2 4 6
  However for token [1] we need relative distances changed by 1, specifically:
  * -1 1 3 5
  So we not only need to calculate the distances that corresponds to spacing
  between the keys but also for the ones in between because there are more than
  one query tokens (on different positions which means different relative
  distances) for single key token.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    total_kv_pooling: Accumulated pool size of keys/values used at this layer.
    n_heads: Number of attention heads.
    dropout: Probabilistic rate for internal dropout applied to attention
      activations (based on query-key pairs) before dotting them with values.
    n_raw_tokens_generated: Number of tokens generated in a single pass through
      this layer. Used only in 'predict' non-training mode.
    max_inference_length: Maximum sequence length allowed in non-training
      modes.
    chunk_len (optional): Number of tokens per chunk. Setting this option will
      enable chunked attention.
    chunk_offset (optional): Offset for shifting chunks, for shifted chunked
      attention
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  pos_emb = PositionalEmbeddings(
      d_feature,
      total_kv_pooling,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      n_raw_tokens_generated=n_raw_tokens_generated,
      mode=mode)

  attention = RelativeAttention(  # pylint: disable=no-value-for-parameter
      total_kv_pooling=total_kv_pooling,
      n_heads=n_heads,
      dropout=dropout,
      n_raw_tokens_generated=n_raw_tokens_generated,
      max_inference_length=max_inference_length,
      chunk_len=chunk_len,
      chunk_offset=chunk_offset,
      mode=mode),

  assert d_feature % n_heads == 0
  d_head = d_feature // n_heads
  context_bias_layer = core.Weights(
      init.RandomNormalInitializer(1e-6), shape=(1, n_heads, 1, d_head))
  location_bias_layer = core.Weights(
      init.RandomNormalInitializer(1e-6), shape=(1, n_heads, 1, d_head))

  return cb.Serial(
      cb.Branch(
          cb.Serial(pos_emb, core.Dense(d_feature)),
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
          cb.Select([1])  # mask
      ),
      context_bias_layer,
      location_bias_layer,
      attention,
      core.Dense(d_feature),
  )
Ejemplo n.º 25
0
def ModularCausalAttention(
        d_feature,
        n_heads=1,
        dropout=0.0,  # pylint: disable=invalid-name
        max_inference_length=2048,
        n_modules=1,
        mode='train'):
    """Returns a layer that maps activations to activations, with causal masking.

  Like `CausalAttention`, this layer type represents one pass of multi-head
  self-attention with causal masking rather than padding-based masking. However,
  it uses LocallyConnectedDense instead of Dense layer for computing K/Q/V.

  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    max_inference_length: maximum length for inference.
    n_modules: Number of modules used in LocallyConnectedDense.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
    if d_feature % n_heads != 0:
        raise ValueError(
            f'Dimensionality of feature embedding ({d_feature}) is not a multiple '
            f'of the requested number of attention heads ({n_heads}).')

    d_head = d_feature // n_heads

    @assert_shape('bld->hlx')
    def _split_into_heads():
        """Returns a layer that reshapes tensors for multi-headed computation."""
        def f(x):
            batch_size = x.shape[0]
            seq_len = x.shape[1]

            # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
            x = x.reshape((batch_size, seq_len, n_heads, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((batch_size * n_heads, seq_len, d_head))
            return x

        return tl.Fn('SplitIntoHeads', f)

    @assert_shape('hlx->bld')
    def _merge_heads():
        """Returns a layer that undoes splitting, after multi-head computation."""
        def f(x):
            seq_len = x.shape[1]

            # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
            x = x.reshape((-1, n_heads, seq_len, d_head))
            x = x.transpose((0, 2, 1, 3))
            x = x.reshape((-1, seq_len, d_head * n_heads))
            return x

        return tl.Fn('MergeHeads', f)

    @assert_shape('...a->...b')
    def ProcessingLayer():  # pylint: disable=invalid-name
        if n_modules == 1:
            return tl.Dense(d_feature)
        else:
            assert d_feature % n_modules == 0
            return LocallyConnectedDense(n_modules, d_feature // n_modules)

    return cb.Serial(
        cb.Branch(
            [ProcessingLayer(), _split_into_heads()],
            [ProcessingLayer(), _split_into_heads()],
            [ProcessingLayer(), _split_into_heads()],
        ),
        tl.DotProductCausalAttention(dropout=dropout,
                                     max_inference_length=max_inference_length,
                                     mode=mode), _merge_heads(),
        ProcessingLayer())
Ejemplo n.º 26
0
def CreateAttentionMaskLayer():
    """Creates attention mask layer.

  Returns a layer that based on queries, keys and accumulated pool size of
  keys/values until this layer calculates positional embeddings for
  causal relative attention calculations.

  Takes as input q, k, v and appends proper mask in the end.
  Causal attention uses masking to prevent a given sequence position from
  attending to positions greater than / following it. This is used, for
  example, when training autoregressive sequence models, or when decoding a
  sequence symbol by symbol.

  Returns:
    an attention mask layer.
  """
    def calculate_mask(queries, keys):
        batch_size = queries.shape[0]
        keys_len, queries_len = keys.shape[-2], queries.shape[-2]
        funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len)

        return _funnel_mask(batch_size, keys_len, queries_len, funnel_factor,
                            is_upsampling)

    def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor,
                     is_upsampling):
        """Funnel mask.

    Args:
      batch_size: batch size.
      keys_len: keys length.
      queries_len: queries length.
      funnel_factor: funnel factor.
      is_upsampling: True or False.

    Returns:
      funnel mask.

    This function based on keys/queries lengths creates a triangle mask
    that prevents tokens from attending to positions following it.

    If funnel_factor is not equal to 1 due to funnel upsampling or
    downsampling it adjusts created mask for funnel attention
    by repeating each element funnel_factor times.

    This is because after funnel layer one token attends to funnel_factor
    different tokens in downsampling. During upsampling on the other hand
    funnel_factor tokens are attending to single token before upsampling.
    """

        if funnel_factor != 1:
            if not is_upsampling:
                mask = jnp.tril(
                    jnp.ones((queries_len, queries_len), dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-1)
            else:
                mask = jnp.tril(jnp.ones((keys_len, keys_len),
                                         dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-2)
        else:
            mask = jnp.tril(
                jnp.ones((queries_len, queries_len), dtype=jnp.bool_))

        return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)

    return cb.Branch(
        cb.Select([0]), cb.Select([1]), cb.Select([2]),
        cb.Fn('create attention mask layer', calculate_mask, n_out=1))