def DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.

    The input is an activation tensor.

    Args:
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        mode (str): 'train' or 'eval'.
        ff_activation (function): the non-linearity in feed-forward layer.

    Returns:
        list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor.
    """

    # Add list of two Residual blocks: the attention with normalization and dropout and feed-forward blocks
    return [
        tl.Residual(
            # Normalize layer input
            tl.LayerNorm(),
            # Add causal attention
            tl.CausalAttention(d_feature,
                               n_heads=n_heads,
                               dropout=dropout,
                               mode=mode)),
        tl.Residual(
            # Add feed-forward block
            # We don't need to normalize the layer inputs here. The feed-forward block takes care of that for us.
            FeedForward(d_model, d_ff, dropout, mode, ff_activation)),
    ]
Exemple #2
0
def Encoder(d_model, d_ff, n_heads, dropout, mode, ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.

    The input is an activation tensor.

    Args:
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        mode (str): 'train' or 'eval'.
        ff_activation (function): the non-linearity in feed-forward layer.

    Returns:
        list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor.
    """
    causal_attention = tl.CausalAttention(d_model,
                                          n_heads=n_heads,
                                          dropout=dropout,
                                          mode=mode)
    feed_forward = [
        tl.LayerNorm(),
        tl.Dense(d_ff),
        ff_activation(),
        tl.Dropout(rate=dropout, mode=mode),
        tl.Dense(d_model),
        tl.Dropout(rate=dropout, mode=mode)
    ]
    return [
        tl.Residual(tl.LayerNorm(), causal_attention,
                    tl.Dropout(rate=dropout, mode=mode)),
        tl.Residual(feed_forward),
    ]
Exemple #3
0
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                         mode, ff_activation):
    """Returns a list of layers implementing a Transformer encoder-decoder block.

  The input is a triple (decoder_activations, mask, encoder_activiations) where
  the mask is created from the original input token IDs to prevent attending to
  the padding part of the encoder.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of `Layer`.

  Returns:
    A list of layers which maps triples (decoder_activations, mask,
    encoder_activations) to triples of the same sort.
  """
    def _Dropout():
        return tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode,
                                    cache_KV_in_predict=True)

    causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout,
                                     dropout_shared_axes, mode, ff_activation)

    return [  # vec_d masks vec_e
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            causal_attention,  # vec_d ..... .....
            _Dropout(),  # vec_d ..... .....
        ),
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
            attention_qkv,  # vec_d masks vec_e
            _Dropout(),  # vec_d masks vec_e
        ),
        tl.Residual(feed_forward  # vec_d masks vec_e
                    ),
    ]
Exemple #4
0
  def test_simple_call(self):
    layer = tl.CausalAttention(d_feature=4, n_heads=2)
    x = np.array([[[2, 5, 3, 4],
                   [0, 1, 2, 3],
                   [0, 1, 2, 3],]])
    _, _ = layer.init(shapes.signature(x))

    y = layer(x)
    self.assertEqual(y.shape, (1, 3, 4))
Exemple #5
0
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                         mode, ff_activation):
    """Returns a list of layers implementing a Transformer encoder-decoder block.

  The input is a triple (decoder_input, mask, encoder) where the mask is
  created from the original source to prevent attending to the padding part
  of the encoder.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers which maps triples (decoder_activations, mask,
    encoder_activations) to triples of the same sort.
  """
    def _Dropout():
        return tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode,
                                    cache_KV_in_predict=True)

    causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout,
                                     dropout_shared_axes, mode, ff_activation)

    return [  # vec_d masks vec_e
        ResidualZero(
            tl.LayerNorm(),  # vec_d ..... .....
            causal_attention,  # vec_d ..... .....
            _Dropout(),  # vec_d ..... .....
        ),
        ResidualZero(
            tl.LayerNorm(),  # vec_d ..... .....
            tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
            attention_qkv,  # vec_d masks vec_e
            _Dropout(),  # vec_d masks vec_e
        ),
        ResidualZero(
            tl.LayerNorm(),
            feed_forward,  # vec_d masks vec_e
            _Dropout(),
        ),
    ]
def DecoderBlock(embeddingDepth, depth, n_heads, dropout, mode,
                 ffActivationffActivation):
    return [
        tl.Residual(
            tl.LayerNorm(),
            tl.CausalAttention(d_feature,
                               n_heads=n_heads,
                               dropout=dropout,
                               mode=mode)),
        tl.Residual(
            FeedForward(embeddingDepth, depth, dropout, mode, ffActivation)),
    ]
