Exemple #1
0
    def Encoder(source, source_mask):
        """Transformer encoder stack.

    Args:
      source: layer variable: raw source sequences
      source_mask: layer variable: self-attention mask

    Returns:
      Layer variable that outputs encoded source.
    """
        encoder_layer = layers.Serial(
            # input attends to self
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    layers.Identity(),  # key
                    layers.Identity(),  # value
                    source_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # feed-forward
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode),
        )
        return layers.Serial(
            source,
            source_embedding_layer,
            layers.repeat(encoder_layer, num_layers),
            layers.LayerNorm(),
        )
def Transformer(vocab_size,
                feature_depth=512,
                feedforward_depth=2048,
                num_layers=6,
                num_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Transformer.

  This model expects on input a pair (source, target).

  Args:
    vocab_size: int: vocab size (shared source and target).
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer model.
  """
    embedding = layers.Serial(layers.Embedding(feature_depth, vocab_size),
                              layers.Dropout(rate=dropout, mode=mode),
                              layers.PositionalEncoding(max_len=max_len))
    encoder = layers.Serial(
        layers.Branch(),  # Branch input to create embedding and mask.
        layers.Parallel(embedding, layers.PaddingMask()),
        layers.Serial(*[
            EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                         mode) for _ in range(num_layers)
        ]),
        layers.Parallel(layers.LayerNorm(), layers.Identity()))
    stack = [
        EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads,
                            dropout, mode) for _ in range(num_layers)
    ]
    return layers.Serial(
        layers.Parallel(layers.Identity(), layers.ShiftRight()),
        layers.Parallel(encoder, embedding),
        layers.UnnestBranches(),  # (encoder, encoder_mask, decoder_input)
        layers.Reorder(output=(0, (1, 2), 2)),
        layers.
        Parallel(  # (encoder_mask, decoder_input) -> encoder-decoder mask
            layers.Identity(), layers.EncoderDecoderMask(), layers.Identity()),
        layers.Serial(*stack),
        layers.ThirdBranch(),
        layers.LayerNorm(),
        layers.Dense(vocab_size),
        layers.LogSoftmax())
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode):
    """Transformer decoder layer.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    return layers.Serial(
        layers.Residual(  # Self-attention block.
            layers.LayerNorm(),
            layers.Branch(),
            layers.Parallel(
                layers.Identity(),  # activation for (q, k, v)
                layers.CausalMask(axis=-2)),  # attention mask
            layers.MultiHeadedAttention(feature_depth,
                                        num_heads=num_heads,
                                        dropout=dropout,
                                        mode=mode),
            layers.Dropout(rate=dropout, mode=mode)),
        ResidualFeedForward(feature_depth,
                            feedforward_depth,
                            dropout,
                            mode=mode))
Exemple #4
0
    def Decoder(memory, target, target_mask, memory_mask):
        """Transformer decoder stack.

    Args:
      memory: layer variable: encoded source sequences
      target: layer variable: raw target sequences
      target_mask: layer variable: self-attention mask
      memory_mask: layer variable: memory attention mask

    Returns:
      Layer variable that outputs encoded source.
    """
        decoder_layer = layers.Serial(
            # target attends to self
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    layers.Identity(),  # key
                    layers.Identity(),  # value
                    target_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # target attends to encoded source
            layers.Residual(
                layers.LayerNorm(),
                layers.Branch(size=4),
                layers.Parallel(
                    layers.Identity(),  # query
                    memory,  # key
                    memory,  # value
                    memory_mask),  # attention mask
                multi_attention,
                layers.Dropout(dropout, mode=mode)),
            # feed-forward
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode))
        return layers.Serial(
            target,
            target_embedding_layer,
            layers.repeat(decoder_layer, num_layers),
            layers.LayerNorm(),
        )
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout,
                        mode):
    """Transformer encoder-decoder layer.

  The input is a triple pair (encoder, mask, decoder_input) where
  the mask is created from the original source to prevent attending
  to the padding part of the encoder.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a triple (encoder, mask, decoder_activations).
  """
    # Decoder self-attending to decoder.
    self_attention = layers.Residual(
        layers.LayerNorm(),
        layers.Branch(),
        layers.Parallel(
            layers.Identity(),  # activation for (q, k, v)
            layers.CausalMask(axis=-2)),  # attention mask
        layers.MultiHeadedAttention(feature_depth,
                                    num_heads=num_heads,
                                    dropout=dropout,
                                    mode=mode),
        layers.Dropout(rate=dropout, mode=mode))
    # Decoder attending to encoder.
    encoder_decoder_attention = layers.Serial(
        layers.Reorder(output=((2, 0, 0), 1)),  # ((dec, enc, enc), mask)
        layers.MultiHeadedAttentionQKV(  # ((q, k, v), mask) --> new v
            feature_depth,
            num_heads=num_heads,
            dropout=dropout,
            mode=mode),
        layers.Dropout(rate=dropout, mode=mode),
    )
    return layers.Serial(
        layers.Parallel(layers.Identity(), layers.Identity(), self_attention),
        layers.Branch(),
        layers.Parallel(layers.Identity(), encoder_decoder_attention),
        layers.UnnestBranches(),  # (encoder, mask, old_act, new_act)
        layers.Reorder(output=(0, 1, (2, 3))),
        layers.Parallel(  # Residual after encoder-decoder attention.
            layers.Identity(), layers.Identity(), layers.SumBranches()),
        layers.Parallel(  # Feed-forward on the third component (decoder).
            layers.Identity(), layers.Identity(),
            ResidualFeedForward(feature_depth,
                                feedforward_depth,
                                dropout,
                                mode=mode)))
