Ejemplo n.º 1
0
def TransformerEncoder(vocab_size,
                       num_classes=10,
                       feature_depth=512,
                       feedforward_depth=2048,
                       num_layers=6,
                       num_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train'):
  """Transformer encoder.

  Args:
    vocab_size: int: vocab size
    num_classes: how many classes on output
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer encoder layer.
  """
  input_embedding = layers.Serial(
      layers.Embedding(feature_depth, vocab_size),
      layers.Dropout(rate=dropout, mode=mode),
      layers.PositionalEncoding(max_len=max_len)
  )
  return layers.Serial(
      layers.Branch(),  # Branch input to create embedding and mask.
      layers.Parallel(input_embedding, layers.PaddingMask()),
      layers.Serial(*[EncoderLayer(feature_depth, feedforward_depth, num_heads,
                                   dropout, mode)
                      for _ in range(num_layers)]),
      layers.FirstBranch(),  # Drop the mask.
      layers.LayerNorm(),
      layers.Mean(axis=1),  # Average on length.
      layers.Dense(num_classes),
      layers.LogSoftmax()
  )
Ejemplo n.º 2
0
def EncoderLayer(feature_depth,
                 feedforward_depth,
                 num_heads,
                 dropout,
                 mode):
  """Transformer encoder layer.

  The input to the encoder is a pair (embedded source, mask) where
  the mask is created from the original source to prevent attending
  to the padding part of the input.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a pair (actiavtions, mask).
  """
  # The encoder block expects (activation, mask) as input and returns
  # the new activations only, we add the mask back to output next.
  encoder_block = layers.Serial(
      layers.Residual(  # Attention block here.
          layers.Parallel(layers.LayerNorm(), layers.Identity()),
          layers.MultiHeadedAttention(feature_depth, num_heads=num_heads,
                                      dropout=dropout, mode=mode),
          layers.Dropout(rate=dropout, mode=mode),
          shortcut=layers.FirstBranch()
      ),
      ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)
  )
  # Now we add the mask back.
  return layers.Serial(
      layers.Reorder(output=((0, 1), 1)),  # (x, mask) --> ((x, mask), mask)
      layers.Parallel(encoder_block, layers.Identity())
  )