Esempio n. 1
0
    def test_ppo_compilation_and_schedule_mixins(self):
        """Test whether PPO can be built with all frameworks."""

        # Build a PPOConfig object.
        config = (
            ppo.PPOConfig().training(
                num_sgd_iter=2,
                # Setup lr schedule for testing.
                lr_schedule=[[0, 5e-5], [128, 0.0]],
                # Set entropy_coeff to a faulty value to proof that it'll get
                # overridden by the schedule below (which is expected).
                entropy_coeff=100.0,
                entropy_coeff_schedule=[[0, 0.1], [256, 0.0]],
            ).rollouts(
                num_rollout_workers=1,
                # Test with compression.
                compress_observations=True,
            ).training(
                train_batch_size=128,
                model=dict(
                    # Settings in case we use an LSTM.
                    lstm_cell_size=10,
                    max_seq_len=20,
                ),
            ).callbacks(MyCallbacks))  # For checking lr-schedule correctness.

        num_iterations = 2

        for fw in framework_iterator(config, with_eager_tracing=True):
            for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
                print("Env={}".format(env))
                for lstm in [True, False]:
                    print("LSTM={}".format(lstm))
                    config.training(model=dict(
                        use_lstm=lstm,
                        lstm_use_prev_action=lstm,
                        lstm_use_prev_reward=lstm,
                    ))

                    trainer = config.build(env=env)
                    policy = trainer.get_policy()
                    entropy_coeff = trainer.get_policy().entropy_coeff
                    lr = policy.cur_lr
                    if fw == "tf":
                        entropy_coeff, lr = policy.get_session().run(
                            [entropy_coeff, lr])
                    check(entropy_coeff, 0.1)
                    check(lr, config.lr)

                    for i in range(num_iterations):
                        results = trainer.train()
                        check_train_results(results)
                        print(results)

                    check_compute_single_action(
                        trainer,
                        include_prev_action_reward=True,
                        include_state=lstm)
                    trainer.stop()
Esempio n. 2
0
    def test_ppo_free_log_std(self):
        """Tests the free log std option works."""
        config = (
            ppo.PPOConfig()
            .rollouts(
                num_rollout_workers=0,
            )
            .training(
                gamma=0.99,
                model=dict(
                    fcnet_hiddens=[10],
                    fcnet_activation="linear",
                    free_log_std=True,
                    vf_share_layers=True,
                ),
            )
        )

        for fw, sess in framework_iterator(config, session=True):
            trainer = ppo.PPO(config=config, env="CartPole-v0")
            policy = trainer.get_policy()

            # Check the free log std var is created.
            if fw == "torch":
                matching = [
                    v for (n, v) in policy.model.named_parameters() if "log_std" in n
                ]
            else:
                matching = [
                    v for v in policy.model.trainable_variables() if "log_std" in str(v)
                ]
            assert len(matching) == 1, matching
            log_std_var = matching[0]

            def get_value():
                if fw == "tf":
                    return policy.get_session().run(log_std_var)[0]
                elif fw == "torch":
                    return log_std_var.detach().cpu().numpy()[0]
                else:
                    return log_std_var.numpy()[0]

            # Check the variable is initially zero.
            init_std = get_value()
            assert init_std == 0.0, init_std
            batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
            if fw == "torch":
                batch = policy._lazy_tensor_dict(batch)
            policy.learn_on_batch(batch)

            # Check the variable is updated.
            post_std = get_value()
            assert post_std != 0.0, post_std
            trainer.stop()
Esempio n. 3
0
    def test_ppo_exploration_setup(self):
        """Tests, whether PPO runs with different exploration setups."""
        config = (
            ppo.PPOConfig()
            .environment(
                env_config={"is_slippery": False, "map_name": "4x4"},
            )
            .rollouts(
                # Run locally.
                num_rollout_workers=0,
            )
        )
        obs = np.array(0)

        # Test against all frameworks.
        for fw in framework_iterator(config):
            # Default Agent should be setup with StochasticSampling.
            trainer = ppo.PPO(config=config, env="FrozenLake-v1")
            # explore=False, always expect the same (deterministic) action.
            a_ = trainer.compute_single_action(
                obs, explore=False, prev_action=np.array(2), prev_reward=np.array(1.0)
            )
            # Test whether this is really the argmax action over the logits.
            if fw != "tf":
                last_out = trainer.get_policy().model.last_output()
                if fw == "torch":
                    check(a_, np.argmax(last_out.detach().cpu().numpy(), 1)[0])
                else:
                    check(a_, np.argmax(last_out.numpy(), 1)[0])
            for _ in range(50):
                a = trainer.compute_single_action(
                    obs,
                    explore=False,
                    prev_action=np.array(2),
                    prev_reward=np.array(1.0),
                )
                check(a, a_)

            # With explore=True (default), expect stochastic actions.
            actions = []
            for _ in range(300):
                actions.append(
                    trainer.compute_single_action(
                        obs, prev_action=np.array(2), prev_reward=np.array(1.0)
                    )
                )
            check(np.mean(actions), 1.5, atol=0.2)
            trainer.stop()
