Example #1
0
    def Encoder(source, source_mask):
        """Transformer encoder stack.

    Args:
      source: layer variable: raw source sequences
      source_mask: layer variable: self-attention mask

    Returns:
      Layer variable that outputs encoded source.
    """
        encoder_layer = layers.Serial(
            # input attends to self
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    layers.Identity(),  # key
                    layers.Identity(),  # value
                    source_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # feed-forward
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode),
        )
        return layers.Serial(
            source,
            source_embedding_layer,
            layers.repeat(encoder_layer, num_layers),
            layers.LayerNorm(),
        )
Example #2
0
def TransformerRevnetLM(vocab_size,
                        d_feature=512,
                        d_feedforward=2048,
                        d_attention_key=64,
                        d_attention_value=64,
                        n_layers=6,
                        n_heads=8,
                        dropout=0.1,
                        max_len=2048,
                        n_chunks=32,
                        n_attention_chunks=8,
                        attention_loop_stride=0,
                        mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of *each half* of the two-part features
    d_feedforward: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_loop_stride: int: number of query elements to compute attention
      for in parallel. Set to 0 to disable memory-efficient attention.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positional_embedder = [
        tl.Embedding(d_feature, vocab_size),
        # TODO(kitaev): add dropout
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model(
        tl.Concatenate(n_items=n_chunks),
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),
        ReversibleSerial([
            # pylint: disable=g-complex-comprehension
            DecoderBlock(d_feature, d_feedforward, d_attention_key,
                         d_attention_value, n_heads, n_attention_chunks,
                         attention_loop_stride, dropout, mode)
            for _ in range(n_layers)
        ]),
        tl.Parallel(tl.LayerNorm(), tl.LayerNorm()),
        tl.Concatenate(),
        Split(sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
        Map([
            tl.Dense(vocab_size),
            tl.LogSoftmax(),
        ], sections=n_chunks),
    )
Example #3
0
def TransformerRevnetLM(vocab_size,
                        d_model=512,
                        d_ff=2048,
                        d_attention_key=64,
                        d_attention_value=64,
                        n_layers=6,
                        n_heads=8,
                        dropout=0.1,
                        max_len=2048,
                        n_chunks=32,
                        n_attention_chunks=8,
                        attention_type=DotProductAttention,
                        mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positional_embedder = [
        tl.Embedding(d_model, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model(
        tl.Concatenate(n_items=n_chunks),
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),
        tl.ReversibleSerial([
            # pylint: disable=g-complex-comprehension
            DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                         n_heads, n_attention_chunks, attention_type, dropout,
                         mode) for _ in range(n_layers)
        ]),
        tl.Parallel(tl.LayerNorm(), tl.LayerNorm()),
        tl.Concatenate(),
        Split(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
        Map([
            tl.Dense(vocab_size),
            tl.LogSoftmax(),
        ], n_sections=n_chunks),
    )
Example #4
0
def Transformer(vocab_size,
                d_feature=512,
                d_feedforward=2048,
                n_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Transformer.

  This model expects on input a pair (source, target).

  Args:
    vocab_size: int: vocab size (shared source and target).
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer model.
  """
    positional_embedder = [
        tl.Embedding(d_feature, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    encoder = [
        tl.Branch(positional_embedder, tl.PaddingMask()),
        [
            EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode)
            for _ in range(n_layers)
        ],
        tl.LayerNorm(),
    ]
    return tl.Model(
        tl.Parallel([], tl.ShiftRight()),
        tl.Parallel(encoder, positional_embedder),
        tl.Select(inputs=(('encoder', 'mask'), 'decoder'),
                  output=('decoder', ('mask', 'decoder'), 'encoder')),
        # (encoder_mask, decoder_input) -> encoder-decoder mask
        tl.Parallel([], tl.EncoderDecoderMask(), []),
        [
            EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode)
            for _ in range(n_layers)
        ],
        tl.Select(0),  # Drop mask and encoder.
        tl.LayerNorm(),
        tl.Dense(vocab_size),
        tl.LogSoftmax(),
    )
def Transformer(vocab_size,
                feature_depth=512,
                feedforward_depth=2048,
                num_layers=6,
                num_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Transformer.

  This model expects on input a pair (source, target).

  Args:
    vocab_size: int: vocab size (shared source and target).
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer model.
  """
    embedding = layers.Serial(layers.Embedding(feature_depth, vocab_size),
                              layers.Dropout(rate=dropout, mode=mode),
                              layers.PositionalEncoding(max_len=max_len))
    encoder = layers.Serial(
        layers.Branch(),  # Branch input to create embedding and mask.
        layers.Parallel(embedding, layers.PaddingMask()),
        layers.Serial(*[
            EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                         mode) for _ in range(num_layers)
        ]),
        layers.Parallel(layers.LayerNorm(), layers.Identity()))
    stack = [
        EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads,
                            dropout, mode) for _ in range(num_layers)
    ]
    return layers.Serial(
        layers.Parallel(layers.Identity(), layers.ShiftRight()),
        layers.Parallel(encoder, embedding),
        layers.UnnestBranches(),  # (encoder, encoder_mask, decoder_input)
        layers.Reorder(output=(0, (1, 2), 2)),
        layers.
        Parallel(  # (encoder_mask, decoder_input) -> encoder-decoder mask
            layers.Identity(), layers.EncoderDecoderMask(), layers.Identity()),
        layers.Serial(*stack),
        layers.ThirdBranch(),
        layers.LayerNorm(),
        layers.Dense(vocab_size),
        layers.LogSoftmax())
Example #6
0
def TransformerRevnetLM(vocab_size,
                        d_feature=512,
                        d_feedforward=2048,
                        n_layers=6,
                        n_heads=8,
                        dropout=0.1,
                        max_len=2048,
                        n_chunks=32,
                        n_attention_chunks=8,
                        mode='train'):
  """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of *each half* of the two-part features
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for memory-efficient attention
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  positional_embedder = [
      tl.Embedding(d_feature, vocab_size),
      # TODO(kitaev): dropout is disabled to save memory
      # tl.Dropout(rate=dropout, mode=mode),
      tl.PositionalEncoding(max_len=max_len),
  ]
  return tl.Model(
      tl.Concatenate(),
      tl.ShiftRight(),
      positional_embedder,
      Duplicate(),  # pylint: disable=no-value-for-parameter
      ReversibleSerial([
          DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks,
                       dropout, mode)
          for _ in range(n_layers)
      ]),
      tl.Parallel(tl.LayerNorm(), tl.LayerNorm()),
      tl.Concatenate(),
      Split(sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
      Map([
          tl.Dense(vocab_size),
          tl.LogSoftmax(),
      ]),
  )
Example #7
0
def Transformer(vocab_size,
                feature_depth=512,
                feedforward_depth=2048,
                num_layers=6,
                num_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Transformer.

  This model expects on input a pair (source, target).

  Args:
    vocab_size: int: vocab size (shared source and target).
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer model.
  """
    embedding = tl.Serial(tl.Embedding(feature_depth, vocab_size),
                          tl.Dropout(rate=dropout, mode=mode),
                          tl.PositionalEncoding(max_len=max_len))
    encoder = tl.Serial(
        tl.Branch(embedding, tl.PaddingMask()),
        tl.Serial(*[
            EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                         mode) for _ in range(num_layers)
        ]), tl.Parallel(tl.LayerNorm(), tl.NoOp()))
    stack = [
        EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads,
                            dropout, mode) for _ in range(num_layers)
    ]
    return tl.Serial(
        tl.Parallel(tl.NoOp(), tl.ShiftRight()),
        tl.Parallel(encoder, embedding),
        tl.Select(inputs=(('encoder', 'mask'), 'decoder'),
                  output=('encoder', ('mask', 'decoder'), 'decoder')),
        tl.Parallel(  # (encoder_mask, decoder_input) -> encoder-decoder mask
            tl.NoOp(), tl.EncoderDecoderMask(), tl.NoOp()),
        tl.Serial(*stack),
        tl.Select(2),  # Drop encoder and mask.
        tl.LayerNorm(),
        tl.Dense(vocab_size),
        tl.LogSoftmax())
Example #8
0
def TransformerLM(vocab_size,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_layers=6,
                  num_heads=8,
                  dropout=0.1,
                  max_len=2048,
                  mode='train'):
    """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return tl.Serial(
        tl.ShiftRight(), tl.Embedding(feature_depth, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
        tl.Serial(*[
            DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                         mode) for _ in range(num_layers)
        ]), tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
Example #9
0
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return tl.Serial(
        tl.Residual(  # Self-attention block.
            tl.LayerNorm(),
            tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)),  # Create mask.
            tl.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
            tl.Select(0),  # Drop the mask.
            tl.Dropout(rate=dropout, mode=mode)),
        ResidualFeedForward(feature_depth,
                            feedforward_depth,
                            dropout,
                            mode=mode))
Example #10
0
def EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode):
    """Transformer encoder block.

  The input to the encoder is a pair (embedded source, mask) where
  the mask is created from the original source to prevent attending
  to the padding part of the input.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a pair (activations, mask).
  """
    attention = [
        tl.LayerNorm(),
        tl.MultiHeadedAttention(d_feature,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Dropout(rate=dropout, mode=mode),
    ]
    feed_forward = [
        FeedForward(d_feature, d_feedforward, dropout, mode=mode),
    ]
    return [
        tl.Residual(attention),
        tl.Residual(feed_forward),
    ]
Example #11
0
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    self_attention = [
        tl.LayerNorm(),
        tl.Branch([], tl.CausalMask(axis=-2)),  # Create mask.
        tl.MultiHeadedAttention(d_feature,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Select(0),  # Drop mask.
        tl.Dropout(rate=dropout, mode=mode),
    ]
    feed_forward = [
        FeedForward(d_feature, d_feedforward, dropout, mode=mode),
    ]
    return [
        tl.Residual(self_attention),
        tl.Residual(feed_forward),
    ]
Example #12
0
def ChunkedDecoderLayer(d_feature, d_feedforward, n_heads, dropout,
                        chunk_selector, mode):
    """Transformer decoder layer operating on chunks.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    The layers comprising a chunked decoder.
  """
    return [
        Residual(  # Self-attention block.
            tl.Map(tl.LayerNorm()),
            ChunkedCausalMultiHeadedAttention(d_feature,
                                              n_heads=n_heads,
                                              dropout=dropout,
                                              chunk_selector=chunk_selector,
                                              mode=mode),
            tl.Map(tl.Dropout(rate=dropout, mode=mode)),
        ),
        tl.Map(
            ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode))
    ]
Example #13
0
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer encoder layer.

  The input to the encoder is a pair (embedded source, mask) where
  the mask is created from the original source to prevent attending
  to the padding part of the input.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a pair (actiavtions, mask).
  """
    return tl.Serial(
        tl.Residual(  # Attention block here.
            tl.Parallel(tl.LayerNorm(), tl.Copy()),
            tl.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
            tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())),
        tl.Parallel(
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode),
            tl.Div(
                divisor=2.0)  # Mask added to itself in the residual, divide.
        ))
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads,
                 n_attention_chunks, attention_type, dropout, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """

    pre_attention = [
        Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(
            [
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key)
            ],
            [
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key)
            ],
            [
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_value)
            ],
        ),
    ]

    attention = attention_type(mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model),
        Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
    ]

    feed_forward = [
        FeedForward(d_model, d_ff, dropout, mode=mode),
    ]
    return [
        ReversibleAttentionHalfResidual(pre_attention, attention,
                                        post_attention),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
Example #15
0
def EncoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode):
    """Returns a layer sequence that implements a Transformer encoder block.

  The input to the layer sequence is a pair, (activations, mask), where the
  mask was created from the original source tokens to prevent attending to the
  padding part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'

  Returns:
    A sequence of layers that maps an (activations, mask) pair to an
    (activations, mask) pair.
  """
    attention = [
        tl.LayerNorm(),
        tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode),
        tl.Dropout(rate=dropout, name='enc_attn_dropout', mode=mode),
    ]
    feed_forward = [
        FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode),
    ]
    return [
        tl.Residual(attention),
        tl.Residual(feed_forward),
    ]
Example #16
0
def EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode):
    """Returns a layer sequence that implements a Transformer encoder block.

  The input to the layer sequence is a pair, (activations, mask), where the
  mask was created from the original source tokens to prevent attending to the
  padding part of the input.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    A sequence of layers that maps an (activations, mask) pair to an
    (activations, mask) pair.
  """
    attention = [
        tl.LayerNorm(),
        tl.MultiHeadedAttention(d_feature,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Dropout(rate=dropout, mode=mode),
    ]
    feed_forward = [
        FeedForward(d_feature, d_feedforward, dropout, mode=mode),
    ]
    return [
        tl.Residual(attention),
        tl.Residual(feed_forward),
    ]
Example #17
0
def ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode):
    """Residual feed-forward layer with normalization at start."""
    return layers.Residual(layers.LayerNorm(), layers.Dense(feedforward_depth),
                           layers.Relu(),
                           layers.Dropout(rate=dropout, mode=mode),
                           layers.Dense(feature_depth),
                           layers.Dropout(rate=dropout, mode=mode))
Example #18
0
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode):
    """Returns a layer sequence that implements a Transformer decoder block.

  The input to the layer sequence is an activation tensor.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    A sequence of layers that maps an activation tensor to an activation tensor.
  """
    self_attention = [
        tl.LayerNorm(),  # vec
        tl.Dup(),  # vec vec
        tl.Parallel([], tl.CausalMask(axis=-2)),  # vec mask
        tl.MultiHeadedAttention(d_feature,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Parallel([], tl.Drop()),  # vec
        tl.Dropout(rate=dropout, mode=mode),  # vec
    ]
    feed_forward = [
        FeedForward(d_feature, d_feedforward, dropout, mode=mode),
    ]
    return [
        tl.Residual(self_attention),
        tl.Residual(feed_forward),
    ]
def ChunkedDecoderLayer(feature_depth,
                        feedforward_depth,
                        num_heads,
                        dropout,
                        chunk_selector,
                        mode):
  """Transformer decoder layer operating on chunks.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  return tl.Serial(
      tl.Residual(  # Self-attention block.
          tl.Map(tl.LayerNorm()),
          ChunkedCausalMultiHeadedAttention(
              feature_depth, num_heads=num_heads, dropout=dropout,
              chunk_selector=chunk_selector, mode=mode),
          tl.Map(tl.Dropout(rate=dropout, mode=mode)),
      ),
      tl.Map(ResidualFeedForward(
          feature_depth, feedforward_depth, dropout, mode=mode))
  )