Exemple #6
0
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
    """WideResnet convolutational block."""
    main = layers.Serial(
        layers.BatchNorm(), layers.Relu(),
        layers.Conv(channels, (3, 3), strides, padding='SAME'),
        layers.BatchNorm(), layers.Relu(),
        layers.Conv(channels, (3, 3), padding='SAME'))
    shortcut = layers.Identity() if not channel_mismatch else layers.Conv(
        channels, (3, 3), strides, padding='SAME')
    return layers.Serial(layers.Branch(), layers.Parallel(main, shortcut),
                         layers.SumBranches())
Exemple #7
0
def IdentityBlock(kernel_size, filters):
    """ResNet identical size block."""
    ks = kernel_size
    filters1, filters2, filters3 = filters
    main = layers.Serial(layers.Conv(filters1, (1, 1)), layers.BatchNorm(),
                         layers.Relu(),
                         layers.Conv(filters2, (ks, ks), padding='SAME'),
                         layers.BatchNorm(), layers.Relu(),
                         layers.Conv(filters3, (1, 1)), layers.BatchNorm())
    return layers.Serial(layers.Branch(),
                         layers.Parallel(main, layers.Identity()),
                         layers.SumBranches(), layers.Relu())
Exemple #8
0
def EncoderLayer(feature_depth,
                 feedforward_depth,
                 num_heads,
                 dropout,
                 mode):
  """Transformer encoder layer.

  The input to the encoder is a pair (embedded source, mask) where
  the mask is created from the original source to prevent attending
  to the padding part of the input.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer, returning a pair (actiavtions, mask).
  """
  # The encoder block expects (activation, mask) as input and returns
  # the new activations only, we add the mask back to output next.
  encoder_block = layers.Serial(
      layers.Residual(  # Attention block here.
          layers.Parallel(layers.LayerNorm(), layers.Identity()),
          layers.MultiHeadedAttention(feature_depth, num_heads=num_heads,
                                      dropout=dropout, mode=mode),
          layers.Dropout(rate=dropout, mode=mode),
          shortcut=layers.FirstBranch()
      ),
      ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)
  )
  # Now we add the mask back.
  return layers.Serial(
      layers.Reorder(output=((0, 1), 1)),  # (x, mask) --> ((x, mask), mask)
      layers.Parallel(encoder_block, layers.Identity())
  )