def encoder(embedded_source, source_mask):
    """Transformer encoder stack.

    Args:
      embedded_source: staxlayer variable: embedded source sequences
      source_mask: staxlayer variable: self-attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    encoder_layer = stax.serial(
        # input attends to self
        stax.residual(stax.LayerNorm(feature_depth),
                      stax.multiplex(stax.Identity,  # query
                                     stax.Identity,  # key
                                     stax.Identity,  # value
                                     source_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        stax.residual(stax.LayerNorm(feature_depth),
                      feed_forward,
                      stax.Dropout(dropout, mode=mode))
    )
    return stax.serial(
        embedded_source,
        stax.repeat(encoder_layer, num_layers),
        stax.LayerNorm(feature_depth),
    )
def TransformerLM(vocab_size,  # pylint: disable=invalid-name
                  mode='train',
                  num_layers=6,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_heads=8,
                  dropout=0.9,
                  max_len=256):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    mode: str: 'train' or 'eval'
    num_layers: int: number of encoder/decoder layers
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate - Stax follows TF's KEEP probability convention
    max_len: int: maximum symbol length for positional encoding

  Returns:
    init and apply.
  """
  # Multi-headed Attention and Feed-forward layers
  multi_attention = stax.MultiHeadedAttention(
      feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)

  feed_forward = stax.serial(
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(dropout, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform())
  )

  # Single decoder layer
  decoder_layer = stax.serial(
      # target attends to self
      stax.residual(stax.LayerNorm(feature_depth),
                    stax.multiplex(stax.Identity,  # query
                                   stax.Identity,  # key
                                   stax.Identity,  # value
                                   stax.CausalMask(axis=-2)),  # attention mask
                    multi_attention,
                    stax.Dropout(dropout, mode=mode)),
      # feed-forward
      stax.residual(stax.LayerNorm(feature_depth),
                    feed_forward,
                    stax.Dropout(dropout, mode=mode))
  )

  return stax.serial(
      stax.ShiftRight(),
      stax.Embedding(feature_depth, vocab_size),
      stax.PositionalEncoding(feature_depth, max_len=max_len),
      stax.Dropout(dropout, mode=mode),
      stax.repeat(decoder_layer, num_layers),
      stax.LayerNorm(feature_depth),
      stax.Dense(vocab_size, W_init=stax.xavier_uniform()),
      stax.LogSoftmax
  )
示例#3
0
  def encoder(source, source_mask):
    """Transformer encoder stack.

    Args:
      source: staxlayer variable: raw source sequences
      source_mask: staxlayer variable: self-attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    encoder_layer = stax.serial(
        # input attends to self
        stax.residual(stax.LayerNorm(),
                      stax.FanOut(4),
                      stax.parallel(stax.Identity,  # query
                                    stax.Identity,  # key
                                    stax.Identity,  # value
                                    source_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(keep_rate, mode=mode)),
        # feed-forward
        stax.residual(stax.LayerNorm(),
                      feed_forward,
                      stax.Dropout(keep_rate, mode=mode))
    )
    return stax.serial(
        source,
        source_embedding_layer,
        stax.repeat(encoder_layer, num_layers),
        stax.LayerNorm(),
    )
示例#4
0
def DecoderLayer(feature_depth,
                 feedforward_depth,
                 num_heads,
                 dropout,
                 mode):
  """Transformer decoder layer.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    init and apply.
  """
  return stax.serial(
      stax.residual(  # Self-attention block.
          stax.LayerNorm(),
          stax.FanOut(4),
          stax.parallel(stax.Identity,  # query
                        stax.Identity,  # key
                        stax.Identity,  # value
                        stax.CausalMask(axis=-2)),  # attention mask
          stax.MultiHeadedAttention(feature_depth, num_heads=num_heads,
                                    dropout=dropout, mode=mode),
          stax.Dropout(dropout, mode=mode)
      ),
      ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)
  )
示例#5
0
    def decoder(memory, target, target_mask, memory_mask):
        """Transformer decoder stack.

    Args:
      memory: staxlayer variable: encoded source sequences
      target: staxlayer variable: raw target sequences
      target_mask: staxlayer variable: self-attention mask
      memory_mask: staxlayer variable: memory attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
        decoder_layer = stax.serial(
            # target attends to self
            stax.residual(
                stax.LayerNorm(),
                stax.FanOut(4),
                stax.parallel(
                    stax.Identity,  # query
                    stax.Identity,  # key
                    stax.Identity,  # value
                    target_mask),  # attention mask
                multi_attention,
                stax.Dropout(keep_rate, mode=mode)),
            # target attends to encoded source
            stax.residual(
                stax.LayerNorm(),
                stax.FanOut(4),
                stax.parallel(
                    stax.Identity,  # query
                    memory,  # key
                    memory,  # value
                    memory_mask),  # attention mask
                multi_attention,
                stax.Dropout(keep_rate, mode=mode)),
            # feed-forward
            stax.residual(stax.LayerNorm(), feed_forward,
                          stax.Dropout(keep_rate, mode=mode)))
        return stax.serial(
            target,
            target_embedding_layer,
            stax.repeat(decoder_layer, num_layers),
            stax.LayerNorm(),
        )
示例#6
0
def ResidualFeedForward(feature_depth,
                        feedforward_depth,
                        dropout,
                        mode):
  """Residual feed-forward layer with normalization at start."""
  return stax.residual(
      stax.LayerNorm(),
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(dropout, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform()),
      stax.Dropout(dropout, mode=mode)
  )
示例#7
0
def TransformerLM(vocab_size,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_layers=6,
                  num_heads=8,
                  dropout=0.1,
                  max_len=2048,
                  mode='train'):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    init and apply.
  """
  return stax.serial(
      stax.ShiftRight(),
      stax.Embedding(feature_depth, vocab_size),
      stax.Dropout(dropout, mode=mode),
      stax.PositionalEncoding(feature_depth, max_len=max_len),
      stax.repeat(
          DecoderLayer(
              feature_depth, feedforward_depth, num_heads, dropout, mode),
          num_layers),
      stax.LayerNorm(),
      stax.Dense(vocab_size, W_init=stax.xavier_uniform()),
      stax.LogSoftmax
  )