Exemplo n.º 1
0
    def test_value_error_high_without_syncs(self):
        model = self._model_fn(mode='train')
        train_task = value_tasks.ValueTrainTask(
            self._trajectory_batch_stream,
            optimizer=opt.Adam(),
            lr_schedule=lr_schedules.constant(1e-3),
            advantage_estimator=advantages.td_k(gamma=self._task.gamma,
                                                margin=1),
            model=model,
            # Synchronize just once, at the end of training.
            sync_at=(lambda step: step == 100),
        )
        loop = training.Loop(
            model=model,
            tasks=[train_task],
        )

        # Assert that before training, the error is high.
        error_before = self._value_error(train_task.value)
        self.assertGreater(error_before, 2.0)

        loop.run(n_steps=100)

        # Assert that after training, the error is smaller, but still high.
        error_after = self._value_error(train_task.value)

        self.assertLess(error_after, 2.0)
        self.assertGreater(error_after, 0.8)
Exemplo n.º 2
0
    def test_value_error_low_with_syncs(self):
        min_error = np.inf
        for _ in range(5):
            model = self._model_fn(mode='train')
            train_task = value_tasks.ValueTrainTask(
                self._trajectory_batch_stream,
                optimizer=opt.Adam(),
                lr_schedule=lr_schedules.constant(1e-3),
                advantage_estimator=advantages.td_k(gamma=self._task.gamma,
                                                    margin=1),
                model=model,
                # Synchronize often throughout training.
                sync_at=(lambda step: step % 10 == 0),
            )
            loop = training.Loop(
                model=model,
                tasks=[train_task],
            )

            # Assert that before training, the error is high.
            error_before = self._value_error(train_task.value)
            self.assertGreater(error_before, 2.0)

            loop.run(n_steps=100)

            # Assert that after training, the error is small.
            error_after = self._value_error(train_task.value)

            if error_after < 0.8:
                return

            min_error = min(min_error, error_after)

        self.fail(
            f'Even after 5 trials, min error_after({min_error}) is not < 0.8')
Exemplo n.º 3
0
    def test_train_save_restore_sharded(self):
        """Saves and restores a sharded checkpoint to check for equivalence."""
        if fastmath.local_device_count() < 2:
            return  # multi-accelerator only
        base.N_WEIGHTS_SHARDS = fastmath.local_device_count()
        train_data = data.Serial(lambda _: _very_simple_data(2, 2),
                                 data.CountAndSkip('simple_data'))
        task = training.TrainTask(train_data(), tl.L2Loss(),
                                  optimizers.Adam(.0001))
        eval_task = training.EvalTask(
            _very_simple_data(2, 2),  # deliberately re-using training data
            [tl.L2Loss()],
            metric_names=['SGD.L2Loss'])
        tmp_dir = self.create_tempdir().full_path

        def _make_model_and_session():
            m = tl.Serial(tl.Dense(2))
            ts = training.Loop(m, [task],
                               eval_tasks=[eval_task],
                               eval_at=lambda step_n: step_n % 2 == 0,
                               output_dir=tmp_dir)
            return m, ts

        _, training_session = _make_model_and_session()
        self.assertEqual(0, training_session.step)
        training_session.run(n_steps=1)
        training_session.save_checkpoint('model')
        _, training_session2 = _make_model_and_session()
        training_session2.run(n_steps=1)
        base.N_WEIGHTS_SHARDS = 1
Exemplo n.º 4
0
 def test_loop_no_eval_task_tfnp(self):
     """Runs a training loop with no eval task(s), TFNP backend."""
     with fastmath.use_backend(fastmath.Backend.TFNP):
         model = tl.Serial(tl.Dense(1))
         task = training.TrainTask(_very_simple_data(), tl.L2Loss(),
                                   optimizers.Adam(.01))
         training_session = training.Loop(model, [task])
         # Loop should initialize and run successfully, even with no eval task.
         training_session.run(n_steps=5)
Exemplo n.º 5
0
  def test_integration_with_policy_tasks(self):
    # Integration test for policy + value training and eval.
    optimizer = opt.Adam()
    lr_schedule = lr_schedules.constant(1e-3)
    advantage_estimator = advantages.td_k(gamma=self._task.gamma, margin=1)
    policy_dist = distributions.create_distribution(self._task.action_space)
    body = lambda mode: tl.Dense(64)
    train_model = models.PolicyAndValue(policy_dist, body=body)
    eval_model = models.PolicyAndValue(policy_dist, body=body)

    head_selector = tl.Select([1])
    value_train_task = value_tasks.ValueTrainTask(
        self._trajectory_batch_stream,
        optimizer,
        lr_schedule,
        advantage_estimator,
        model=train_model,
        target_model=eval_model,
        head_selector=head_selector,
    )
    value_eval_task = value_tasks.ValueEvalTask(
        value_train_task, head_selector=head_selector
    )

    # Drop the value head - just tl.Select([0]) would pass it, and it would
    # override the targets.
    head_selector = tl.Select([0], n_in=2)
    policy_train_task = policy_tasks.PolicyTrainTask(
        self._trajectory_batch_stream,
        optimizer,
        lr_schedule,
        policy_dist,
        advantage_estimator,
        # Plug a trained critic as our value estimate.
        value_fn=value_train_task.value,
        head_selector=head_selector,
    )
    policy_eval_task = policy_tasks.PolicyEvalTask(
        policy_train_task, head_selector=head_selector
    )

    loop = training.Loop(
        model=train_model,
        eval_model=eval_model,
        tasks=[policy_train_task, value_train_task],
        eval_tasks=[policy_eval_task, value_eval_task],
        eval_at=(lambda _: True),
        # Switch the task every step.
        which_task=(lambda step: step % 2),
    )
    # Run for a couple of steps to make sure there are a few task switches.
    loop.run(n_steps=10)
