Esempio n. 1
0
    def test_sanity_ppo_cartpole(self):
        """Run PPO and check whether it correctly runs for 2 epochs.s."""
        task = rl_task.RLTask('CartPole-v1',
                              initial_trajectories=0,
                              max_steps=200)

        lr = lambda h: lr_schedules.MultifactorSchedule(  # pylint: disable=g-long-lambda
            h,
            constant=1e-3,
            warmup_steps=100,
            factors='constant * linear_warmup')

        body = lambda mode: tl.Serial(tl.Dense(64), tl.Relu())
        policy_model = functools.partial(models.Policy, body=body)
        value_model = functools.partial(models.Value, body=body)
        trainer = actor_critic.PPOTrainer(task,
                                          n_shared_layers=1,
                                          value_model=value_model,
                                          value_optimizer=opt.Adam,
                                          value_lr_schedule=lr,
                                          value_batch_size=128,
                                          value_train_steps_per_epoch=10,
                                          policy_model=policy_model,
                                          policy_optimizer=opt.Adam,
                                          policy_lr_schedule=lr,
                                          policy_batch_size=128,
                                          policy_train_steps_per_epoch=10,
                                          n_trajectories_per_epoch=10)

        trainer.run(2)
        self.assertEqual(2, trainer.current_epoch)
Esempio n. 2
0
    def test_boxing(self):
        """Test-runs PPO on Boxing."""
        env = environments.load_from_settings(platform='atari',
                                              settings={
                                                  'levelName': 'boxing',
                                                  'interleaved_pixels': True,
                                                  'zero_indexed_actions': True
                                              })
        env = atari_wrapper.AtariWrapper(environment=env, num_stacked_frames=1)

        task = rl_task.RLTask(env,
                              initial_trajectories=20,
                              dm_suite=True,
                              max_steps=200)

        body = lambda mode: atari_cnn.AtariCnnBody()

        policy_model = functools.partial(models.Policy, body=body)
        value_model = functools.partial(models.Value, body=body)

        # pylint: disable=g-long-lambda
        lr_value = lambda h: lr_schedules.MultifactorSchedule(
            h,
            constant=3e-3,
            warmup_steps=100,
            factors='constant * linear_warmup')
        lr_policy = lambda h: lr_schedules.MultifactorSchedule(
            h,
            constant=1e-3,
            warmup_steps=100,
            factors='constant * linear_warmup')
        # pylint: enable=g-long-lambda

        trainer = actor_critic.PPOTrainer(task,
                                          n_shared_layers=0,
                                          value_model=value_model,
                                          value_optimizer=opt.Adam,
                                          value_lr_schedule=lr_value,
                                          value_batch_size=1,
                                          value_train_steps_per_epoch=1,
                                          policy_model=policy_model,
                                          policy_optimizer=opt.Adam,
                                          policy_lr_schedule=lr_policy,
                                          policy_batch_size=1,
                                          policy_train_steps_per_epoch=1,
                                          collect_per_epoch=10)
        trainer.run(2)
        # Make sure that we test everywhere at least for 2 epochs, beucase
        # the first epoch is different
        self.assertEqual(2, trainer.current_epoch)