Exemplo n.º 1
0
  def Decoder(memory, target, target_mask, memory_mask):
    """Transformer decoder stack.

    Args:
      memory: staxlayer variable: encoded source sequences
      target: staxlayer variable: raw target sequences
      target_mask: staxlayer variable: self-attention mask
      memory_mask: staxlayer variable: memory attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    decoder_layer = stax.serial(
        # target attends to self
        stax.residual(stax.LayerNorm(),
                      stax.FanOut(4),
                      stax.parallel(stax.Identity,  # query
                                    stax.Identity,  # key
                                    stax.Identity,  # value
                                    target_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # target attends to encoded source
        stax.residual(stax.LayerNorm(),
                      stax.FanOut(4),
                      stax.parallel(stax.Identity,  # query
                                    memory,  # key
                                    memory,  # value
                                    memory_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        ResidualFeedForward(
            feature_depth, feedforward_depth, dropout, mode=mode)
    )
    return stax.serial(
        target,
        target_embedding_layer,
        stax.repeat(decoder_layer, num_layers),
        stax.LayerNorm(),
    )
Exemplo n.º 2
0
def IdentityBlock(kernel_size, filters):
    """ResNet identical size block."""
    ks = kernel_size
    filters1, filters2 = filters

    def MakeMain(input_shape):
        # the number of output channels depends on the number of input channels
        return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(),
                           stax.Relu,
                           stax.Conv(filters2, (ks, ks), padding='SAME'),
                           stax.BatchNorm(), stax.Relu,
                           stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())

    main = stax.shape_dependent(MakeMain)
    return stax.serial(stax.FanOut(2), stax.parallel(main, stax.Identity),
                       stax.FanInSum, stax.Relu)
Exemplo n.º 3
0
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
    return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'),
                       WideResnetGroup(num_blocks, hidden_size),
                       WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
                       WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
                       stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)),
                       stax.Flatten(), stax.Dense(num_output_classes),
                       stax.LogSoftmax)
Exemplo n.º 4
0
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.
    mode: whether we are training or evaluating or doing inference.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    del mode
    return stax.serial(
        stax.Conv(hidden_size, (7, 7), (2, 2),
                  'SAME'), stax.BatchNorm(), stax.Relu,
        stax.MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size]),
        ConvBlock(3,
                  [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [8 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]),
        stax.AvgPool((7, 7)), stax.Flatten(), stax.Dense(num_output_classes),
        stax.LogSoftmax)
Exemplo n.º 5
0
def TransformerLM(vocab_size,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_layers=6,
                  num_heads=8,
                  dropout=0.1,
                  max_len=2048,
                  mode='train'):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    init and apply.
  """
  return stax.serial(
      stax.ShiftRight(),
      stax.Embedding(feature_depth, vocab_size),
      stax.Dropout(dropout, mode=mode),
      stax.PositionalEncoding(feature_depth, max_len=max_len),
      stax.repeat(
          DecoderLayer(
              feature_depth, feedforward_depth, num_heads, dropout, mode),
          num_layers),
      stax.LayerNorm(),
      stax.Dense(vocab_size, W_init=stax.xavier_uniform()),
      stax.LogSoftmax
  )
Exemplo n.º 6
0
 def lambda_fun3(x, y, z, w, v):
     input_tree = _build_combinator_tree(tree_spec, (x, y, z))
     return stax.serial(input_tree, stax.FanOut(3),
                        stax.parallel(w, v, stax.Identity),
                        stax.FanInSum)
Exemplo n.º 7
0
def WideResnetGroup(n, channels, strides=(1, 1)):
    blocks = []
    blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
    for _ in range(n - 1):
        blocks += [WideResnetBlock(channels, (1, 1))]
    return stax.serial(*blocks)
Exemplo n.º 8
0
 def generator(encoded_target):
   return stax.serial(
       encoded_target,
       stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()),
       stax.LogSoftmax
   )
