예제 #1
0
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                        mode):
    """Transformer encoder-decoder layer.

  The input is a triple pair (encoder, mask, decoder_input) where
  the mask is created from the original source to prevent attending
  to the padding part of the encoder.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a triple (encoder, mask, decoder_activations).
  """
    # Decoder self-attending to decoder.
    self_attention = layers.Residual(
        layers.LayerNorm(),
        layers.Branch(),
        layers.Parallel(
            layers.Identity(),  # activation for (q, k, v)
            layers.CausalMask(axis=-2)),  # attention mask
        layers.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
        layers.Dropout(rate=dropout, mode=mode))
    # Decoder attending to encoder.
    encoder_decoder_attention = layers.Serial(
        layers.Reorder(output=((2, 0, 0), 1)),  # ((dec, enc, enc), mask)
        layers.MultiHeadedAttentionQKV(  # ((q, k, v), mask) --> new v
            feature_depth,
            num_heads=num_heads,
            dropout=dropout,
            mode=mode),
        layers.Dropout(rate=dropout, mode=mode),
    )
    return layers.Serial(
        layers.Parallel(layers.Identity(), layers.Identity(), self_attention),
        layers.Branch(),
        layers.Parallel(layers.Identity(), encoder_decoder_attention),
        layers.UnnestBranches(),  # (encoder, mask, old_act, new_act)
        layers.Reorder(output=(0, 1, (2, 3))),
        layers.Parallel(  # Residual after encoder-decoder attention.
            layers.Identity(), layers.Identity(), layers.SumBranches()),
        layers.Parallel(  # Feed-forward on the third component (decoder).
            layers.Identity(), layers.Identity(),
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode)))
예제 #2
0
def Transformer(vocab_size,
                feature_depth=512,
                feedforward_depth=2048,
                num_layers=6,
                num_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Transformer.

  This model expects on input a pair (source, target).

  Args:
    vocab_size: int: vocab size (shared source and target).
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer model.
  """
    embedding = layers.Serial(layers.Embedding(feature_depth, vocab_size),
                              layers.Dropout(rate=dropout, mode=mode),
                              layers.PositionalEncoding(max_len=max_len))
    encoder = layers.Serial(
        layers.Branch(),  # Branch input to create embedding and mask.
        layers.Parallel(embedding, layers.PaddingMask()),
        layers.Serial(*[
            EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                         mode) for _ in range(num_layers)
        ]),
        layers.Parallel(layers.LayerNorm(), layers.Identity()))
    stack = [
        EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads,
                            dropout, mode) for _ in range(num_layers)
    ]
    return layers.Serial(
        layers.Parallel(layers.Identity(), layers.ShiftRight()),
        layers.Parallel(encoder, embedding),
        layers.UnnestBranches(),  # (encoder, encoder_mask, decoder_input)
        layers.Reorder(output=(0, (1, 2), 2)),
        layers.
        Parallel(  # (encoder_mask, decoder_input) -> encoder-decoder mask
            layers.Identity(), layers.EncoderDecoderMask(), layers.Identity()),
        layers.Serial(*stack),
        layers.ThirdBranch(),
        layers.LayerNorm(),
        layers.Dense(vocab_size),
        layers.LogSoftmax())