Example #1
0
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer encoder layer.

  The input to the encoder is a pair (embedded source, mask) where
  the mask is created from the original source to prevent attending
  to the padding part of the input.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a pair (actiavtions, mask).
  """
    return tl.Serial(
        tl.Residual(  # Attention block here.
            tl.Parallel(tl.LayerNorm(), tl.Copy()),
            tl.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
            tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())),
        tl.Parallel(
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode),
            tl.Div(
                divisor=2.0)  # Mask added to itself in the residual, divide.
        ))
Example #2
0
def Transformer(vocab_size,
                feature_depth=512,
                feedforward_depth=2048,
                num_layers=6,
                num_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
  """Transformer.

  This model expects on input a pair (source, target).

  Args:
    vocab_size: int: vocab size (shared source and target).
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer model.
  """
  embedding = tl.Serial(
      tl.Embedding(feature_depth, vocab_size),
      tl.Dropout(rate=dropout, mode=mode),
      tl.PositionalEncoding(max_len=max_len)
  )
  encoder = tl.Serial(
      tl.Branch(embedding, tl.PaddingMask()),
      tl.Serial(*[EncoderLayer(feature_depth, feedforward_depth, num_heads,
                               dropout, mode)
                  for _ in range(num_layers)]),
      tl.Parallel(tl.LayerNorm(), tl.Copy())
  )
  stack = [EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads,
                               dropout, mode)
           for _ in range(num_layers)]
  return tl.Serial(
      tl.Parallel(tl.Copy(), tl.ShiftRight()),
      tl.Parallel(encoder, embedding),
      tl.UnnestBranches(),  # (encoder, encoder_mask, decoder_input)
      tl.Select((0, (1, 2), 2)),
      tl.Parallel(  # (encoder_mask, decoder_input) -> encoder-decoder mask
          tl.Copy(), tl.EncoderDecoderMask(), tl.Copy()),
      tl.Serial(*stack),
      tl.Select(2),  # Drop encoder and mask.
      tl.LayerNorm(),
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
Example #3
0
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return tl.Serial(
        tl.Residual(  # Self-attention block.
            tl.LayerNorm(),
            tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)),  # Create mask.
            tl.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
            tl.Select(0),  # Drop the mask.
            tl.Dropout(rate=dropout, mode=mode)),
        ResidualFeedForward(feature_depth,
                            feedforward_depth,
                            dropout,
                            mode=mode))
Example #4
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = tl.Serial(tl.BatchNorm(), tl.Relu(),
                     tl.Conv(channels, (3, 3), strides, padding='SAME'),
                     tl.BatchNorm(), tl.Relu(),
                     tl.Conv(channels, (3, 3), padding='SAME'))
    shortcut = tl.Copy() if not channel_mismatch else tl.Conv(
        channels, (3, 3), strides, padding='SAME')
    return tl.Residual(main, shortcut=shortcut)
Example #5
0
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                        mode):
    """Transformer encoder-decoder layer.

  The input is a triple pair (encoder, mask, decoder_input) where
  the mask is created from the original source to prevent attending
  to the padding part of the encoder.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a triple (encoder, mask, decoder_activations).
  """
    # Decoder self-attending to decoder.
    self_attention = tl.Residual(
        tl.LayerNorm(),
        tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)),  # create mask
        tl.MultiHeadedAttention(feature_depth,
                                num_heads=num_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Select(0),  # drop mask
        tl.Dropout(rate=dropout, mode=mode))
    # Decoder attending to encoder.
    encoder_decoder_attention = tl.Serial(
        tl.Select(((2, 0, 0), 1)),  # ((dec, enc, enc), mask)
        tl.MultiHeadedAttentionQKV(  # ((q, k, v), mask) --> new, mask
            feature_depth,
            num_heads=num_heads,
            dropout=dropout,
            mode=mode),
        tl.Select(0),  # drop the mask
        tl.Dropout(rate=dropout, mode=mode),
    )
    return tl.Serial(
        tl.Parallel(tl.Copy(), tl.Copy(), self_attention),
        tl.Branch(tl.Copy(), encoder_decoder_attention),
        tl.UnnestBranches(),  # (encoder, mask, old_act, new_act)
        tl.Select((0, 1, (2, 3))),
        tl.Parallel(  # Residual after encoder-decoder attention.
            tl.Copy(), tl.Copy(), tl.Add()),
        tl.Parallel(  # Feed-forward on the third component (decoder).
            tl.Copy(), tl.Copy(),
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode)))
def ChunkedCausalMultiHeadedAttention(
    feature_depth, num_heads=8, dropout=0.0, chunk_selector=None, mode='train'):
  """Transformer-style causal multi-headed attention operating on chunks.

  Accepts inputs that are a list of chunks and applies causal attention.

  Args:
    feature_depth: int:  depth of embedding
    num_heads: int: number of attention heads
    dropout: float: dropout rate
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
  prepare_attention_input = tl.Serial(
      tl.Branch(
          tl.Branch(  # q = k = v = first input
              tl.Copy(), tl.Copy(), tl.Copy()),
          tl.CausalMask(axis=-2),
      ),
      tl.Parallel(
          tl.Parallel(
              tl.Dense(feature_depth),
              tl.Dense(feature_depth),
              tl.Dense(feature_depth),
          ),
          tl.Copy()
      )
  )
  return tl.Serial(
      tl.Map(prepare_attention_input),
      ChunkedAttentionSelector(selector=chunk_selector),  # pylint: disable=no-value-for-parameter
      tl.Map(tl.PureMultiHeadedAttention(
          feature_depth=feature_depth, num_heads=num_heads,
          dropout=dropout, mode=mode), check_shapes=False),
      tl.Map(tl.Select(0), check_shapes=False),  # drop masks
      tl.Map(tl.Dense(feature_depth))
  )