Exemplo n.º 9
0
def TransformerEncoder(mode='train',  # pylint: disable=invalid-name
                       num_layers=6,
                       feature_depth=512,
                       feedforward_depth=2048,
                       num_heads=8,
                       dropout=0.9):
  """Transformer Encoder Stack.

  Args:
    mode: str: 'train' or 'eval'
    num_layers: int: number of encoder/decoder layers
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate - Stax follows TF's KEEP probability convention

  Returns:
    A staxlayer for implementing a raw Transformer encoder stack.  No embedding
    or positional signals are added by this layer.
  """
  # Multi-headed Attention and Feed-forward layers
  multi_attention = stax.MultiHeadedAttention(
      feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)

  feed_forward = stax.serial(
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(dropout, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform())
  )

  @stax.Lambda
  def encoder(embedded_source, source_mask):
    """Transformer encoder stack.

    Args:
      embedded_source: staxlayer variable: embedded source sequences
      source_mask: staxlayer variable: self-attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    encoder_layer = stax.serial(
        # input attends to self
        stax.residual(stax.LayerNorm(feature_depth),
                      stax.multiplex(stax.Identity,  # query
                                     stax.Identity,  # key
                                     stax.Identity,  # value
                                     source_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        stax.residual(stax.LayerNorm(feature_depth),
                      feed_forward,
                      stax.Dropout(dropout, mode=mode))
    )
    return stax.serial(
        embedded_source,
        stax.repeat(encoder_layer, num_layers),
        stax.LayerNorm(feature_depth),
    )

  return encoder
Exemplo n.º 10
0
def Transformer(source_vocab_size,  # pylint: disable=invalid-name
                target_vocab_size,
                mode='train',
                num_layers=6,
                feature_depth=512,
                feedforward_depth=2048,
                num_heads=8,
                dropout=0.9,
                shared_embedding=True,
                max_len=200,
                return_evals=False):
  """Transformer model.

  Args:
    source_vocab_size: int: source vocab size
    target_vocab_size: int: target vocab size
    mode: str: 'train' or 'eval'
    num_layers: int: number of encoder/decoder layers
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate - Stax follows TF's KEEP probability convention
    shared_embedding: bool: specify whether source/target embeddings are tied.
    max_len: int: maximum symbol length for positional encoding
    return_evals: bool: whether to generate decode-time evaluation functions

  Returns:
    A namedtuple containing model 'init' and 'apply' functions for training and
  the 'evals' functions that itself returns a namedtuple containing evaluation
  functions for the trained encoder, decoder, and generator substax.
  """

  # Input embedding and positional encoding
  inject_position = stax.serial(
      stax.PositionalEncoding(feature_depth, max_len=max_len),
      stax.Dropout(dropout, mode=mode)
  )
  if shared_embedding:
    assert source_vocab_size == target_vocab_size
    # Weight-shared Embedding
    embedding = stax.Share(stax.Embedding(feature_depth, source_vocab_size))
    source_embedding_layer = stax.serial(embedding, inject_position)
    target_embedding_layer = source_embedding_layer
  else:
    source_embedding = stax.Embedding(feature_depth, source_vocab_size)
    target_embedding = stax.Embedding(feature_depth, target_vocab_size)
    source_embedding_layer = stax.serial(source_embedding, inject_position)
    target_embedding_layer = stax.serial(target_embedding, inject_position)

  # Multi-headed Attention and Feed-forward layers
  multi_attention = stax.MultiHeadedAttention(
      feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)

  feed_forward = stax.serial(
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(dropout, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform())
  )

  # Encoder
  @stax.Lambda
  def encoder(source, source_mask):
    """Transformer encoder stack.

    Args:
      source: staxlayer variable: raw source sequences
      source_mask: staxlayer variable: self-attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    encoder_layer = stax.serial(
        # input attends to self
        stax.residual(stax.LayerNorm(feature_depth),
                      stax.multiplex(stax.Identity,  # query
                                     stax.Identity,  # key
                                     stax.Identity,  # value
                                     source_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        stax.residual(stax.LayerNorm(feature_depth),
                      feed_forward,
                      stax.Dropout(dropout, mode=mode))
    )
    return stax.serial(
        source,
        source_embedding_layer,
        stax.repeat(encoder_layer, num_layers),
        stax.LayerNorm(feature_depth),
    )

  # Decoder
  @stax.Lambda
  def decoder(memory, target, target_mask, memory_mask):
    """Transformer decoder stack.

    Args:
      memory: staxlayer variable: encoded source sequences
      target: staxlayer variable: raw target sequences
      target_mask: staxlayer variable: self-attention mask
      memory_mask: staxlayer variable: memory attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    decoder_layer = stax.serial(
        # target attends to self
        stax.residual(stax.LayerNorm(feature_depth),
                      stax.multiplex(stax.Identity,  # query
                                     stax.Identity,  # key
                                     stax.Identity,  # value
                                     target_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # target attends to encoded source
        stax.residual(stax.LayerNorm(feature_depth),
                      stax.multiplex(stax.Identity,  # query
                                     memory,  # key
                                     memory,  # value
                                     memory_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        stax.residual(stax.LayerNorm(feature_depth),
                      feed_forward,
                      stax.Dropout(dropout, mode=mode))
    )
    return stax.serial(
        target,
        target_embedding_layer,
        stax.repeat(decoder_layer, num_layers),
        stax.LayerNorm(feature_depth),
    )

  # The Transformer
  @stax.Lambda
  def transformer(source, target, source_mask, target_mask, memory_mask):
    encoded_source = encoder(source, source_mask)
    return decoder(encoded_source, target, target_mask, memory_mask)

  # Finally, bind the generator transform to use later for inference.
  @stax.Lambda
  def generator(encoded_target):
    return stax.serial(
        encoded_target,
        stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()),
        stax.LogSoftmax
    )

  # Model-Building and Evaluation Functions
  # Get entire model's init and apply pair
  top_init, top_apply = generator(transformer)

  # By default act as a normal Stax constructor and emit an (init, apply) pair.
  if not return_evals:
    return (top_init, top_apply)
  else:
    # Inference-time function for binding trained params to model and returning
    # the python-bound sub-expressions for evaluation and sequence generation.
    def make_namedtuple(**kwargs):
      return collections.namedtuple('Model', kwargs.keys())(**kwargs)

    def get_evals(params):
      # We need to feed _concrete_ trained parameters through the network once.
      # Otherwise the bound parameters point to abstract tracer values.
      # The inputs don't matter.
      fake_inputs = 5 * (np.ones((1), dtype=np.int32),)
      fake_key = random.PRNGKey(1)
      top_apply(params, fake_inputs, rng=fake_key)
      # We can now return eval functions from the bound pieces of the model.
      return make_namedtuple(
          encoder=stax.make_apply_fun(encoder),
          generator=stax.make_apply_fun(generator),
          decoder=stax.make_apply_fun(decoder),
      )

    # We return the functions needed to train and evaluate the Transformer.
    return make_namedtuple(
        init=top_init,
        apply=top_apply,
        evals=get_evals,
    )
Exemplo n.º 11
0
def TransformerLM(vocab_size,  # pylint: disable=invalid-name
                  mode='train',
                  num_layers=6,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_heads=8,
                  dropout=0.1,
                  max_len=2048):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    mode: str: 'train' or 'eval'
    num_layers: int: number of encoder/decoder layers
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding

  Returns:
    init and apply.
  """
  keep_rate = 1.0 - dropout
  # Multi-headed Attention and Feed-forward layers
  multi_attention = stax.MultiHeadedAttention(
      feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode)

  feed_forward = stax.serial(
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(keep_rate, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform())
  )

  # Single decoder layer
  decoder_layer = stax.serial(
      # target attends to self
      stax.residual(stax.LayerNorm(),
                    stax.FanOut(4),
                    stax.parallel(stax.Identity,  # query
                                  stax.Identity,  # key
                                  stax.Identity,  # value
                                  stax.CausalMask(axis=-2)),  # attention mask
                    multi_attention,
                    stax.Dropout(keep_rate, mode=mode)),
      # feed-forward
      stax.residual(stax.LayerNorm(),
                    feed_forward,
                    stax.Dropout(keep_rate, mode=mode))
  )

  return stax.serial(
      stax.ShiftRight(),
      stax.Embedding(feature_depth, vocab_size),
      stax.Dropout(keep_rate, mode=mode),
      stax.PositionalEncoding(feature_depth, max_len=max_len),
      stax.repeat(decoder_layer, num_layers),
      stax.LayerNorm(),
      stax.Dense(vocab_size, W_init=stax.xavier_uniform()),
      stax.LogSoftmax
  )
Exemplo n.º 12
0
def Transformer(source_vocab_size,
                target_vocab_size,
                mode='train',
                num_layers=6,
                feature_depth=512,
                feedforward_depth=2048,
                num_heads=8,
                dropout=0.1,
                shared_embedding=True,
                max_len=200,
                return_evals=False):
  """Transformer model.

  Args:
    source_vocab_size: int: source vocab size
    target_vocab_size: int: target vocab size
    mode: str: 'train' or 'eval'
    num_layers: int: number of encoder/decoder layers
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    shared_embedding: bool: specify whether source/target embeddings are tied.
    max_len: int: maximum symbol length for positional encoding
    return_evals: bool: whether to generate decode-time evaluation functions

  Returns:
    A namedtuple containing model 'init' and 'apply' functions for training and
  the 'evals' functions that itself returns a namedtuple containing evaluation
  functions for the trained encoder, decoder, and generator substax.
  """
  # Input embedding and positional encoding
  inject_position = stax.serial(
      stax.Dropout(dropout, mode=mode),
      stax.PositionalEncoding(feature_depth, max_len=max_len)
  )
  if shared_embedding:
    assert source_vocab_size == target_vocab_size
    # Weight-shared Embedding
    embedding = stax.Share(stax.Embedding(feature_depth, source_vocab_size))
    source_embedding_layer = stax.serial(embedding, inject_position)
    target_embedding_layer = source_embedding_layer
  else:
    source_embedding = stax.Embedding(feature_depth, source_vocab_size)
    target_embedding = stax.Embedding(feature_depth, target_vocab_size)
    source_embedding_layer = stax.serial(source_embedding, inject_position)
    target_embedding_layer = stax.serial(target_embedding, inject_position)

  # Multi-headed Attention and Feed-forward layers
  multi_attention = stax.MultiHeadedAttention(
      feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)

  # Encoder
  @stax.Lambda
  def Encoder(source, source_mask):
    """Transformer encoder stack.

    Args:
      source: staxlayer variable: raw source sequences
      source_mask: staxlayer variable: self-attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    encoder_layer = stax.serial(
        # input attends to self
        stax.residual(stax.LayerNorm(),
                      stax.FanOut(4),
                      stax.parallel(stax.Identity,  # query
                                    stax.Identity,  # key
                                    stax.Identity,  # value
                                    source_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        ResidualFeedForward(
            feature_depth, feedforward_depth, dropout, mode=mode),
    )
    return stax.serial(
        source,
        source_embedding_layer,
        stax.repeat(encoder_layer, num_layers),
        stax.LayerNorm(),
    )

  # Decoder
  @stax.Lambda
  def Decoder(memory, target, target_mask, memory_mask):
    """Transformer decoder stack.

    Args:
      memory: staxlayer variable: encoded source sequences
      target: staxlayer variable: raw target sequences
      target_mask: staxlayer variable: self-attention mask
      memory_mask: staxlayer variable: memory attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    decoder_layer = stax.serial(
        # target attends to self
        stax.residual(stax.LayerNorm(),
                      stax.FanOut(4),
                      stax.parallel(stax.Identity,  # query
                                    stax.Identity,  # key
                                    stax.Identity,  # value
                                    target_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # target attends to encoded source
        stax.residual(stax.LayerNorm(),
                      stax.FanOut(4),
                      stax.parallel(stax.Identity,  # query
                                    memory,  # key
                                    memory,  # value
                                    memory_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        ResidualFeedForward(
            feature_depth, feedforward_depth, dropout, mode=mode)
    )
    return stax.serial(
        target,
        target_embedding_layer,
        stax.repeat(decoder_layer, num_layers),
        stax.LayerNorm(),
    )

  # The Transformer
  @stax.Lambda
  def transformer(source, target, source_mask, target_mask, memory_mask):  # pylint: disable=invalid-name
    encoded_source = Encoder(source, source_mask)
    return Decoder(encoded_source, target, target_mask, memory_mask)

  # Finally, bind the generator transform to use later for inference.
  @stax.Lambda
  def Generator(encoded_target):
    return stax.serial(
        encoded_target,
        stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()),
        stax.LogSoftmax
    )

  # Model-Building and Evaluation Functions
  # Get entire model's init and apply pair
  top_init, top_apply = Generator(transformer)

  # By default act as a normal Stax constructor and emit an (init, apply) pair.
  if not return_evals:
    return (top_init, top_apply)
  else:
    raise ValueError('inference in this model is still a work in progress')
Exemplo n.º 13
0
 def lambda_fun2(x, y, z, w, v):
     input_tree = _build_combinator_tree(tree_spec, (x, y, z))
     return stax.serial(input_tree,
                        stax.multiplex(w, stax.Identity, v),
                        stax.FanInSum)