def TransformerLM(vocab_size, # pylint: disable=invalid-name mode='train', num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.9, max_len=256): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate - Stax follows TF's KEEP probability convention max_len: int: maximum symbol length for positional encoding Returns: init and apply. """ # Multi-headed Attention and Feed-forward layers multi_attention = stax.MultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, mode=mode) feed_forward = stax.serial( stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(dropout, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()) ) # Single decoder layer decoder_layer = stax.serial( # target attends to self stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query stax.Identity, # key stax.Identity, # value stax.CausalMask(axis=-2)), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(feature_depth), feed_forward, stax.Dropout(dropout, mode=mode)) ) return stax.serial( stax.ShiftRight(), stax.Embedding(feature_depth, vocab_size), stax.PositionalEncoding(feature_depth, max_len=max_len), stax.Dropout(dropout, mode=mode), stax.repeat(decoder_layer, num_layers), stax.LayerNorm(feature_depth), stax.Dense(vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax )
def MLP(num_hidden_layers=2, hidden_size=512, activation_fn=stax.Relu, num_output_classes=10): layers = [stax.Flatten()] layers += [stax.Dense(hidden_size), activation_fn] * num_hidden_layers layers += [stax.Dense(num_output_classes), stax.LogSoftmax] return stax.serial(*layers)
def MLP(num_hidden_layers=2, hidden_size=512, activation_fn=stax.Relu, num_output_classes=10, mode="train"): """Multi-layer feed-forward neural network with non-linear activations.""" del mode layers = [stax.Flatten()] for _ in range(num_hidden_layers): layers += [stax.Dense(hidden_size), activation_fn()] layers += [stax.Dense(num_output_classes), stax.LogSoftmax()] return stax.Serial(*layers)
def ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode): """Residual feed-forward layer with normalization at start.""" return stax.residual( stax.LayerNorm(), stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(dropout, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()), stax.Dropout(dropout, mode=mode) )
def test_dense_param_sharing(self): model1 = stax.Serial(stax.Dense(32), stax.Dense(32)) layer = stax.Dense(32) model2 = stax.Serial(layer, layer) init_fun1, _ = model1 init_fun2, _ = model2 rng = random.get_prng(0) _, params1 = init_fun1(rng, [-1, 32]) _, params2 = init_fun2(rng, [-1, 32]) # The first parameters have 2 kernels of size (32, 32). self.assertEqual((32, 32), params1[0][0].shape) self.assertEqual((32, 32), params1[1][0].shape) # The second parameters have 1 kernel of size (32, 32) and an empty dict. self.assertEqual((32, 32), params2[0][0].shape) self.assertEqual((), params2[1])
def Resnet50(hidden_size=64, num_output_classes=1001): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. Returns: The ResNet model with the given layer and output sizes. """ return stax.serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size]), IdentityBlock(3, [hidden_size, hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), stax.AvgPool((7, 7)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax)
def policy_and_value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy and value net function.""" # Layers. layers = [] if bottom_layers is not None: layers.extend(bottom_layers) # Now, with the current logits, one head computes action probabilities and the # other computes the value function. layers.extend([stax.FanOut(), stax.Parallel( stax.Serial(stax.Dense(num_actions), stax.Softmax()), stax.Dense(1) )]) net_init, net_apply = stax.Serial(layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def test_training_loop(self): env = gym.make("CartPole-v0") # Usually gym envs are wrapped in TimeLimit wrapper. env = gym_utils.remove_time_limit_wrapper(env) # Limit this to a small number for tests. env = gym.wrappers.TimeLimit(env, max_episode_steps=2) num_epochs = 2 batch_size = 2 # Run the training loop. _, rewards, val_losses, ppo_objectives = ppo.training_loop( env=env, epochs=num_epochs, policy_net_fun=functools.partial(ppo.policy_net, bottom_layers=[stax.Dense(1)]), value_net_fun=functools.partial(ppo.value_net, bottom_layers=[stax.Dense(1)]), batch_size=batch_size, num_optimizer_steps=1, random_seed=0) self.assertLen(rewards, num_epochs) self.assertLen(val_losses, num_epochs) self.assertLen(ppo_objectives, num_epochs)
def policy_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy net function.""" # Use the bottom_layers as the bottom part of the network and just add the # required layers on top of it. if bottom_layers is None: bottom_layers = [] bottom_layers.extend([stax.Dense(num_actions), stax.Softmax()]) net_init, net_apply = stax.Serial(bottom_layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A value net function.""" del num_actions if bottom_layers is None: bottom_layers = [] bottom_layers.extend([ stax.Dense(1), ]) net_init, net_apply = stax.Serial(bottom_layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. Returns: The WideResnet model with given layer and output sizes. """ return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax)
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. mode: whether we are training or evaluating or doing inference. Returns: The ResNet model with the given layer and output sizes. """ del mode return stax.Serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu(), stax.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), stax.AvgPool(pool_size=(7, 7)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax())
def TransformerLM(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: init and apply. """ return stax.serial( stax.ShiftRight(), stax.Embedding(feature_depth, vocab_size), stax.Dropout(dropout, mode=mode), stax.PositionalEncoding(feature_depth, max_len=max_len), stax.repeat( DecoderLayer( feature_depth, feedforward_depth, num_heads, dropout, mode), num_layers), stax.LayerNorm(), stax.Dense(vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax )
def TransformerEncoder(mode='train', # pylint: disable=invalid-name num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.9): """Transformer Encoder Stack. Args: mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate - Stax follows TF's KEEP probability convention Returns: A staxlayer for implementing a raw Transformer encoder stack. No embedding or positional signals are added by this layer. """ # Multi-headed Attention and Feed-forward layers multi_attention = stax.MultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, mode=mode) feed_forward = stax.serial( stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(dropout, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()) ) @stax.Lambda def encoder(embedded_source, source_mask): """Transformer encoder stack. Args: embedded_source: staxlayer variable: embedded source sequences source_mask: staxlayer variable: self-attention mask Returns: Staxlayer variable that outputs encoded source. """ encoder_layer = stax.serial( # input attends to self stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query stax.Identity, # key stax.Identity, # value source_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(feature_depth), feed_forward, stax.Dropout(dropout, mode=mode)) ) return stax.serial( embedded_source, stax.repeat(encoder_layer, num_layers), stax.LayerNorm(feature_depth), ) return encoder
def common_stax_layers(): return [stax.Dense(16), stax.Relu, stax.Dense(4), stax.Relu]
def common_stax_layers(): layers = [] if FLAGS.env_name == "Pong-v0": layers = [stax.Div(divisor=255.0), stax.Flatten(num_axis_to_keep=2)] return layers + [stax.Dense(16), stax.Relu(), stax.Dense(4), stax.Relu()]
def generator(encoded_target): return stax.serial( encoded_target, stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax )
def Transformer(source_vocab_size, # pylint: disable=invalid-name target_vocab_size, mode='train', num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.9, shared_embedding=True, max_len=200, return_evals=False): """Transformer model. Args: source_vocab_size: int: source vocab size target_vocab_size: int: target vocab size mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate - Stax follows TF's KEEP probability convention shared_embedding: bool: specify whether source/target embeddings are tied. max_len: int: maximum symbol length for positional encoding return_evals: bool: whether to generate decode-time evaluation functions Returns: A namedtuple containing model 'init' and 'apply' functions for training and the 'evals' functions that itself returns a namedtuple containing evaluation functions for the trained encoder, decoder, and generator substax. """ # Input embedding and positional encoding inject_position = stax.serial( stax.PositionalEncoding(feature_depth, max_len=max_len), stax.Dropout(dropout, mode=mode) ) if shared_embedding: assert source_vocab_size == target_vocab_size # Weight-shared Embedding embedding = stax.Share(stax.Embedding(feature_depth, source_vocab_size)) source_embedding_layer = stax.serial(embedding, inject_position) target_embedding_layer = source_embedding_layer else: source_embedding = stax.Embedding(feature_depth, source_vocab_size) target_embedding = stax.Embedding(feature_depth, target_vocab_size) source_embedding_layer = stax.serial(source_embedding, inject_position) target_embedding_layer = stax.serial(target_embedding, inject_position) # Multi-headed Attention and Feed-forward layers multi_attention = stax.MultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, mode=mode) feed_forward = stax.serial( stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(dropout, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()) ) # Encoder @stax.Lambda def encoder(source, source_mask): """Transformer encoder stack. Args: source: staxlayer variable: raw source sequences source_mask: staxlayer variable: self-attention mask Returns: Staxlayer variable that outputs encoded source. """ encoder_layer = stax.serial( # input attends to self stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query stax.Identity, # key stax.Identity, # value source_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(feature_depth), feed_forward, stax.Dropout(dropout, mode=mode)) ) return stax.serial( source, source_embedding_layer, stax.repeat(encoder_layer, num_layers), stax.LayerNorm(feature_depth), ) # Decoder @stax.Lambda def decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: staxlayer variable: encoded source sequences target: staxlayer variable: raw target sequences target_mask: staxlayer variable: self-attention mask memory_mask: staxlayer variable: memory attention mask Returns: Staxlayer variable that outputs encoded source. """ decoder_layer = stax.serial( # target attends to self stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query stax.Identity, # key stax.Identity, # value target_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # target attends to encoded source stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query memory, # key memory, # value memory_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(feature_depth), feed_forward, stax.Dropout(dropout, mode=mode)) ) return stax.serial( target, target_embedding_layer, stax.repeat(decoder_layer, num_layers), stax.LayerNorm(feature_depth), ) # The Transformer @stax.Lambda def transformer(source, target, source_mask, target_mask, memory_mask): encoded_source = encoder(source, source_mask) return decoder(encoded_source, target, target_mask, memory_mask) # Finally, bind the generator transform to use later for inference. @stax.Lambda def generator(encoded_target): return stax.serial( encoded_target, stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax ) # Model-Building and Evaluation Functions # Get entire model's init and apply pair top_init, top_apply = generator(transformer) # By default act as a normal Stax constructor and emit an (init, apply) pair. if not return_evals: return (top_init, top_apply) else: # Inference-time function for binding trained params to model and returning # the python-bound sub-expressions for evaluation and sequence generation. def make_namedtuple(**kwargs): return collections.namedtuple('Model', kwargs.keys())(**kwargs) def get_evals(params): # We need to feed _concrete_ trained parameters through the network once. # Otherwise the bound parameters point to abstract tracer values. # The inputs don't matter. fake_inputs = 5 * (np.ones((1), dtype=np.int32),) fake_key = random.PRNGKey(1) top_apply(params, fake_inputs, rng=fake_key) # We can now return eval functions from the bound pieces of the model. return make_namedtuple( encoder=stax.make_apply_fun(encoder), generator=stax.make_apply_fun(generator), decoder=stax.make_apply_fun(decoder), ) # We return the functions needed to train and evaluate the Transformer. return make_namedtuple( init=top_init, apply=top_apply, evals=get_evals, )