Example #20
0
def PositionLookupTransformerLM(vocab_size=128,
                                d_feature=256,
                                d_feedforward=512,
                                n_layers=3,
                                n_heads=4,
                                dropout=0.1,
                                max_len=100,
                                mode='train'):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: maximal length
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  positions = _POSITIONS[:max_len, :]
  return tl.Serial([
      tl.ShiftRight(),
      tl.Embedding(d_feature, vocab_size),
      tl.Dropout(rate=dropout, mode=mode),
      NewPositionalEncoding(positions=positions),
      [DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode)
       for _ in range(n_layers)],
      PreservePosition(tl.LayerNorm()),
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  ])
Example #21
0
def DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    positions: random vectors for positions
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return [
        tl.Residual(  # Self-attention block.
            PreservePosition(tl.LayerNorm()),
            tl.Dup(),
            tl.Parallel(
                [],  # activation for (q, k, v)
                tl.CausalMask(axis=-2)),  # attention mask
            MultiHeadedAttentionPosition(positions,
                                         d_feature,
                                         n_heads=n_heads,
                                         dropout=dropout,
                                         mode=mode),
            PreservePosition(tl.Dropout(rate=dropout, mode=mode))),
        ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode)
    ]
Example #22
0
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return layers.Serial(
        layers.Residual(  # Self-attention block.
            layers.LayerNorm(),
            layers.Branch(),
            layers.Parallel(
                layers.Identity(),  # activation for (q, k, v)
                layers.CausalMask(axis=-2)),  # attention mask
            layers.MultiHeadedAttention(feature_depth,
                                        num_heads=num_heads,
                                        dropout=dropout,
                                        mode=mode),
            layers.Dropout(rate=dropout, mode=mode)),
        ResidualFeedForward(feature_depth,
                            feedforward_depth,
                            dropout,
                            mode=mode))
