Ejemplo n.º 1
0
 def test_policytrainer_cartpole(self):
     """Trains a policy on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=1,
                           max_steps=200)
     model = functools.partial(
         models.Policy,
         body=lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
             tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu()),
     )
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-2,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = training.PolicyGradientTrainer(
         task,
         policy_model=model,
         policy_optimizer=opt.Adam,
         policy_lr_schedule=lr,
         policy_batch_size=128,
         policy_train_steps_per_epoch=1,
         collect_per_epoch=2)
     # Assert that we get to 200 at some point and then exit so the test is as
     # fast as possible.
     for ep in range(200):
         trainer.run(1)
         self.assertEqual(trainer.current_epoch, ep + 1)
         if trainer.avg_returns[-1] == 200.0:
             return
     self.fail('The expected score of 200 has not been reached. '
               'Maximum was {}.'.format(max(trainer.avg_returns)))
Ejemplo n.º 2
0
 def test_policytrainer_cartpole(self):
   """Trains a policy on cartpole."""
   task = rl_task.RLTask('CartPole-v0', initial_trajectories=1,
                         max_steps=2)
   # TODO(pkozakowski): Use Distribution.n_inputs to initialize the action
   # head.
   model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
       tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax())
   lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
       h, constant=1e-4, warmup_steps=100, factors='constant * linear_warmup')
   trainer = training.PolicyGradientTrainer(
       task,
       policy_model=model,
       policy_optimizer=opt.Adam,
       policy_lr_schedule=lr,
       policy_batch_size=2,
       policy_train_steps_per_epoch=2,
       collect_per_epoch=2)
   trainer.run(1)
   self.assertEqual(1, trainer.current_epoch)
Ejemplo n.º 3
0
 def test_policytrainer_cartpole(self):
     """Trains a policy on cartpole."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=750,
                           max_steps=200)
     model = lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
         tl.Dense(64), tl.Relu(), tl.Dense(2), tl.LogSoftmax())
     lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
         h,
         constant=1e-4,
         warmup_steps=100,
         factors='constant * linear_warmup')
     trainer = training.PolicyGradientTrainer(task,
                                              model,
                                              opt.Adam,
                                              lr_schedule=lr,
                                              batch_size=128,
                                              train_steps_per_epoch=700,
                                              collect_per_epoch=50)
     trainer.run(1)
     # This should *mostly* pass, this means that this test is flaky.
     self.assertGreater(trainer.avg_returns[-1], 35.0)
     self.assertEqual(1, trainer.current_epoch)
Ejemplo n.º 4
0
 def test_policytrainer_save_restore(self):
     """Check save and restore of policy trainer."""
     task = rl_task.RLTask('CartPole-v0',
                           initial_trajectories=10,
                           max_steps=200)
     model = functools.partial(
         models.Policy,
         body=lambda mode: tl.Serial(  # pylint: disable=g-long-lambda
             tl.Dense(64), tl.Relu(), tl.Dense(64), tl.Relu()),
     )
     tmp_dir = self.create_tempdir().full_path
     trainer1 = training.PolicyGradientTrainer(
         task,
         policy_model=model,
         policy_optimizer=opt.Adam,
         policy_batch_size=128,
         policy_train_steps_per_epoch=1,
         collect_per_epoch=2,
         output_dir=tmp_dir)
     trainer1.run(1)
     trainer1.run(1)
     self.assertEqual(trainer1.current_epoch, 2)
     self.assertEqual(trainer1._policy_trainer.step, 2)
     # Trainer 2 starts where trainer 1 stopped.
     trainer2 = training.PolicyGradientTrainer(
         task,
         policy_model=model,
         policy_optimizer=opt.Adam,
         policy_batch_size=128,
         policy_train_steps_per_epoch=1,
         collect_per_epoch=2,
         output_dir=tmp_dir)
     trainer2.run(1)
     self.assertEqual(trainer2.current_epoch, 3)
     self.assertEqual(trainer2._policy_trainer.step, 3)
     # Trainer 3 has 2x steps-per-epoch, but epoch 3, should raise an error.
     trainer3 = training.PolicyGradientTrainer(
         task,
         policy_model=model,
         policy_optimizer=opt.Adam,
         policy_batch_size=128,
         policy_train_steps_per_epoch=2,
         collect_per_epoch=2,
         output_dir=tmp_dir)
     self.assertRaises(ValueError, trainer3.run)
     # Manually set saved epoch to 1.
     dictionary = {'epoch': 1, 'avg_returns': [0.0]}
     with tf.io.gfile.GFile(os.path.join(tmp_dir, 'rl.pkl'), 'wb') as f:
         pickle.dump(dictionary, f)
     # Trainer 3 still should fail as steps between evals are 2, cannot do 1.
     self.assertRaises(ValueError, trainer3.run)
     # Trainer 4 does 1 step per eval, should train 1 step in epoch 2.
     trainer4 = training.PolicyGradientTrainer(
         task,
         policy_model=model,
         policy_optimizer=opt.Adam,
         policy_batch_size=128,
         policy_train_steps_per_epoch=2,
         policy_evals_per_epoch=2,
         collect_per_epoch=2,
         output_dir=tmp_dir)
     trainer4.run(1)
     self.assertEqual(trainer4.current_epoch, 2)
     self.assertEqual(trainer4._policy_trainer.step, 4)
     trainer1.close()
     trainer2.close()
     trainer3.close()
     trainer4.close()