Esempio n. 4
0
    def run_re3(self, rl_algorithm):
        """Tests RE3 for PPO and SAC.

        Both the on-policy and off-policy setups are validated.
        """
        if rl_algorithm == "PPO":
            config = ppo.PPOConfig().to_dict()
            trainer_cls = ppo.PPO
            beta_schedule = "constant"
        elif rl_algorithm == "SAC":
            config = sac.SACConfig().to_dict()
            trainer_cls = sac.SAC
            beta_schedule = "linear_decay"

        class RE3Callbacks(RE3UpdateCallbacks, config["callbacks"]):
            pass

        config["env"] = "Pendulum-v1"
        config["callbacks"] = RE3Callbacks
        config["exploration_config"] = {
            "type": "RE3",
            "embeds_dim": 128,
            "beta_schedule": beta_schedule,
            "sub_exploration": {
                "type": "StochasticSampling",
            },
        }

        num_iterations = 30
        trainer = trainer_cls(config=config)
        learnt = False
        for i in range(num_iterations):
            result = trainer.train()
            print(result)
            if result["episode_reward_max"] > -900.0:
                print("Reached goal after {} iters!".format(i))
                learnt = True
                break
        trainer.stop()
        self.assertTrue(learnt)
Esempio n. 5
0
def _import_ppo():
    import ray.rllib.algorithms.ppo as ppo

    return ppo.PPO, ppo.PPOConfig().to_dict()
Esempio n. 6
0
    def test_ppo_loss_function(self):
        """Tests the PPO loss function math."""
        config = (ppo.PPOConfig().rollouts(num_rollout_workers=0, ).training(
            gamma=0.99,
            model=dict(
                fcnet_hiddens=[10],
                fcnet_activation="linear",
                vf_share_layers=True,
            ),
        ))

        for fw, sess in framework_iterator(config, session=True):
            trainer = ppo.PPO(config=config, env="CartPole-v0")
            policy = trainer.get_policy()

            # Check no free log std var by default.
            if fw == "torch":
                matching = [
                    v for (n, v) in policy.model.named_parameters()
                    if "log_std" in n
                ]
            else:
                matching = [
                    v for v in policy.model.trainable_variables()
                    if "log_std" in str(v)
                ]
            assert len(matching) == 0, matching

            # Post-process (calculate simple (non-GAE) advantages) and attach
            # to train_batch dict.
            # A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
            # [0.50005, -0.505, 0.5]
            train_batch = compute_gae_for_sample_batch(policy,
                                                       FAKE_BATCH.copy())
            if fw == "torch":
                train_batch = policy._lazy_tensor_dict(train_batch)

            # Check Advantage values.
            check(train_batch[Postprocessing.VALUE_TARGETS],
                  [0.50005, -0.505, 0.5])

            # Calculate actual PPO loss.
            if fw in ["tf2", "tfe"]:
                PPOTF2Policy.loss(policy, policy.model, Categorical,
                                  train_batch)
            elif fw == "torch":
                PPOTorchPolicy.loss(policy, policy.model, policy.dist_class,
                                    train_batch)

            vars = (policy.model.variables()
                    if fw != "torch" else list(policy.model.parameters()))
            if fw == "tf":
                vars = policy.get_session().run(vars)
            expected_shared_out = fc(
                train_batch[SampleBatch.CUR_OBS],
                vars[0 if fw != "torch" else 2],
                vars[1 if fw != "torch" else 3],
                framework=fw,
            )
            expected_logits = fc(
                expected_shared_out,
                vars[2 if fw != "torch" else 0],
                vars[3 if fw != "torch" else 1],
                framework=fw,
            )
            expected_value_outs = fc(expected_shared_out,
                                     vars[4],
                                     vars[5],
                                     framework=fw)

            kl, entropy, pg_loss, vf_loss, overall_loss = self._ppo_loss_helper(
                policy,
                policy.model,
                Categorical if fw != "torch" else TorchCategorical,
                train_batch,
                expected_logits,
                expected_value_outs,
                sess=sess,
            )
            if sess:
                policy_sess = policy.get_session()
                k, e, pl, v, tl = policy_sess.run(
                    [
                        policy._mean_kl_loss,
                        policy._mean_entropy,
                        policy._mean_policy_loss,
                        policy._mean_vf_loss,
                        policy._total_loss,
                    ],
                    feed_dict=policy._get_loss_inputs_dict(train_batch,
                                                           shuffle=False),
                )
                check(k, kl)
                check(e, entropy)
                check(pl, np.mean(-pg_loss))
                check(v, np.mean(vf_loss), decimals=4)
                check(tl, overall_loss, decimals=4)
            elif fw == "torch":
                check(policy.model.tower_stats["mean_kl_loss"], kl)
                check(policy.model.tower_stats["mean_entropy"], entropy)
                check(policy.model.tower_stats["mean_policy_loss"],
                      np.mean(-pg_loss))
                check(
                    policy.model.tower_stats["mean_vf_loss"],
                    np.mean(vf_loss),
                    decimals=4,
                )
                check(policy.model.tower_stats["total_loss"],
                      overall_loss,
                      decimals=4)
            else:
                check(policy._mean_kl_loss, kl)
                check(policy._mean_entropy, entropy)
                check(policy._mean_policy_loss, np.mean(-pg_loss))
                check(policy._mean_vf_loss, np.mean(vf_loss), decimals=4)
                check(policy._total_loss, overall_loss, decimals=4)
            trainer.stop()