def test_restores_history(self): """Training restores history from directory where it saved it.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()]) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(4) loop2 = training.Loop(model, [task], output_dir=tmp_dir) self.assertLen(loop2.history.modes, 2) self.assertLen(loop2.history.metrics_for_mode('train'), 6) self.assertLen(loop2.history.metrics_for_mode('eval'), 1) for mode, metric in [ ('train', 'metrics/L2Loss'), ('train', 'training/learning_rate'), ('train', 'training/steps per second'), ('train', 'training/gradients_l2'), ('train', 'training/loss'), ('train', 'training/weights_l2'), ('eval', 'metrics/L2Loss'), ]: self.assertLen(loop2.history.get(mode, metric), 1) self.assertEqual(2, loop2.history.get(mode, metric)[0][0])
def test_restores_step(self): """Training restores step from directory where it saved it.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(4) loop2 = training.Loop(model, [task], output_dir=tmp_dir) self.assertEqual(4, loop2.step)
def test_restore_fails_different_model(self): """Training restores from a checkpoint created with smaller model.""" model1 = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model1, [task], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(2) model2 = tl.Serial(tl.Dense(2)) with self.assertRaises(IndexError): training.Loop(model2, [task], output_dir=tmp_dir)
def test_restores_from_smaller_model(self): """Training restores from a checkpoint created with smaller model.""" model1 = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam(.01)) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model1, [task], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(2) model2 = tl.Serial(tl.Dense(1), tl.Dense(1)) loop2 = training.Loop(model2, [task], output_dir=tmp_dir) self.assertEqual(2, loop2.step)
def test_restores_step_sharded_bfloat16(self): """Training restores step from where it saved it, sharded and bfloat16.""" model = tl.Serial(tl.Dense(1, use_bfloat16=True)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir, use_memory_efficient_trainer=True) loop.run(4) loop2 = training.Loop(model, [task], output_dir=tmp_dir, use_memory_efficient_trainer=True) self.assertEqual(4, loop2.step) loop2.run(2) # check that continued training works self.assertEqual(6, loop2.step)
def test_train_one_task_eval_two_tasks(self): """Trains a very simple network on one task and evaluates on two tasks.""" model = tl.Serial(tl.Dense(3), tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) export_prefix_1 = 'eval_1' eval_task_1 = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], export_prefix=export_prefix_1, ) export_prefix_2 = 'eval_2' eval_task_2 = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], export_prefix=export_prefix_2, ) training_session = training.Loop( model, tasks=(task, ), eval_tasks=(eval_task_1, eval_task_2), ) self.assertEqual(0, training_session.step) training_session.run(n_steps=5) self.assertEqual(5, training_session.step) export_prefixes = [ task.export_prefix for task in training_session.eval_tasks ] self.assertCountEqual([export_prefix_1, export_prefix_2], export_prefixes)
def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.Accuracy()], n_eval_batches=10) training_session = training.Loop( mnist_model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 50 == 0) training_session.run(n_steps=1000) self.assertEqual(training_session.step, 1000)
def test_train_mnist_multitask(self, mock_stdout): """Train two-head MNIST model a bit, to compare to other implementations.""" mnist_model = _build_model(two_heads=True) # MNIST classification task. (cls_task, cls_eval_task) = _mnist_tasks(head=tl.Select([0], n_in=2)) # Auxiliary brightness prediction task. reg_task = training.TrainTask( itertools.cycle(_mnist_brightness_dataset().train_stream(1)), tl.Serial(tl.Select([1]), tl.L2Loss()), adam.Adam(0.001), ) reg_eval_task = training.EvalTask( itertools.cycle(_mnist_brightness_dataset().eval_stream(1)), [tl.Serial(tl.Select([1]), tl.L2Loss())], n_eval_batches=1, metric_names=['L2'], ) training_session = training.Loop( mnist_model, tasks=[cls_task, reg_task], eval_tasks=[cls_eval_task, reg_eval_task], eval_at=lambda step_n: step_n % 20 == 0, which_task=lambda step_n: step_n % 2, ) training_session.run(n_steps=100) self.assertEqual(training_session.step, 100) # Assert that we reach at least 80% eval accuracy on MNIST. self.assertGreater(_read_metric('Accuracy', mock_stdout), 0.8) # Assert that we get below 0.03 brightness prediction error. self.assertLess(_read_metric('L2', mock_stdout), 0.03)
def _make_model_and_session(): m = transformer.TransformerLM( vocab_size, d_model=4, d_ff=4, n_layers=1, n_heads=2, dropout=0.) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts
def test_value_error_low_with_syncs(self): min_error = np.inf for _ in range(5): model = self._model_fn(mode='train') train_task = value_tasks.ValueTrainTask( self._trajectory_batch_stream, optimizer=opt.Adam(), lr_schedule=lr_schedules.constant(1e-3), advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), model=model, # Synchronize often throughout training. sync_at=(lambda step: step % 10 == 0), ) loop = training.Loop( model=model, tasks=[train_task], ) # Assert that before training, the error is high. error_before = self._value_error(train_task.value) self.assertGreater(error_before, 2.0) loop.run(n_steps=100) # Assert that after training, the error is small. error_after = self._value_error(train_task.value) if error_after < 0.8: return min_error = min(min_error, error_after) self.fail( f'Even after 5 trials, min error_after({min_error}) is not < 0.8')
def test_initializes_step_callbacks_with_loop_instance(self): """Runs a training loop, asserting that callbacks are initialized.""" class ActualLoop: # Wrapper object to make the Loop reference mutable. loop = None class TestCallback(callbacks.TrainingStepCallback): def __init__(self, loop): super().__init__(loop) ActualLoop.loop = loop def call_at(self, step): return False def on_step_begin(self, step): del step def on_step_end(self, step): del step model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) expected_loop = training.Loop(model, [task], callbacks=[TestCallback]) self.assertIs(ActualLoop.loop, expected_loop)
def test_calls_step_callbacks(self): """Runs a training loop, asserting that callbacks are called.""" call_at_steps = [1, 3, 4] begin_steps = [] end_steps = [] test_case = self class TestCallback(callbacks.TrainingStepCallback): def call_at(self, step): return step in call_at_steps def on_step_begin(self, step): begin_steps.append(step) def on_step_end(self, step): # Assert that on_step_begin() was called before. test_case.assertIn(step, begin_steps) end_steps.append(step) model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) loop = training.Loop(model, [task], callbacks=[TestCallback]) loop.run(n_steps=5) # Assert that the callback has been called at the appropriate steps. self.assertEqual(begin_steps, call_at_steps) self.assertEqual(end_steps, call_at_steps)
def training_loop(TransformerLM, train_gen, eval_gen, output_dir="./model"): output_dir = os.path.expanduser(output_dir) lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01) # This sets up loss function and our adam optimizer used to fit the data efficiently train_task = training.TrainTask(labeled_data=train_gen, loss_layer=tl.CrossEntropyLoss(), optimizer=trax.optimizers.Adam(0.01), lr_schedule=lr_schedule, n_steps_per_checkpoint=10) # We evaluate on a different dataset to ensure no overfitting eval_task = training.EvalTask( labeled_data=eval_gen, metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]) loop = training.Loop(TransformerLM(d_model=512, d_ff=2048, n_layers=6, n_heads=8, mode='train'), train_task, eval_tasks=[eval_task], output_dir=output_dir) return loop
def test_can_predict_with_trained_model(self): model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) train_tasks, eval_tasks = [], [] for output_dim in [1, 2]: # The head we select from the model: 0 for output_dim 1 and 1 for 2. head_index = output_dim - 1 train_tasks.append( training.TrainTask( _very_simple_data(output_dim), tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss()), optimizers.SGD(.01))) eval_tasks.append( training.EvalTask( _very_simple_data( output_dim), # deliberately re-use training data [tl.Serial(tl.Select([head_index], n_in=2), tl.L2Loss())])) tmp_dir = self.create_tempdir().full_path training_session = training.Loop( model, tasks=train_tasks, eval_tasks=eval_tasks, checkpoint_at=lambda step_n: step_n == 1, output_dir=tmp_dir, which_task=lambda step_n: step_n % 2, ) training_session.run(n_steps=2) trained_model = training_session.eval_model inp = next(_very_simple_data())[0] out = trained_model(inp) self.assertEqual( shapes.signature(out), (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), )
def test_train_mnist(self): """Train MNIST model (almost) fully, to compare to other implementations. Evals for cross-entropy loss and accuracy are run every 50 steps; their values are visible in the test log. """ gin.parse_config([ 'batch_fn.batch_size_per_device = 256', 'batch_fn.eval_batch_size = 256', ]) mnist_model = tl.Serial( tl.Flatten(), tl.Dense(512), tl.Relu(), tl.Dense(512), tl.Relu(), tl.Dense(10), tl.LogSoftmax(), ) task = training.TrainTask( itertools.cycle(_mnist_dataset().train_stream(1)), tl.CrossEntropyLoss(), adafactor.Adafactor(.02)) eval_task = training.EvalTask( itertools.cycle(_mnist_dataset().eval_stream(1)), [tl.CrossEntropyLoss(), tl.AccuracyScalar()], names=['CrossEntropyLoss', 'AccuracyScalar'], eval_at=lambda step_n: step_n % 50 == 0, eval_N=10) training_session = training.Loop(mnist_model, task, eval_task=eval_task) training_session.run(n_steps=1000) self.assertEqual(training_session.current_step(), 1000)
def _make_model_and_session(): m = tl.Serial(tl.Dense(1)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts
def test_summaries_are_written(self): """Training writes down metrics when writting is turned on.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path training_session = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) expected_train_metric_dir = os.path.join(tmp_dir, 'train') expected_eval_metric_dir = os.path.join(tmp_dir, 'eval') for directory in [expected_train_metric_dir, expected_eval_metric_dir]: self.assertFalse( os.path.isdir(directory), 'Failed for directory %s.' % directory) training_session.run(n_steps=15) time.sleep(1) # wait for the files to be closed for directory in [expected_train_metric_dir, expected_eval_metric_dir]: self.assertTrue( os.path.isdir(directory), 'Failed for directory %s.' % directory) self.assertEqual( 1, _count_files(directory), 'Failed for directory %s.' % directory) training_session.run(n_steps=5) time.sleep(1) # wait for the files to be closed for directory in [expected_train_metric_dir, expected_eval_metric_dir]: self.assertEqual( 2, _count_files(directory), 'Failed for directory %s.' % directory)
def training_loop(n_steps=50, cutoff=0.05, output_dir="./model/"): train_gen, eval_gen, vocab_size = generate_data(cutoff) lr_schedule = trax.lr.warmup_and_rsqrt_decay(n_warmup_steps=1000, max_value=0.01) train_task = training.TrainTask( # labeled data labeled_data=train_gen, # loss layer loss_layer=tl.CrossEntropyLoss(), # optimizer optimizer=trax.optimizers.Adam(0.01), # lr_schedule lr_schedule=lr_schedule, # n_steps n_steps_per_checkpoint=n_steps) eval_task = training.EvalTask( # labeled data labeled_data=eval_gen, # metrics metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]) loop = training.Loop(ReformerLM(vocab_size, 6, mode='train'), train_task, eval_tasks=[eval_task], output_dir=output_dir) return loop
def test_value_error_high_without_syncs(self): model = self._model_fn(mode='train') train_task = value_tasks.ValueTrainTask( self._trajectory_batch_stream, optimizer=opt.Adam(), lr_schedule=lr_schedules.constant(1e-3), advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), model=model, # Synchronize just once, at the end of training. sync_at=(lambda step: step == 100), ) loop = training.Loop( model=model, tasks=[train_task], ) # Assert that before training, the error is high. error_before = self._value_error(train_task.value) self.assertGreater(error_before, 2.0) loop.run(n_steps=100) # Assert that after training, the error is smaller, but still high. error_after = self._value_error(train_task.value) self.assertLess(error_after, 2.0) self.assertGreater(error_after, 0.8)
def test_can_predict_with_trained_model(self): model = tl.Serial(tl.Dense(3), tl.Branch(tl.Dense(1), tl.Dense(2))) tasks = tuple( training.TrainTask( # pylint: disable=g-complex-comprehension _very_simple_data(output_dim), tl.L2Loss(), optimizers.SGD(.01), ) for output_dim in (1, 2)) eval_tasks = tuple([ training.EvalTask( # pylint: disable=g-complex-comprehension # deliberately re-using training data _very_simple_data(output_dim), [tl.L2Loss()], ) ] for output_dim in (1, 2)) tmp_dir = self.create_tempdir().full_path training_session = training.Loop( model, tasks=tasks, eval_tasks=eval_tasks, checkpoint_at=lambda step_n: step_n == 1, output_dir=tmp_dir, which_task=lambda step_n: step_n % 2, ) training_session.run(n_steps=2) trained_model = training_session.eval_model inp = next(_very_simple_data())[0] out = trained_model(inp) self.assertEqual( shapes.signature(out), (shapes.ShapeDtype((8, 1)), shapes.ShapeDtype((8, 2))), )
def test_restores_step_bfloat16(self): """Training restores step from directory where it saved it, w/ bfloat16.""" model = tl.Serial(tl.Dense(1, use_bfloat16=True)) # We'll also use Adafactor with bfloat16 to check restoring bfloat slots. opt = optimizers.Adafactor(.01, do_momentum=True, momentum_in_bfloat16=True) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), opt) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(4) loop2 = training.Loop(model, [task], output_dir=tmp_dir) self.assertEqual(4, loop2.step) loop2.run(2) # check that continued training works self.assertEqual(6, loop2.step)
def train_model(classifier, train_task, eval_task, n_steps, output_dir): training_loop = training.Loop(classifier, train_task, eval_tasks=[eval_task], output_dir=output_dir, random_seed=31) training_loop.run(n_steps=n_steps) return training_loop
def test_loop_no_eval_task(self): """Runs a training loop with no eval task(s).""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) training_session = training.Loop(model, [task]) # Loop should initialize and run successfully, even with no eval task. training_session.run(n_steps=5)
def test_restores_memory_efficient_from_standard(self): """Training restores step from directory where it saved it.""" model = tl.Serial(tl.Dense(4), tl.Dense(1)) task_std = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001)) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task_std], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(4) task_memeff = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam) loop2 = training.Loop(model, [task_memeff], output_dir=tmp_dir, use_memory_efficient_trainer=True) loop2.run(2) self.assertEqual(6, loop2.step)
def test_loop_no_eval_task_tfnp(self): """Runs a training loop with no eval task(s), TFNP backend.""" with fastmath.use_backend(fastmath.Backend.TFNP): model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam(.01)) training_session = training.Loop(model, [task]) # Loop should initialize and run successfully, even with no eval task. training_session.run(n_steps=5)
def test_integration_with_policy_tasks(self): # Integration test for policy + value training and eval. optimizer = opt.Adam() lr_schedule = lr_schedules.constant(1e-3) advantage_estimator = advantages.td_k(gamma=self._task.gamma, margin=1) policy_dist = distributions.create_distribution(self._task.action_space) body = lambda mode: tl.Dense(64) train_model = models.PolicyAndValue(policy_dist, body=body) eval_model = models.PolicyAndValue(policy_dist, body=body) head_selector = tl.Select([1]) value_train_task = value_tasks.ValueTrainTask( self._trajectory_batch_stream, optimizer, lr_schedule, advantage_estimator, model=train_model, target_model=eval_model, head_selector=head_selector, ) value_eval_task = value_tasks.ValueEvalTask( value_train_task, head_selector=head_selector ) # Drop the value head - just tl.Select([0]) would pass it, and it would # override the targets. head_selector = tl.Select([0], n_in=2) policy_train_task = policy_tasks.PolicyTrainTask( self._trajectory_batch_stream, optimizer, lr_schedule, policy_dist, advantage_estimator, # Plug a trained critic as our value estimate. value_fn=value_train_task.value, head_selector=head_selector, ) policy_eval_task = policy_tasks.PolicyEvalTask( policy_train_task, head_selector=head_selector ) loop = training.Loop( model=train_model, eval_model=eval_model, tasks=[policy_train_task, value_train_task], eval_tasks=[policy_eval_task, value_eval_task], eval_at=(lambda _: True), # Switch the task every step. which_task=(lambda step: step % 2), ) # Run for a couple of steps to make sure there are a few task switches. loop.run(n_steps=10)
def test_train_dense_layer_with_momentum(self): """Trains with an optimizer that has slots / requires initialization.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.Momentum(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['Momentum.L2Loss']) training_session = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0) self.assertEqual(0, training_session.step) training_session.run(n_steps=20) self.assertEqual(20, training_session.step)
def test_train_dense_layer_evals(self): """Trains a very simple network on a very simple task, 2 epochs.""" model = tl.Serial(tl.Dense(1)) task = training.TrainTask( _very_simple_data(), tl.L2Loss(), optimizers.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()]) training_session = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: False) self.assertEqual(0, training_session.step) training_session.run(n_steps=10) self.assertEqual(10, training_session.step) training_session.run_evals() self.assertEqual(10, training_session.step) # Unchanged
def test_train_dense_layer_with_momentum(self): """Trains with an optimizer that has slots / requires initialization.""" model = tl.Dense(1) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), momentum.Momentum(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], names=['Momentum.L2Loss'], eval_at=lambda step_n: step_n % 2 == 0, eval_N=1) training_session = training.Loop(model, task, eval_task=eval_task) self.assertIsNone(training_session.current_step()) training_session.run(n_steps=20) self.assertEqual(20, training_session.current_step())
def test_train_dense_layer(self): """Trains a very simple network on a very simple task.""" model = tl.Dense(1) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), sgd.SGD(.01)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], names=['SGD.L2Loss'], eval_at=lambda step_n: step_n % 2 == 0, eval_N=1) training_session = training.Loop(model, task, eval_task=eval_task) self.assertIsNone(training_session.current_step()) training_session.run(n_steps=20) self.assertEqual(20, training_session.current_step())