def test_value_error_high_without_syncs(self): model = self._model_fn(mode='train') train_task = value_tasks.ValueTrainTask( self._trajectory_batch_stream, optimizer=opt.Adam(), lr_schedule=lr_schedules.constant(1e-3), advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), model=model, # Synchronize just once, at the end of training. sync_at=(lambda step: step == 100), ) loop = training.Loop( model=model, tasks=[train_task], ) # Assert that before training, the error is high. error_before = self._value_error(train_task.value) self.assertGreater(error_before, 2.0) loop.run(n_steps=100) # Assert that after training, the error is smaller, but still high. error_after = self._value_error(train_task.value) self.assertLess(error_after, 2.0) self.assertGreater(error_after, 0.8)
def test_value_error_low_with_syncs(self): min_error = np.inf for _ in range(5): model = self._model_fn(mode='train') train_task = value_tasks.ValueTrainTask( self._trajectory_batch_stream, optimizer=opt.Adam(), lr_schedule=lr_schedules.constant(1e-3), advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), model=model, # Synchronize often throughout training. sync_at=(lambda step: step % 10 == 0), ) loop = training.Loop( model=model, tasks=[train_task], ) # Assert that before training, the error is high. error_before = self._value_error(train_task.value) self.assertGreater(error_before, 2.0) loop.run(n_steps=100) # Assert that after training, the error is small. error_after = self._value_error(train_task.value) if error_after < 0.8: return min_error = min(min_error, error_after) self.fail( f'Even after 5 trials, min error_after({min_error}) is not < 0.8')
def test_train_save_restore_sharded(self): """Saves and restores a sharded checkpoint to check for equivalence.""" if fastmath.local_device_count() < 2: return # multi-accelerator only base.N_WEIGHTS_SHARDS = fastmath.local_device_count() train_data = data.Serial(lambda _: _very_simple_data(2, 2), data.CountAndSkip('simple_data')) task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(.0001)) eval_task = training.EvalTask( _very_simple_data(2, 2), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = tl.Serial(tl.Dense(2)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts _, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint('model') _, training_session2 = _make_model_and_session() training_session2.run(n_steps=1) base.N_WEIGHTS_SHARDS = 1
def test_loop_no_eval_task_tfnp(self): """Runs a training loop with no eval task(s), TFNP backend.""" with fastmath.use_backend(fastmath.Backend.TFNP): model = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam(.01)) training_session = training.Loop(model, [task]) # Loop should initialize and run successfully, even with no eval task. training_session.run(n_steps=5)
def test_integration_with_policy_tasks(self): # Integration test for policy + value training and eval. optimizer = opt.Adam() lr_schedule = lr_schedules.constant(1e-3) advantage_estimator = advantages.td_k(gamma=self._task.gamma, margin=1) policy_dist = distributions.create_distribution(self._task.action_space) body = lambda mode: tl.Dense(64) train_model = models.PolicyAndValue(policy_dist, body=body) eval_model = models.PolicyAndValue(policy_dist, body=body) head_selector = tl.Select([1]) value_train_task = value_tasks.ValueTrainTask( self._trajectory_batch_stream, optimizer, lr_schedule, advantage_estimator, model=train_model, target_model=eval_model, head_selector=head_selector, ) value_eval_task = value_tasks.ValueEvalTask( value_train_task, head_selector=head_selector ) # Drop the value head - just tl.Select([0]) would pass it, and it would # override the targets. head_selector = tl.Select([0], n_in=2) policy_train_task = policy_tasks.PolicyTrainTask( self._trajectory_batch_stream, optimizer, lr_schedule, policy_dist, advantage_estimator, # Plug a trained critic as our value estimate. value_fn=value_train_task.value, head_selector=head_selector, ) policy_eval_task = policy_tasks.PolicyEvalTask( policy_train_task, head_selector=head_selector ) loop = training.Loop( model=train_model, eval_model=eval_model, tasks=[policy_train_task, value_train_task], eval_tasks=[policy_eval_task, value_eval_task], eval_at=(lambda _: True), # Switch the task every step. which_task=(lambda step: step % 2), ) # Run for a couple of steps to make sure there are a few task switches. loop.run(n_steps=10)
def test_run_simple_task_tfnp(self): """Runs an accelerated optimizer on a simple task, TFNP backend.""" with fastmath.use_backend(fastmath.Backend.TFNP): inputs_batch = np.arange(8).reshape((8, 1)) # 8 items per batch targets_batch = np.pi * np.ones_like(inputs_batch) labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) loss_layer = tl.Serial(tl.Dense(1), tl.L2Loss()) loss_layer.init(labeled_batch) optimizer = optimizers.Adam(.01) optimizer.tree_init(loss_layer.weights) trainer = optimizers.Trainer(loss_layer, optimizer) rng = fastmath.random.get_prng(0) trainer.one_step(labeled_batch, rng)
def test_restores_from_smaller_model(self): """Training restores from a checkpoint created with smaller model.""" model1 = tl.Serial(tl.Dense(1)) task = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam(.01)) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model1, [task], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(2) model2 = tl.Serial(tl.Dense(1), tl.Dense(1)) loop2 = training.Loop(model2, [task], output_dir=tmp_dir) self.assertEqual(2, loop2.step)
def test_restores_memory_efficient_from_standard(self): """Training restores step from directory where it saved it.""" model = tl.Serial(tl.Dense(4), tl.Dense(1)) task_std = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam(.0001)) tmp_dir = self.create_tempdir().full_path loop = training.Loop(model, [task_std], checkpoint_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) loop.run(4) task_memeff = training.TrainTask(_very_simple_data(), tl.L2Loss(), optimizers.Adam) loop2 = training.Loop(model, [task_memeff], output_dir=tmp_dir, use_memory_efficient_trainer=True) loop2.run(2) self.assertEqual(6, loop2.step)
def test_train_save_restore_dense(self): """Saves and restores a checkpoint to check for equivalence.""" train_data = data.Serial(lambda _: _very_simple_data(), data.CountAndSkip('simple_data')) task = training.TrainTask(train_data(), tl.L2Loss(), optimizers.Adam(.0001)) eval_task = training.EvalTask( _very_simple_data(), # deliberately re-using training data [tl.L2Loss()], metric_names=['SGD.L2Loss']) tmp_dir = self.create_tempdir().full_path def _make_model_and_session(): m = tl.Serial(tl.Dense(1)) ts = training.Loop(m, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n % 2 == 0, output_dir=tmp_dir) return m, ts model, training_session = _make_model_and_session() self.assertEqual(0, training_session.step) training_session.run(n_steps=1) training_session.save_checkpoint() self.assertEqual(data.inputs.data_counters['simple_data'], 2) data.inputs.data_counters['simple_data'] = 0 # reset manually self.assertEqual(data.inputs.data_counters['simple_data'], 0) # check model2, training_session2 = _make_model_and_session() self.assertEqual(data.inputs.data_counters['simple_data'], 2) # restored x = np.ones((8, 1)) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertEqual(str(y1), str(y2)) training_session2.run(n_steps=1) y1 = model(x, rng=fastmath.random.get_prng(0)) y2 = model2(x, rng=fastmath.random.get_prng(0)) self.assertNotEqual(str(y1), str(y2)) slots1 = training_session._trainer_per_task[0].slots slots2 = training_session2._trainer_per_task[0].slots np.testing.assert_array_equal(slots1, slots2)
def test_value_tasks_smoke(self): # Smoke test for train + eval. model = self._model_fn(mode='train') train_task = value_tasks.ValueTrainTask( self._trajectory_batch_stream, optimizer=opt.Adam(), lr_schedule=lr_schedules.constant(1e-3), advantage_estimator=advantages.td_k(gamma=self._task.gamma, margin=1), model=model, ) eval_task = value_tasks.ValueEvalTask(train_task) loop = training.Loop( model=model, tasks=[train_task], eval_tasks=[eval_task], eval_at=(lambda _: True), ) loop.run(n_steps=1)