def policy_and_value_net(n_actions, n_controls, vocab_size, bottom_layers_fn, two_towers): """A policy and value net function.""" # Layers. # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. @tl.layer() def FlattenControlsIntoTime(x, **unused_kwargs): # pylint: disable=invalid-name """Splits logits for actions in different controls and flattens controls.""" return np.reshape(x, (x.shape[0], -1, n_actions)) if vocab_size is None: # In continuous policies every element of the output sequence corresponds to # an observation. n_preds_per_input = n_controls kwargs = {} else: # In discrete policies every element of the output sequence corresponds to # a symbol in the discrete representation, and each control takes 1 symbol. n_preds_per_input = 1 kwargs = {"vocab_size": vocab_size} if two_towers: layers = [ tl.Dup(), tl.Parallel( [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [ bottom_layers_fn(**kwargs), tl.Dense(n_preds_per_input), tl.Flatten() ], ) ] else: layers = [ bottom_layers_fn(**kwargs), tl.Dup(), tl.Parallel( [ tl.Dense(n_preds_per_input * n_actions), FlattenControlsIntoTime(), # pylint: disable=no-value-for-parameter tl.LogSoftmax() ], [tl.Dense(n_preds_per_input), tl.Flatten()], ) ] return tl.Model(layers)
def test_awrtrainer_cartpole(self): """Test-runs AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) policy_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax()) value_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(1)) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AWRTrainer(task, n_shared_layers=0, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=32, value_train_steps_per_epoch=1000, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=32, policy_train_steps_per_epoch=1000, collect_per_epoch=10) trainer.run(1) self.assertEqual(1, trainer.current_epoch) self.assertGreater(trainer.avg_returns[-1], 180.0)
def classifier(vocab_size=len(Vocab), embedding_dim=256, output_dim=2, mode='predict'): # create embedding layer embed_layer = tl.Embedding( vocab_size=vocab_size, # Size of the vocabulary d_feature=embedding_dim) # Embedding dimension # Create a mean layer, to create an "average" word embedding mean_layer = tl.Mean(axis=1) # Create a dense layer, one unit for each output dense_output_layer = tl.Dense(n_units = output_dim) # Create the log softmax layer (no parameters needed) log_softmax_layer = tl.LogSoftmax() # Use tl.Serial to combine all layers # and create the classifier # of type trax.layers.combinators.Serial model = tl.Serial( embed_layer, # embedding layer mean_layer, # mean layer dense_output_layer, # dense output layer log_softmax_layer # log softmax layer ) # return the model of type return model
def TransformerEncoder(vocab_size=vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu, EncoderBlock=EncoderBlock): """ Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size (int): vocab size. Defaults to vocab_size. n_classes (int): how many classes on output. Defaults to 10. d_model (int): depth of embedding. Defaults to 512. d_ff (int): depth of feed-forward layer. Defaults to 2048. n_layers (int): number of encoder/decoder layers. Defaults to 6. n_heads (int): number of attention heads. Defaults to 8. dropout (float): dropout rate (how much to drop out). Defaults to 0.1. dropout_shared_axes (int): axes on which to share dropout mask. Defaults to None. max_len (int): maximum symbol length for positional encoding. Defaults to 2048. mode (str): 'train' or 'eval'. Defaults to 'train'. ff_activation (function): the non-linearity in feed-forward layer. Defaults to tl.Relu. EncoderBlock (function): Returns the encoder block. Defaults to EncoderBlock. Returns: trax.layers.combinators.Serial: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len) ] # repeatation of Encoder block upto number of layers encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(n_layers) ] # Encoder Model return tl.Serial( tl.Branch( positional_encoder, tl.PaddingMask(), ), encoder_blocks, tl.Select([0], n_in=2), tl.LayerNorm(), tl.Mean(axis=1), tl.Dense(n_classes), tl.LogSoftmax(), )
def NMTAttn(input_vocab_size=33300, target_vocab_size=33300, d_model=1024, n_encoder_layers=2, n_decoder_layers=2, n_attention_heads=4, attention_dropout=0.0, mode='train'): """Returns an LSTM sequence-to-sequence model with attention. The input to the model is a pair (input tokens, target tokens), e.g., an English sentence (tokenized) and its translation into German (tokenized). Args: input_vocab_size: int: vocab size of the input target_vocab_size: int: vocab size of the target d_model: int: depth of embedding (n_units in the LSTM cell) n_encoder_layers: int: number of LSTM layers in the encoder n_decoder_layers: int: number of LSTM layers in the decoder after attention n_attention_heads: int: number of attention heads attention_dropout: float, dropout for the attention layer mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: A LSTM sequence-to-sequence model with attention. """ # creation of input encoder for encoder activations input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers) # creation of layers for the pre-attention decoder pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model) # Model model = tl.Serial( # copy input tokens and target tokens for later use. tl.Select([0, 1, 0, 1]), # parellel run of input encoder on the input and pre-attention decoder the target. tl.Parallel(input_encoder, pre_attention_decoder), # preparation of queries, keys, values and mask for attention. tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4), # AttentionQKV layer nested it inside a Residual layer to add to the pre-attention decoder activations tl.Residual(tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)), tl.Select([0, 2]), # run the rest of the RNN decoder [tl.LSTM(n_units=d_model) for _ in range(n_decoder_layers)], # Dense layer of target size tl.Dense(target_vocab_size), #Log-softmax for output tl.LogSoftmax() ) return model
def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.Accuracy()], n_eval_batches=10) training_session = training.Loop( mnist_model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 50 == 0) training_session.run(n_steps=1000) self.assertEqual(training_session.step, 1000)
def test_a2ctrainer_cartpole(self): """Test-runs a2c on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, max_steps=2) policy_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax()) value_model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(1)) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-4, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic.AdvantageActorCriticTrainer( task, n_shared_layers=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=2, value_train_steps_per_epoch=2, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=2, policy_train_steps_per_epoch=2, collect_per_epoch=2) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_policytrainer_cartpole(self): """Trains a policy on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1, max_steps=200) # TODO(pkozakowski): Use Distribution.n_inputs to initialize the action # head. model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax()) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = training.PolicyGradientTrainer( task, policy_model=model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=128, policy_train_steps_per_epoch=1, collect_per_epoch=2) # Assert that we get to 200 at some point and then exit so the test is as # fast as possible. for ep in range(200): trainer.run(1) self.assertEqual(trainer.current_epoch, ep + 1) if trainer.avg_returns[-1] == 200.0: return self.fail( 'The expected score of 200 has not been reached. ' 'Maximum was {}.'.format(max(trainer.avg_returns)) )
def _mnist_tasks(head=None): """Creates MNIST training and evaluation tasks. Args: head: Adaptor layer to put before loss and accuracy layers in the tasks. Returns: A pair (train_task, eval_task) consisting of the MNIST training task and the MNIST evaluation task using cross-entropy as loss and accuracy as metric. """ loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) accuracy = tl.Accuracy() if head is not None: loss = tl.Serial(head, loss) accuracy = tl.Serial(head, accuracy) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), loss, adam.Adam(0.001), ) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [loss, accuracy], n_eval_batches=10, metric_names=['CrossEntropy', 'Accuracy'], ) return (task, eval_task)
def PositionLookupTransformerLM(vocab_size=128, d_model=256, d_ff=512, n_layers=3, n_heads=4, dropout=0.1, max_len=100, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: maximal length mode: str: 'train' or 'eval' Returns: the layer. """ positions = _POSITIONS[:max_len, :] return tl.Serial( tl.ShiftRight(), tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), NewPositionalEncoding(positions=positions), [ DecoderLayer(positions, d_model, d_ff, n_heads, dropout, mode) for _ in range(n_layers) ], PreservePosition(tl.LayerNorm()), tl.Dense(vocab_size), tl.LogSoftmax())
def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ gin.parse_config([ 'batch_fn.batch_size_per_device = 256', 'batch_fn.eval_batch_size = 256', ]) mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.AccuracyScalar()], names=['CrossEntropyLoss', 'AccuracyScalar'], eval_at=lambda step_n: step_n % 50 == 0, eval_N=10) training_session = training.Loop(mnist_model, task, eval_task=eval_task) training_session.run(n_steps=1000) self.assertEqual(training_session.current_step(), 1000)
def test_jointawrtrainer_cartpole(self): """Test-runs joint AWR on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=1000, max_steps=200) shared_model = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_top = lambda mode: tl.Serial(tl.Dense(2), tl.LogSoftmax()) value_top = lambda mode: tl.Dense(1) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-2, warmup_steps=100, factors='constant * linear_warmup') trainer = actor_critic_joint.AWRJointTrainer( task, shared_model=shared_model, policy_top=policy_top, value_top=value_top, optimizer=opt.Adam, lr_schedule=lr, batch_size=32, train_steps_per_epoch=1000, collect_per_epoch=10) trainer.run(1) self.assertEqual(1, trainer.current_epoch)
def test_call(self): layer = tl.LogSoftmax() x = np.array([[2., 1., -10.], [1., 1., -10.]]) y = layer(x) np.testing.assert_allclose( y, [[-0.313, -1.313, -12.313], [-0.693, -0.693, -11.693]], atol=.001)
def GRULM(vocab_size=256, d_model=512, n_layers=1, mode='train'): """Returns a GRU (gated recurrent unit) language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Embedding depth throughout the model. n_layers: Number of GRU layers. mode: If `'predict'`, use fast inference (and omit the right shift). Returns: A GRU language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ return tl.Serial(tl.ShiftRight(mode=mode), tl.Embedding(vocab_size, d_model), [tl.GRU(d_model) for _ in range(n_layers)], tl.Dense(vocab_size), tl.LogSoftmax())
def WideResnet(n_blocks=3, widen_factor=1, n_output_classes=10, bn_momentum=0.9, mode='train'): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: n_blocks: int, number of blocks in a group. total layers = 6n + 4. widen_factor: int, widening factor of each group. k=1 is vanilla resnet. n_output_classes: int, number of distinct output classes. bn_momentum: float, momentum in BatchNorm. mode: Whether we are training or evaluating or doing inference. Returns: The list of layers comprising a WideResnet model with the given parameters. """ return tl.Serial( tl.ToFloat(), tl.Conv(16, (3, 3), padding='SAME'), WideResnetGroup(n_blocks, 16 * widen_factor, bn_momentum=bn_momentum, mode=mode), WideResnetGroup(n_blocks, 32 * widen_factor, (2, 2), bn_momentum=bn_momentum, mode=mode), WideResnetGroup(n_blocks, 64 * widen_factor, (2, 2), bn_momentum=bn_momentum, mode=mode), tl.BatchNorm(momentum=bn_momentum, mode=mode), tl.Relu(), tl.AvgPool(pool_size=(8, 8)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def NMTAttn(input_vocab_size=33300, target_vocab_size=33300, d_model=1024, n_encoder_layers=2, n_decoder_layers=2, n_attention_heads=4, attention_dropout=0.0, mode='train'): input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers) pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model) model = tl.Serial( tl.Select([0, 1, 0, 1]), tl.Parallel(input_encoder, pre_attention_decoder), tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4), # nest it inside a Residual layer to add to the pre-attention decoder activations(i.e. queries) tl.Residual( tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)), # Step 6: drop attention mask (i.e. index = None tl.Select([0, 2]), [tl.LSTM(d_model) for _ in range(n_decoder_layers)], tl.Dense(target_vocab_size), tl.LogSoftmax()) return model
def TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=4096, mode='train', ff_activation=tl.Relu): positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode) ] decoder_blocks = [ DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers) ] # Put the different blocks and functions together to be executed like in a stack return tl.Serial( tl.ShiftRight(mode=mode), positional_encoder, decoder_blocks, tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def SkippingTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, d_attention_key=None, d_attention_value=None, attention_type=tl.DotProductCausalAttention, dropout=0.1, share_qk=False, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Skipping Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads d_attention_key: int: depth of key vector for each attention head (default is d_model // n_heads) d_attention_value: int: depth of value vector for each attention head (default is d_model // n_heads) attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: bool, whether to share queries and keys in decoder attention max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='embedding', mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] return tl.Serial( tl.ShiftRight(mode=mode), embedder, SkippingSerial( [ transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, d_attention_key, d_attention_value, attention_type, dropout, share_qk, i, mode, ff_activation) for i in range(n_layers) ], mode=mode), tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def GRULM(vocab_size=256, d_model=512, n_layers=2, mode='train'): """Returns a GRU language model. Args: vocab_size (int, optional): Size of the vocabulary. Defaults to 256. d_model (int, optional): Depth of embedding (n_units in the GRU cell). Defaults to 512. n_layers (int, optional): Number of GRU layers. Defaults to 2. mode (str, optional): 'train', 'eval' or 'predict', predict mode is for fast inference. Defaults to "train". Returns: trax.layers.combinators.Serial: A GRU language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ ### START CODE HERE (Replace instances of 'None' with your code) ### model = tl.Serial( tl.ShiftRight(mode=mode), # Stack the ShiftRight layer tl.Embedding(vocab_size=vocab_size, d_feature=d_model), # Stack the embedding layer [ tl.GRU(n_units=d_model) for i in range(n_layers) ], # Stack GRU layers of d_model units keeping n_layer parameter in mind (use list comprehension syntax) tl.Dense(n_units=vocab_size), # Dense layer tl.LogSoftmax() # Log Softmax ) ### END CODE HERE ### return model
def GRULM(vocab_size=256, d_model=512, n_layers=1, mode='train'): """Returns an GRU language model. The input to the model is a tensor of tokens (ints). Args: vocab_size: int: vocab size d_model: int: depth of embedding (n_units in the RNN cell) n_layers: int: number of RNN layers mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: An RNN language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ return tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(d_model, vocab_size), [tl.GRU(d_model) for _ in range(n_layers)], tl.Dense(vocab_size), tl.LogSoftmax() )
def log_prob(self, inputs, point): inputs = tl.LogSoftmax()(self._unflatten_inputs(inputs)) return jnp.sum( # Select the logits specified by point. inputs * tl.one_hot(point, self._n_categories), # Sum over the parameter dimensions. axis=[-a for a in range(1, len(self._shape) + 2)], )
def classifier(vocab_size=1, embedding_dim=256, output_dim=2, mode='train'): embed_layer = tl.Embedding(vocab_size=vocab_size, d_feature=embedding_dim) mean_layer = tl.Mean(axis=1) dense_output_layer = tl.Dense(n_units=output_dim) log_softmax_layer = tl.LogSoftmax() model = tl.Serial(embed_layer, mean_layer, dense_output_layer, log_softmax_layer) return model
def test_run_reversible_same_as_default_extended(self): """Runs the reversible trainer, check results are the same as default.""" inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) # We want to test rng propagation too, so adding some dropout layers. first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) rev_layers1 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), tl.ReversibleSwap(), tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), tl.ReversibleSwap() ] mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup()) rev_layers2 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)), tl.ReversibleSwap() ] loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3), tl.LogSoftmax(), tl.CrossEntropyLoss()) model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] + rev_layers2 + [loss_layer]) rng_init = fastmath.random.get_prng(12) model.init(labeled_batch, rng=rng_init) optimizer_fn = optimizers.Adam # to test slots # Make 3 steps with the original trainer. optimizer = optimizer_fn() optimizer.tree_init(model.weights) trainer = optimizers.Trainer(model, optimizer) rng_step1 = fastmath.random.get_prng(7) rng_step2 = fastmath.random.get_prng(8) rng_step3 = fastmath.random.get_prng(9) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) first_layer_weights1 = first_layer.weights rev_layer12_weights1 = rev_layers1[2].weights mid_layer_weights1 = mid_layer.weights rev_layer20_weights1 = rev_layers2[0].weights loss_layer_weights1 = loss_layer.weights # Now make 3 steps with reversible trainer. model.init(labeled_batch, rng=rng_init) trainer = optimizers.ReversibleSerialTrainer( [(first_layer.sublayers, rev_layers1), (mid_layer.sublayers, rev_layers2)], loss_layer, optimizer_fn) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) # Check that weights end up the same. self._assert_all_equal(loss_layer_weights1, loss_layer.weights) self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights) self._assert_all_equal(mid_layer_weights1, mid_layer.weights) self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights) self._assert_all_equal(first_layer_weights1, first_layer.weights)
def BERTClassifierHead(n_classes): return tl.Serial([ tl.Select([0], n_in=2), tl.Dense(n_classes, kernel_initializer=tl.RandomNormalInitializer(0.02), bias_initializer=tl.RandomNormalInitializer(1e-6), ), tl.LogSoftmax(), ])
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train', norm=tl.BatchNorm, non_linearity=tl.Relu): """ResNet. Args: d_hidden: Dimensionality of the first hidden layer (multiplied later). n_output_classes: Number of distinct output classes. mode: Whether we are training or evaluating or doing inference. norm: `Layer` used for normalization, Ex: BatchNorm or FilterResponseNorm. non_linearity: `Layer` used as a non-linearity, Ex: If norm is BatchNorm then this is a Relu, otherwise for FilterResponseNorm this should be ThresholdedLinearUnit. Returns: The list of layers comprising a ResNet model with the given parameters. """ # A ConvBlock configured with the given norm, non-linearity and mode. def Resnet50ConvBlock(filter_multiplier=1, strides=(2, 2)): filters = ([ filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden] ]) return ConvBlock(3, filters, strides, norm, non_linearity, mode) # Same as above for IdentityBlock. def Resnet50IdentityBlock(filter_multiplier=1): filters = ([ filter_multiplier * dim for dim in [d_hidden, d_hidden, 4 * d_hidden] ]) return IdentityBlock(3, filters, norm, non_linearity, mode) return tl.Serial( tl.ToFloat(), tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'), norm(mode=mode), non_linearity(), tl.MaxPool(pool_size=(3, 3), strides=(2, 2)), Resnet50ConvBlock(strides=(1, 1)), [Resnet50IdentityBlock() for _ in range(2)], Resnet50ConvBlock(2), [Resnet50IdentityBlock(2) for _ in range(3)], Resnet50ConvBlock(4), [Resnet50IdentityBlock(4) for _ in range(5)], Resnet50ConvBlock(8), [Resnet50IdentityBlock(8) for _ in range(2)], tl.AvgPool(pool_size=(7, 7)), tl.Flatten(), tl.Dense(n_output_classes), tl.LogSoftmax(), )
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size: int: vocab size n_classes: how many classes on output d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len) ] encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers) ] # Assemble and return the model. return tl.Serial( # toks # Encode. tl.Branch(positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks, # vecs masks tl.Select([0], n_in=2), # vecs tl.LayerNorm(), # vecs # Map to output categories. tl.Mean(axis=1), # vecs tl.Dense(n_classes), # vecs tl.LogSoftmax(), # vecs )
def get_model(num_classes): return tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(num_classes), tl.LogSoftmax(), )
def test_run_reversible_same_as_default_terraformer(self): """Runs the reversible trainer, check results are the same as default.""" inputs_batch = np.arange(8).reshape((2, 4)) + 1 targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) int_sig = shapes.ShapeDtype((2, 4), dtype=np.int32) input_sig = (int_sig, int_sig, int_sig) # We want to test rng propagation too, so adding some dropout layers. model = terraformer.ConfigurableTerraformer(20, d_model=8, d_ff=32, n_heads=1, dropout=0.0, n_encoder_layers=2, n_decoder_layers=2, ff_sparsity=(4, 8, 0.0, 1.0), pos_type=None, reversible_encoder=True) loss = tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()) optimizer_fn = optimizers.Adafactor blocks, loss_layer = optimizers.trainer.extract_reversible_blocks( [model, loss], loss_chunk_size=4) blocks_serial = [(tl.Serial(std), rev) for (std, rev) in blocks] model_with_loss = tl.Serial(model, loss) rng_init = fastmath.random.get_prng(12) model_with_loss.init(input_sig, rng=rng_init) # Make 3 steps with the original trainer. optimizer = optimizer_fn() optimizer.tree_init(model_with_loss.weights) trainer = optimizers.Trainer(model_with_loss, optimizer) rng_step1 = fastmath.random.get_prng(7) rng_step2 = fastmath.random.get_prng(8) rng_step3 = fastmath.random.get_prng(9) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) first_weights = blocks_serial[0][0].weights first_rev_weights = blocks[0][1][0].weights loss_weights = loss_layer.weights # Now make 3 steps with reversible trainer. model_with_loss.init(input_sig, rng=rng_init) trainer = optimizers.ReversibleSerialTrainer(blocks, loss_layer, optimizer_fn) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) # Check that weights end up the same. self._assert_all_equal(loss_weights, loss_layer.weights) self._assert_all_equal(first_rev_weights, blocks[0][1][0].weights) self._assert_all_equal(first_weights, blocks_serial[0][0].weights)
def test_policytrainer_cartpole(self): """Trains a policy on cartpole.""" task = rl_task.RLTask('CartPole-v0', initial_trajectories=100, max_steps=200) model = lambda mode: tl.Serial( # pylint: disable=g-long-lambda tl.Dense(32), tl.Relu(), tl.Dense(3), tl.LogSoftmax()) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-3, warmup_steps=100, factors='constant * linear_warmup') trainer = training.ExamplePolicyTrainer(task, model, opt.Adam, lr) trainer.run(1) self.assertEqual(1, trainer.current_epoch)
def test_run_reversible_slots(self): """Tests that slots can be read and assigned in reversible trainer.""" layers = [tl.Dense(4), tl.Dup()] rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4)), tl.ReversibleSwap()] loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(4), tl.LogSoftmax(), tl.CrossEntropyLoss()) trainer = optimizers.ReversibleSerialTrainer( [(layers, rev_layers)], loss_layer, optimizers.Adam) slots = trainer.slots trainer.slots = slots self.assertEqual(slots, trainer.slots)