Exemple #7
0
def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
                  ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.

  The input is an activation tensor.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of `Layer`.

  Returns:
    A list of layers that maps an activation tensor to an activation tensor.
  """
    causal_attention = tl.CausalAttention(d_model,
                                          n_heads=n_heads,
                                          dropout=dropout,
                                          mode=mode),

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout,
                                     dropout_shared_axes, mode, ff_activation)

    dropout_ = tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    return [
        tl.Residual(
            tl.LayerNorm(),
            causal_attention,
            dropout_,
        ),
        tl.Residual(feed_forward),
    ]
Exemple #8
0
def _DecoderBlock(d_model, d_ff, n_heads, d_attn_key, d_attn_value, attn_type,
                  dropout, share_qk, layer_idx, mode, ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.

  The input is an activation tensor.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    d_attn_key: int: depth of key vector for each attention head
    d_attn_value: int: depth of value vector for each attention head
    attn_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: bool, whether to share queries and keys
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers that maps an activation tensor to an activation tensor.
  """
    causal_attention = tl.CausalAttention(d_model,
                                          n_heads=n_heads,
                                          d_attention_key=d_attn_key,
                                          d_attention_value=d_attn_value,
                                          attention_type=attn_type,
                                          share_qk=share_qk,
                                          mode=mode),

    dropout_ = tl.Dropout(rate=dropout,
                          name='attention_%d' % layer_idx,
                          mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode,
                                     ff_activation)

    return [
        tl.Residual(
            tl.LayerNorm(),
            causal_attention,
            dropout_,
        ),
        tl.Residual(feed_forward),
    ]
Exemple #9
0
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    pre_attention_qkv = [
        tl.LayerNorm(),
        tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
    ]
    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode)
    # TODO(kitaev): BroadcastedDropout?
    post_attention_qkv = tl.Dropout(rate=dropout, mode=mode)

    pre_causal_attention = tl.LayerNorm()
    causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)
    # TODO(kitaev): BroadcastedDropout?
    post_causal_attention = tl.Dropout(rate=dropout, mode=mode)

    feed_forward = FeedForward(d_model, d_ff, dropout, ff_activation, mode)

    return [  # vec_d1 vec_d2 masks vec_e
        # TODO(kitaev): consider ReversibleAttentionHalfResidual for efficiency
        ReversibleHalfResidual(
            [pre_causal_attention, causal_attention, post_causal_attention]),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(
            [pre_attention_qkv, attention_qkv, post_attention_qkv]),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),  # vec_d1 vec_d2 masks vec_e
    ]
Exemple #10
0
def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
                  ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.

  The input is an activation tensor.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers that maps an activation tensor to an activation tensor.
  """
    causal_attention = tl.CausalAttention(d_model,
                                          n_heads=n_heads,
                                          dropout=dropout,
                                          mode=mode),

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout,
                                     dropout_shared_axes, mode, ff_activation)

    dropout_ = tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    return [
        ResidualZero(
            tl.LayerNorm(),
            causal_attention,
            dropout_,
        ),
        ResidualZero(
            tl.LayerNorm(),
            feed_forward,
            dropout_,
        ),
    ]
Exemple #11
0
def DecoderBlock(d_model, d_ff, n_heads, d_attention_key, d_attention_value,
                 attention_type, dropout, share_qk, layer_idx, mode):
    """Returns a layer sequence that implements a Transformer decoder block.

  The input to the layer sequence is an activation tensor.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: bool, whether to share queries and keys
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'

  Returns:
    A sequence of layers that maps an activation tensor to an activation tensor.
  """
    self_attention = [
        tl.LayerNorm(),  # vec
        tl.CausalAttention(d_model,
                           n_heads=n_heads,
                           d_attention_key=d_attention_key,
                           d_attention_value=d_attention_value,
                           attention_type=attention_type,
                           share_qk=share_qk,
                           mode=mode),
        tl.Dropout(rate=dropout, name='attention_%d' % layer_idx, mode=mode),
    ]
    feed_forward = [
        FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode),
    ]
    return tl.Serial(
        tl.Residual(self_attention),
        tl.Residual(feed_forward),
    )
Exemple #12
0
def _DecoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode,
                  ff_activation):
    """Returns a list of layers that implements a Transformer decoder block.

  The input is an activation tensor.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers that maps an activation tensor to an activation tensor.
  """
    causal_attention = tl.CausalAttention(d_model,
                                          n_heads=n_heads,
                                          dropout=dropout,
                                          mode=mode),

    dropout_ = tl.Dropout(rate=dropout,
                          name='attention_%d' % layer_idx,
                          mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode,
                                     ff_activation)

    return [
        tl.Residual(
            tl.LayerNorm(),
            causal_attention,
            dropout_,
        ),
        tl.Residual(feed_forward),
    ]
Exemple #13
0
 def _CausalAttention():
     return tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)
Exemple #14
0
 def _CausalAttention():
     return tl.CausalAttention(d_model,
                               n_heads=n_heads,
                               dropout=dropout,
                               mode=mode),