def test_sanity_ppo_cartpole(self): """Run PPO and check whether it correctly runs for 2 epochs.s.""" task = rl_task.RLTask('CartPole-v1', initial_trajectories=0, max_steps=200) lr = lambda h: lr_schedules.MultifactorSchedule( # pylint: disable=g-long-lambda h, constant=1e-3, warmup_steps=100, factors='constant * linear_warmup') body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu()) policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) trainer = actor_critic.PPOTrainer(task, n_shared_layers=1, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr, value_batch_size=128, value_train_steps_per_epoch=10, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr, policy_batch_size=128, policy_train_steps_per_epoch=10, n_trajectories_per_epoch=10) trainer.run(2) self.assertEqual(2, trainer.current_epoch)
def test_boxing(self): """Test-runs PPO on Boxing.""" env = environments.load_from_settings(platform='atari', settings={ 'levelName': 'boxing', 'interleaved_pixels': True, 'zero_indexed_actions': True }) env = atari_wrapper.AtariWrapper(environment=env, num_stacked_frames=1) task = rl_task.RLTask(env, initial_trajectories=20, dm_suite=True, max_steps=200) body = lambda mode: atari_cnn.AtariCnnBody() policy_model = functools.partial(models.Policy, body=body) value_model = functools.partial(models.Value, body=body) # pylint: disable=g-long-lambda lr_value = lambda h: lr_schedules.MultifactorSchedule( h, constant=3e-3, warmup_steps=100, factors='constant * linear_warmup') lr_policy = lambda h: lr_schedules.MultifactorSchedule( h, constant=1e-3, warmup_steps=100, factors='constant * linear_warmup') # pylint: enable=g-long-lambda trainer = actor_critic.PPOTrainer(task, n_shared_layers=0, value_model=value_model, value_optimizer=opt.Adam, value_lr_schedule=lr_value, value_batch_size=1, value_train_steps_per_epoch=1, policy_model=policy_model, policy_optimizer=opt.Adam, policy_lr_schedule=lr_policy, policy_batch_size=1, policy_train_steps_per_epoch=1, collect_per_epoch=10) trainer.run(2) # Make sure that we test everywhere at least for 2 epochs, beucase # the first epoch is different self.assertEqual(2, trainer.current_epoch)