def TransformerLM(vocab_size,  # pylint: disable=invalid-name
                  mode='train',
                  num_layers=6,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_heads=8,
                  dropout=0.9,
                  max_len=256):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    mode: str: 'train' or 'eval'
    num_layers: int: number of encoder/decoder layers
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate - Stax follows TF's KEEP probability convention
    max_len: int: maximum symbol length for positional encoding

  Returns:
    init and apply.
  """
  # Multi-headed Attention and Feed-forward layers
  multi_attention = stax.MultiHeadedAttention(
      feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)

  feed_forward = stax.serial(
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(dropout, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform())
  )

  # Single decoder layer
  decoder_layer = stax.serial(
      # target attends to self
      stax.residual(stax.LayerNorm(feature_depth),
                    stax.multiplex(stax.Identity,  # query
                                   stax.Identity,  # key
                                   stax.Identity,  # value
                                   stax.CausalMask(axis=-2)),  # attention mask
                    multi_attention,
                    stax.Dropout(dropout, mode=mode)),
      # feed-forward
      stax.residual(stax.LayerNorm(feature_depth),
                    feed_forward,
                    stax.Dropout(dropout, mode=mode))
  )

  return stax.serial(
      stax.ShiftRight(),
      stax.Embedding(feature_depth, vocab_size),
      stax.PositionalEncoding(feature_depth, max_len=max_len),
      stax.Dropout(dropout, mode=mode),
      stax.repeat(decoder_layer, num_layers),
      stax.LayerNorm(feature_depth),
      stax.Dense(vocab_size, W_init=stax.xavier_uniform()),
      stax.LogSoftmax
  )
Beispiel #2
0
def MLP(num_hidden_layers=2,
        hidden_size=512,
        activation_fn=stax.Relu,
        num_output_classes=10):
    layers = [stax.Flatten()]
    layers += [stax.Dense(hidden_size), activation_fn] * num_hidden_layers
    layers += [stax.Dense(num_output_classes), stax.LogSoftmax]
    return stax.serial(*layers)
Beispiel #3
0
def MLP(num_hidden_layers=2,
        hidden_size=512,
        activation_fn=stax.Relu,
        num_output_classes=10,
        mode="train"):
    """Multi-layer feed-forward neural network with non-linear activations."""
    del mode
    layers = [stax.Flatten()]
    for _ in range(num_hidden_layers):
        layers += [stax.Dense(hidden_size), activation_fn()]
    layers += [stax.Dense(num_output_classes), stax.LogSoftmax()]
    return stax.Serial(*layers)
Beispiel #4
0
def ResidualFeedForward(feature_depth,
                        feedforward_depth,
                        dropout,
                        mode):
  """Residual feed-forward layer with normalization at start."""
  return stax.residual(
      stax.LayerNorm(),
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(dropout, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform()),
      stax.Dropout(dropout, mode=mode)
  )
Beispiel #5
0
 def test_dense_param_sharing(self):
     model1 = stax.Serial(stax.Dense(32), stax.Dense(32))
     layer = stax.Dense(32)
     model2 = stax.Serial(layer, layer)
     init_fun1, _ = model1
     init_fun2, _ = model2
     rng = random.get_prng(0)
     _, params1 = init_fun1(rng, [-1, 32])
     _, params2 = init_fun2(rng, [-1, 32])
     # The first parameters have 2 kernels of size (32, 32).
     self.assertEqual((32, 32), params1[0][0].shape)
     self.assertEqual((32, 32), params1[1][0].shape)
     # The second parameters have 1 kernel of size (32, 32) and an empty dict.
     self.assertEqual((32, 32), params2[0][0].shape)
     self.assertEqual((), params2[1])
Beispiel #6
0
def Resnet50(hidden_size=64, num_output_classes=1001):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    return stax.serial(
        stax.Conv(hidden_size, (7, 7), (2, 2),
                  'SAME'), stax.BatchNorm(), stax.Relu,
        stax.MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size]),
        ConvBlock(3,
                  [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)), IdentityBlock(3,
                                         [8 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]),
        stax.AvgPool((7, 7)), stax.Flatten(), stax.Dense(num_output_classes),
        stax.LogSoftmax)
Beispiel #7
0
def policy_and_value_net(rng_key,
                         batch_observations_shape,
                         num_actions,
                         bottom_layers=None):
  """A policy and value net function."""

  # Layers.
  layers = []
  if bottom_layers is not None:
    layers.extend(bottom_layers)

  # Now, with the current logits, one head computes action probabilities and the
  # other computes the value function.
  layers.extend([stax.FanOut(), stax.Parallel(
      stax.Serial(stax.Dense(num_actions), stax.Softmax()),
      stax.Dense(1)
  )])

  net_init, net_apply = stax.Serial(layers)

  _, net_params = net_init(rng_key, batch_observations_shape)
  return net_params, net_apply
 def test_training_loop(self):
     env = gym.make("CartPole-v0")
     # Usually gym envs are wrapped in TimeLimit wrapper.
     env = gym_utils.remove_time_limit_wrapper(env)
     # Limit this to a small number for tests.
     env = gym.wrappers.TimeLimit(env, max_episode_steps=2)
     num_epochs = 2
     batch_size = 2
     # Run the training loop.
     _, rewards, val_losses, ppo_objectives = ppo.training_loop(
         env=env,
         epochs=num_epochs,
         policy_net_fun=functools.partial(ppo.policy_net,
                                          bottom_layers=[stax.Dense(1)]),
         value_net_fun=functools.partial(ppo.value_net,
                                         bottom_layers=[stax.Dense(1)]),
         batch_size=batch_size,
         num_optimizer_steps=1,
         random_seed=0)
     self.assertLen(rewards, num_epochs)
     self.assertLen(val_losses, num_epochs)
     self.assertLen(ppo_objectives, num_epochs)
Beispiel #9
0
def policy_net(rng_key,
               batch_observations_shape,
               num_actions,
               bottom_layers=None):
  """A policy net function."""
  # Use the bottom_layers as the bottom part of the network and just add the
  # required layers on top of it.
  if bottom_layers is None:
    bottom_layers = []
  bottom_layers.extend([stax.Dense(num_actions), stax.Softmax()])

  net_init, net_apply = stax.Serial(bottom_layers)

  _, net_params = net_init(rng_key, batch_observations_shape)
  return net_params, net_apply
Beispiel #10
0
def value_net(rng_key,
              batch_observations_shape,
              num_actions,
              bottom_layers=None):
  """A value net function."""
  del num_actions

  if bottom_layers is None:
    bottom_layers = []
  bottom_layers.extend([
      stax.Dense(1),
  ])

  net_init, net_apply = stax.Serial(bottom_layers)

  _, net_params = net_init(rng_key, batch_observations_shape)
  return net_params, net_apply
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
    return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'),
                       WideResnetGroup(num_blocks, hidden_size),
                       WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
                       WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
                       stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)),
                       stax.Flatten(), stax.Dense(num_output_classes),
                       stax.LogSoftmax)
Beispiel #12
0
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.
    mode: whether we are training or evaluating or doing inference.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    del mode
    return stax.Serial(
        stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(),
        stax.Relu(), stax.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        stax.AvgPool(pool_size=(7, 7)), stax.Flatten(),
        stax.Dense(num_output_classes), stax.LogSoftmax())
Beispiel #13
0
def TransformerLM(vocab_size,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_layers=6,
                  num_heads=8,
                  dropout=0.1,
                  max_len=2048,
                  mode='train'):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    init and apply.
  """
  return stax.serial(
      stax.ShiftRight(),
      stax.Embedding(feature_depth, vocab_size),
      stax.Dropout(dropout, mode=mode),
      stax.PositionalEncoding(feature_depth, max_len=max_len),
      stax.repeat(
          DecoderLayer(
              feature_depth, feedforward_depth, num_heads, dropout, mode),
          num_layers),
      stax.LayerNorm(),
      stax.Dense(vocab_size, W_init=stax.xavier_uniform()),
      stax.LogSoftmax
  )
