Exemple #1
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                 n_heads, attention_type, dropout, ff_activation,
                 ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity,
                 attention_chunk_size, n_attention_layers=1,
                 n_feedforward_layers=1, center_layernorm=True,
                 use_bfloat16=False, mode='train'):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_attention_layers: how many residual causal attention layers should we
      have before the feed-forward block (default: 1, the standard block)
    n_feedforward_layers: how many FFNN layers should we have (default 1).
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    use_bfloat16: whether to use bfloat16 for weights (default: False).
    mode: str: 'train' or 'eval'


  Returns:
    the layer.
  """
  # pylint: disable=g-complex-comprehension
  attention_half_residuals = [
      [tl.ReversibleHalfResidual(
          tl.LayerNorm(center=center_layernorm),
          attention_layer=ct.ApplyAttentionLayer(
              attention_type, d_model, n_heads, d_attention_key,
              d_attention_value, True, False, dropout, dropout,
              attention_chunk_size, mode),
          name='ReversibleHalfResidualDecoderAttn'),
       tl.ReversibleSwap()
      ] for _ in range(n_attention_layers)]

  feed_forwards = [
      [tl.ReversibleHalfResidual(
          ct.FeedForwardWithOptions(
              d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,
              ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm,
              mode, use_bfloat16),
          name='ReversibleHalfResidualDecoderFF'),
       tl.ReversibleSwap()
      ] for _ in range(n_feedforward_layers)]
  # pylint: enable=g-complex-comprehension
  return attention_half_residuals + feed_forwards
Exemple #2
0
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation,
                        ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0,
                        ff_sparsity=0):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate for feed-forward layer
    mode: str: 'train' or 'eval'
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity

  Returns:
    the layer.
  """
  enc_dec_attention = tl.EncDecAttention(
      n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
      attention_dropout=dropout, output_dropout=dropout,
      mode=mode)
  enc_dec_attention_half_residual = tl.ReversibleHalfResidual(
      tl.LayerNorm(),
      attention_layer=enc_dec_attention,
  )

  causal_attention = tl.SelfAttention(
      n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
      causal=True,
      attention_dropout=dropout, output_dropout=dropout,
      mode=mode)
  causal_attention_half_residual = tl.ReversibleHalfResidual(
      tl.LayerNorm(),
      attention_layer=causal_attention,
  )

  feed_forward = ct.FeedForwardWithOptions(
      d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,
      ff_chunk_size, ff_use_sru, ff_sparsity, mode)

  return [                             # vec_d1 vec_d2 vec_e masks
      causal_attention_half_residual,
      tl.ReversibleSwap(),
      enc_dec_attention_half_residual,
      tl.ReversibleSwap(),
      tl.ReversibleHalfResidual(feed_forward),
      tl.ReversibleSwap(),
  ]
Exemple #3
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads,
                 attention_type, dropout, ff_activation, ff_dropout,
                 ff_use_sru, ff_chunk_size, ff_sparsity, attention_chunk_size,
                 mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    mode: str: 'train' or 'eval'


  Returns:
    the layer.
  """
    attention = ct.ApplyAttentionLayer(attention_type, d_model, n_heads,
                                       d_attention_key, d_attention_value,
                                       True, False, dropout, dropout,
                                       attention_chunk_size, mode)
    attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=attention,
    )

    feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                             ff_activation, ff_dropout,
                                             ff_chunk_size, ff_use_sru,
                                             ff_sparsity, mode)

    return [
        attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
Exemple #4
0
def EncoderBlock(d_model,
                 d_ff,
                 n_heads,
                 attention_type,
                 dropout,
                 ff_activation,
                 ff_dropout,
                 ff_use_sru=0,
                 ff_chunk_size=0,
                 ff_sparsity=0,
                 attention_chunk_size=0,
                 use_bfloat16=False,
                 mode='train'):
    """Returns a list of layers that implements a Reformer encoder block.

  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    mode: str: 'train' or 'eval'

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
    if mode == 'predict':
        # Mode 'predict' means that the decoder should be run one token at a time.
        # The encoder only ever runs over full sequences, which is why it's switched
        # to 'eval' mode instead.
        mode = 'eval'

    attention = ct.ApplyAttentionLayer(
        attention_type=attention_type,
        d_model=d_model,
        n_heads=n_heads,
        d_qk=d_model // n_heads,
        d_v=d_model // n_heads,
        masked=True,
        causal=False,
        attention_dropout=dropout,
        output_dropout=dropout,
        attention_chunk_size=attention_chunk_size,
        mode=mode)
    # TODO(lukaszkaiser): refactor efficient attention layers to unify the API
    # If we're using standard attention, we need to pass reshaped mask and not
    # return the mask to be compatible with the EfficientAttention API.
    if attention.n_out == 2:

        def reshape_mask(mask):
            return jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))

        attention = tl.Serial(
            tl.Fn('ReshapeMask', lambda x, y: (x, reshape_mask(y)), n_out=2),
            attention, tl.Select([0], n_in=2))

    attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=attention,
    )

    feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                             ff_activation, ff_dropout,
                                             ff_chunk_size, ff_use_sru,
                                             ff_sparsity, mode, use_bfloat16)

    return [
        attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
Exemple #5
0
 def _FF():
     return ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                      ff_activation, ff_dropout,
                                      ff_chunk_size, ff_use_sru,
                                      ff_sparsity, center_layernorm, mode,
                                      use_bfloat16)
Exemple #6
0
def EncoderBlock(d_model,
                 d_ff,
                 n_heads,
                 attention_type,
                 dropout,
                 ff_activation,
                 ff_dropout,
                 ff_use_sru=0,
                 ff_chunk_size=0,
                 ff_sparsity=0,
                 mode='train'):
    """Returns a list of layers that implements a Reformer encoder block.

  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    mode: str: 'train' or 'eval'

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
    if mode == 'predict':
        # Mode 'predict' means that the decoder should be run one token at a time.
        # The encoder only ever runs over full sequences, which is why it's switched
        # to 'eval' mode instead.
        mode = 'eval'

    attention = configurable_transformer.ApplyAttentionLayer(
        attention_type=attention_type,
        d_model=d_model,
        n_heads=n_heads,
        d_qk=d_model // n_heads,
        d_v=d_model // n_heads,
        masked=True,
        causal=False,
        attention_dropout=dropout,
        output_dropout=dropout,
        mode=mode)
    attention_half_residual = tl.ReversibleHalfResidual(
        tl.LayerNorm(),
        attention_layer=attention,
    )

    feed_forward = configurable_transformer.FeedForwardWithOptions(
        d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size,
        ff_use_sru, ff_sparsity, mode)

    return [
        attention_half_residual,
        tl.ReversibleSwap(),
        tl.ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]