Exemplo n.º 6
0
 def test_run_simple_task_tfnp(self):
   """Runs an accelerated optimizer on a simple task, TFNP backend."""
   with fastmath.use_backend(fastmath.Backend.TFNP):
     inputs_batch = np.arange(8).reshape((8, 1))  # 8 items per batch
     targets_batch = np.pi * np.ones_like(inputs_batch)
     labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch))
     loss_layer = tl.Serial(tl.Dense(1), tl.L2Loss())
     loss_layer.init(labeled_batch)
     optimizer = optimizers.Adam(.01)
     optimizer.tree_init(loss_layer.weights)
     trainer = optimizers.Trainer(loss_layer, optimizer)
     rng = fastmath.random.get_prng(0)
     trainer.one_step(labeled_batch, rng)
Exemplo n.º 7
0
 def test_restores_from_smaller_model(self):
     """Training restores from a checkpoint created with smaller model."""
     model1 = tl.Serial(tl.Dense(1))
     task = training.TrainTask(_very_simple_data(), tl.L2Loss(),
                               optimizers.Adam(.01))
     tmp_dir = self.create_tempdir().full_path
     loop = training.Loop(model1, [task],
                          checkpoint_at=lambda step_n: step_n % 2 == 0,
                          output_dir=tmp_dir)
     loop.run(2)
     model2 = tl.Serial(tl.Dense(1), tl.Dense(1))
     loop2 = training.Loop(model2, [task], output_dir=tmp_dir)
     self.assertEqual(2, loop2.step)
Exemplo n.º 8
0
 def test_restores_memory_efficient_from_standard(self):
     """Training restores step from directory where it saved it."""
     model = tl.Serial(tl.Dense(4), tl.Dense(1))
     task_std = training.TrainTask(_very_simple_data(), tl.L2Loss(),
                                   optimizers.Adam(.0001))
     tmp_dir = self.create_tempdir().full_path
     loop = training.Loop(model, [task_std],
                          checkpoint_at=lambda step_n: step_n % 2 == 0,
                          output_dir=tmp_dir)
     loop.run(4)
     task_memeff = training.TrainTask(_very_simple_data(), tl.L2Loss(),
                                      optimizers.Adam)
     loop2 = training.Loop(model, [task_memeff],
                           output_dir=tmp_dir,
                           use_memory_efficient_trainer=True)
     loop2.run(2)
     self.assertEqual(6, loop2.step)
Exemplo n.º 9
0
    def test_train_save_restore_dense(self):
        """Saves and restores a checkpoint to check for equivalence."""
        train_data = data.Serial(lambda _: _very_simple_data(),
                                 data.CountAndSkip('simple_data'))
        task = training.TrainTask(train_data(), tl.L2Loss(),
                                  optimizers.Adam(.0001))
        eval_task = training.EvalTask(
            _very_simple_data(),  # deliberately re-using training data
            [tl.L2Loss()],
            metric_names=['SGD.L2Loss'])
        tmp_dir = self.create_tempdir().full_path

        def _make_model_and_session():
            m = tl.Serial(tl.Dense(1))
            ts = training.Loop(m, [task],
                               eval_tasks=[eval_task],
                               eval_at=lambda step_n: step_n % 2 == 0,
                               output_dir=tmp_dir)
            return m, ts

        model, training_session = _make_model_and_session()
        self.assertEqual(0, training_session.step)
        training_session.run(n_steps=1)
        training_session.save_checkpoint()
        self.assertEqual(data.inputs.data_counters['simple_data'], 2)
        data.inputs.data_counters['simple_data'] = 0  # reset manually
        self.assertEqual(data.inputs.data_counters['simple_data'], 0)  # check
        model2, training_session2 = _make_model_and_session()
        self.assertEqual(data.inputs.data_counters['simple_data'],
                         2)  # restored

        x = np.ones((8, 1))
        y1 = model(x, rng=fastmath.random.get_prng(0))
        y2 = model2(x, rng=fastmath.random.get_prng(0))
        self.assertEqual(str(y1), str(y2))

        training_session2.run(n_steps=1)
        y1 = model(x, rng=fastmath.random.get_prng(0))
        y2 = model2(x, rng=fastmath.random.get_prng(0))
        self.assertNotEqual(str(y1), str(y2))

        slots1 = training_session._trainer_per_task[0].slots
        slots2 = training_session2._trainer_per_task[0].slots
        np.testing.assert_array_equal(slots1, slots2)
Exemplo n.º 10
0
 def test_value_tasks_smoke(self):
     # Smoke test for train + eval.
     model = self._model_fn(mode='train')
     train_task = value_tasks.ValueTrainTask(
         self._trajectory_batch_stream,
         optimizer=opt.Adam(),
         lr_schedule=lr_schedules.constant(1e-3),
         advantage_estimator=advantages.td_k(gamma=self._task.gamma,
                                             margin=1),
         model=model,
     )
     eval_task = value_tasks.ValueEvalTask(train_task)
     loop = training.Loop(
         model=model,
         tasks=[train_task],
         eval_tasks=[eval_task],
         eval_at=(lambda _: True),
     )
     loop.run(n_steps=1)