def TransformerEncoder(mode='train',  # pylint: disable=invalid-name
                       num_layers=6,
                       feature_depth=512,
                       feedforward_depth=2048,
                       num_heads=8,
                       dropout=0.9):
  """Transformer Encoder Stack.

  Args:
    mode: str: 'train' or 'eval'
    num_layers: int: number of encoder/decoder layers
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate - Stax follows TF's KEEP probability convention

  Returns:
    A staxlayer for implementing a raw Transformer encoder stack.  No embedding
    or positional signals are added by this layer.
  """
  # Multi-headed Attention and Feed-forward layers
  multi_attention = stax.MultiHeadedAttention(
      feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)

  feed_forward = stax.serial(
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(dropout, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform())
  )

  @stax.Lambda
  def encoder(embedded_source, source_mask):
    """Transformer encoder stack.

    Args:
      embedded_source: staxlayer variable: embedded source sequences
      source_mask: staxlayer variable: self-attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    encoder_layer = stax.serial(
        # input attends to self
        stax.residual(stax.LayerNorm(feature_depth),
                      stax.multiplex(stax.Identity,  # query
                                     stax.Identity,  # key
                                     stax.Identity,  # value
                                     source_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        stax.residual(stax.LayerNorm(feature_depth),
                      feed_forward,
                      stax.Dropout(dropout, mode=mode))
    )
    return stax.serial(
        embedded_source,
        stax.repeat(encoder_layer, num_layers),
        stax.LayerNorm(feature_depth),
    )

  return encoder
Beispiel #15
0
def common_stax_layers():
  return [stax.Dense(16), stax.Relu, stax.Dense(4), stax.Relu]
Beispiel #16
0
def common_stax_layers():
    layers = []
    if FLAGS.env_name == "Pong-v0":
        layers = [stax.Div(divisor=255.0), stax.Flatten(num_axis_to_keep=2)]
    return layers + [stax.Dense(16), stax.Relu(), stax.Dense(4), stax.Relu()]
 def generator(encoded_target):
   return stax.serial(
       encoded_target,
       stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()),
       stax.LogSoftmax
   )
def Transformer(source_vocab_size,  # pylint: disable=invalid-name
                target_vocab_size,
                mode='train',
                num_layers=6,
                feature_depth=512,
                feedforward_depth=2048,
                num_heads=8,
                dropout=0.9,
                shared_embedding=True,
                max_len=200,
                return_evals=False):
  """Transformer model.

  Args:
    source_vocab_size: int: source vocab size
    target_vocab_size: int: target vocab size
    mode: str: 'train' or 'eval'
    num_layers: int: number of encoder/decoder layers
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate - Stax follows TF's KEEP probability convention
    shared_embedding: bool: specify whether source/target embeddings are tied.
    max_len: int: maximum symbol length for positional encoding
    return_evals: bool: whether to generate decode-time evaluation functions

  Returns:
    A namedtuple containing model 'init' and 'apply' functions for training and
  the 'evals' functions that itself returns a namedtuple containing evaluation
  functions for the trained encoder, decoder, and generator substax.
  """

  # Input embedding and positional encoding
  inject_position = stax.serial(
      stax.PositionalEncoding(feature_depth, max_len=max_len),
      stax.Dropout(dropout, mode=mode)
  )
  if shared_embedding:
    assert source_vocab_size == target_vocab_size
    # Weight-shared Embedding
    embedding = stax.Share(stax.Embedding(feature_depth, source_vocab_size))
    source_embedding_layer = stax.serial(embedding, inject_position)
    target_embedding_layer = source_embedding_layer
  else:
    source_embedding = stax.Embedding(feature_depth, source_vocab_size)
    target_embedding = stax.Embedding(feature_depth, target_vocab_size)
    source_embedding_layer = stax.serial(source_embedding, inject_position)
    target_embedding_layer = stax.serial(target_embedding, inject_position)

  # Multi-headed Attention and Feed-forward layers
  multi_attention = stax.MultiHeadedAttention(
      feature_depth, num_heads=num_heads, dropout=dropout, mode=mode)

  feed_forward = stax.serial(
      stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()),
      stax.Relu,
      stax.Dropout(dropout, mode=mode),
      stax.Dense(feature_depth, W_init=stax.xavier_uniform())
  )

  # Encoder
  @stax.Lambda
  def encoder(source, source_mask):
    """Transformer encoder stack.

    Args:
      source: staxlayer variable: raw source sequences
      source_mask: staxlayer variable: self-attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    encoder_layer = stax.serial(
        # input attends to self
        stax.residual(stax.LayerNorm(feature_depth),
                      stax.multiplex(stax.Identity,  # query
                                     stax.Identity,  # key
                                     stax.Identity,  # value
                                     source_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        stax.residual(stax.LayerNorm(feature_depth),
                      feed_forward,
                      stax.Dropout(dropout, mode=mode))
    )
    return stax.serial(
        source,
        source_embedding_layer,
        stax.repeat(encoder_layer, num_layers),
        stax.LayerNorm(feature_depth),
    )

  # Decoder
  @stax.Lambda
  def decoder(memory, target, target_mask, memory_mask):
    """Transformer decoder stack.

    Args:
      memory: staxlayer variable: encoded source sequences
      target: staxlayer variable: raw target sequences
      target_mask: staxlayer variable: self-attention mask
      memory_mask: staxlayer variable: memory attention mask

    Returns:
      Staxlayer variable that outputs encoded source.
    """
    decoder_layer = stax.serial(
        # target attends to self
        stax.residual(stax.LayerNorm(feature_depth),
                      stax.multiplex(stax.Identity,  # query
                                     stax.Identity,  # key
                                     stax.Identity,  # value
                                     target_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # target attends to encoded source
        stax.residual(stax.LayerNorm(feature_depth),
                      stax.multiplex(stax.Identity,  # query
                                     memory,  # key
                                     memory,  # value
                                     memory_mask),  # attention mask
                      multi_attention,
                      stax.Dropout(dropout, mode=mode)),
        # feed-forward
        stax.residual(stax.LayerNorm(feature_depth),
                      feed_forward,
                      stax.Dropout(dropout, mode=mode))
    )
    return stax.serial(
        target,
        target_embedding_layer,
        stax.repeat(decoder_layer, num_layers),
        stax.LayerNorm(feature_depth),
    )

  # The Transformer
  @stax.Lambda
  def transformer(source, target, source_mask, target_mask, memory_mask):
    encoded_source = encoder(source, source_mask)
    return decoder(encoded_source, target, target_mask, memory_mask)

  # Finally, bind the generator transform to use later for inference.
  @stax.Lambda
  def generator(encoded_target):
    return stax.serial(
        encoded_target,
        stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()),
        stax.LogSoftmax
    )

  # Model-Building and Evaluation Functions
  # Get entire model's init and apply pair
  top_init, top_apply = generator(transformer)

  # By default act as a normal Stax constructor and emit an (init, apply) pair.
  if not return_evals:
    return (top_init, top_apply)
  else:
    # Inference-time function for binding trained params to model and returning
    # the python-bound sub-expressions for evaluation and sequence generation.
    def make_namedtuple(**kwargs):
      return collections.namedtuple('Model', kwargs.keys())(**kwargs)

    def get_evals(params):
      # We need to feed _concrete_ trained parameters through the network once.
      # Otherwise the bound parameters point to abstract tracer values.
      # The inputs don't matter.
      fake_inputs = 5 * (np.ones((1), dtype=np.int32),)
      fake_key = random.PRNGKey(1)
      top_apply(params, fake_inputs, rng=fake_key)
      # We can now return eval functions from the bound pieces of the model.
      return make_namedtuple(
          encoder=stax.make_apply_fun(encoder),
          generator=stax.make_apply_fun(generator),
          decoder=stax.make_apply_fun(decoder),
      )

    # We return the functions needed to train and evaluate the Transformer.
    return make_namedtuple(
        init=top_init,
        apply=top_apply,
        evals=get_evals,
    )