Example #23
0
def TransformerDecoder(d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train'):
  """Returns a Transformer decoder model.

  The input to the model is a continuous tensor. Does not shift the input to the
  right, i.e. the output for timestep t is based on inputs up to timestep t
  inclusively.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer decoder as a layer that maps from a continuous tensor to
    a continuous tensor.
  """
  return tl.Model(                  # vecs
      tl.PositionalEncoding(max_len=max_len),
      tl.Dense(d_model),            # vecs
      [DecoderBlock(d_model, d_ff, n_heads, dropout, mode)
       for _ in range(n_layers)],   # vecs
      tl.LayerNorm(),               # vecs
  )
Example #24
0
def DecoderBlock(d_model, d_ff, n_heads, dropout, mode):
  """Returns a layer sequence that implements a Transformer decoder block.

  The input to the layer sequence is an activation tensor.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    A sequence of layers that maps an activation tensor to an activation tensor.
  """
  self_attention = [
      tl.LayerNorm(),  # vec
      tl.BasicCausalAttention(
          d_model, n_heads=n_heads, dropout=dropout, mode=mode),
      tl.Dropout(rate=dropout, mode=mode),  # vec
  ]
  feed_forward = [
      FeedForward(d_model, d_ff, dropout, mode=mode),
  ]
  return [
      tl.Residual(self_attention),
      tl.Residual(feed_forward),
  ]
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks,
                 attention_loop_stride, dropout, mode):
    """Reversible transformer decoder layer.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_loop_stride: int: number of query elements to compute attention
      for in parallel. Set to 0 to disable memory-efficient attention.
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """

    pre_attention = [
        Chunk(sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(
            [tl.Dense(d_feature),
             SplitHeads(n_heads=n_heads)],  # pylint: disable=no-value-for-parameter
            [tl.Dense(d_feature),
             SplitHeads(n_heads=n_heads)],  # pylint: disable=no-value-for-parameter
            [tl.Dense(d_feature),
             SplitHeads(n_heads=n_heads)],  # pylint: disable=no-value-for-parameter
        ),
    ]

    # TODO(kitaev): add dropout
    if attention_loop_stride < 1:
        # Use the standard implementation if no loop_stride is provided.
        attention = DotProductAttention(dropout=None, mode=mode)
    else:
        attention = MemoryEfficientDotProductAttention(
            loop_stride=attention_loop_stride, dropout=None, mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        JoinHeads(),  # pylint: disable=no-value-for-parameter
        tl.Dense(d_feature),
        Unchunk(sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
    ]

    feed_forward = [
        FeedForward(d_feature, d_feedforward, dropout, mode=mode),
    ]
    return [
        ReversibleAttentionHalfResidual(pre_attention, attention,
                                        post_attention),
        ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        ReversibleSwap(),
    ]
Example #26
0
def EncoderDecoder(d_model, d_ff, n_heads, dropout, layer_idx, mode):
    """Transformer encoder-decoder layer.

  The input is a triple (decoder_input, mask, encoder) where the mask is
  created from the original source to prevent attending to the padding part
  of the encoder.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a triple (decoder_activations, mask, encoder).
  """
    decoder_self_attention = [  #        vecs_d   pmask vecs_e
        tl.LayerNorm(),  #        vecs_d   ..... ......
        tl.BasicCausalAttention(d_model,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Dropout(rate=dropout, mode=mode),  # vecs_d          ..... ......
    ]
    decoder_to_encoder_attention = [  # vecs_d        masks         vecs_e
        tl.LayerNorm(),  # vecs_d        masks         vecs_e
        tl.Parallel([], [], tl.Dup()),  # ______        _____  vecs_e vecs_e
        tl.Parallel([], tl.Swap()),  # ______        vecs_e masks  ......
        tl.Parallel([], tl.Dup()),  # ______ vecs_e vecs_e .....  ......
        tl.AttentionQKV(  # (q k v masks ... --> vecs_d masks ...)
            d_model,
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        tl.Dropout(rate=dropout, mode=mode),  # vecs_d mask vecs_e
    ]
    feed_forward = [
        FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode),
    ]
    return [  # vecs_d masks vecs_e
        tl.Residual(decoder_self_attention),  # vecs_d masks vecs_e
        tl.Residual(decoder_to_encoder_attention),  # vecs_d masks vecs_e
        tl.Residual(feed_forward),  # vecs_d masks vecs_e
    ]
Example #27
0
    def Decoder(memory, target, target_mask, memory_mask):
        """Transformer decoder stack.

    Args:
      memory: layer variable: encoded source sequences
      target: layer variable: raw target sequences
      target_mask: layer variable: self-attention mask
      memory_mask: layer variable: memory attention mask

    Returns:
      Layer variable that outputs encoded source.
    """
        decoder_layer = layers.Serial(
            # target attends to self
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    layers.Identity(),  # key
                    layers.Identity(),  # value
                    target_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # target attends to encoded source
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    memory,  # key
                    memory,  # value
                    memory_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # feed-forward
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode))
        return layers.Serial(
            target,
            target_embedding_layer,
            layers.repeat(decoder_layer, num_layers),
            layers.LayerNorm(),
        )
Example #28
0
def FeedForward(d_model, d_ff, dropout, layer_idx, mode):
    """Feed-forward block with layer normalization at start."""
    return [
        tl.LayerNorm(),
        tl.Dense(d_ff),
        tl.Relu(),
        tl.Dropout(rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode),
        tl.Dense(d_model),
        tl.Dropout(rate=dropout, name='ff_final_%d' % layer_idx, mode=mode),
    ]
Example #29
0
def FeedForward(d_feature, d_feedforward, dropout, mode):
    """Feed-forward block with layer normalization at start."""
    return [
        tl.LayerNorm(),
        tl.Dense(d_feedforward),
        tl.Relu(),
        tl.Dropout(rate=dropout, mode=mode),
        tl.Dense(d_feature),
        tl.Dropout(rate=dropout, mode=mode),
    ]
Example #30
0
def FeedForward(d_model, d_ff, dropout, mode):
    """Feed-forward block with layer normalization at start."""
    return [
        tl.LayerNorm(),
        tl.Dense(d_ff),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        tl.Relu(),
        tl.Dense(d_model),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
    ]