Пример #1
0
    def test_ppo_compilation_and_schedule_mixins(self):
        """Test whether a PPOTrainer can be built with all frameworks."""

        # Build a PPOConfig object.
        config = (
            ppo.PPOConfig().training(
                num_sgd_iter=2,
                # Setup lr schedule for testing.
                lr_schedule=[[0, 5e-5], [128, 0.0]],
                # Set entropy_coeff to a faulty value to proof that it'll get
                # overridden by the schedule below (which is expected).
                entropy_coeff=100.0,
                entropy_coeff_schedule=[[0, 0.1], [256, 0.0]],
            ).rollouts(
                num_rollout_workers=1,
                # Test with compression.
                compress_observations=True,
            ).training(
                train_batch_size=128,
                model=dict(
                    # Settings in case we use an LSTM.
                    lstm_cell_size=10,
                    max_seq_len=20,
                ),
            ).callbacks(MyCallbacks))  # For checking lr-schedule correctness.

        num_iterations = 2

        for fw in framework_iterator(config, with_eager_tracing=True):
            for env in ["FrozenLake-v1", "MsPacmanNoFrameskip-v4"]:
                print("Env={}".format(env))
                for lstm in [True, False]:
                    print("LSTM={}".format(lstm))
                    config.training(model=dict(
                        use_lstm=lstm,
                        lstm_use_prev_action=lstm,
                        lstm_use_prev_reward=lstm,
                    ))

                    trainer = config.build(env=env)
                    policy = trainer.get_policy()
                    entropy_coeff = trainer.get_policy().entropy_coeff
                    lr = policy.cur_lr
                    if fw == "tf":
                        entropy_coeff, lr = policy.get_session().run(
                            [entropy_coeff, lr])
                    check(entropy_coeff, 0.1)
                    check(lr, config.lr)

                    for i in range(num_iterations):
                        results = trainer.train()
                        check_train_results(results)
                        print(results)

                    check_compute_single_action(
                        trainer,
                        include_prev_action_reward=True,
                        include_state=lstm)
                    trainer.stop()
Пример #2
0
    def test_ppo_free_log_std(self):
        """Tests the free log std option works."""
        config = (ppo.PPOConfig().rollouts(num_rollout_workers=0, ).training(
            gamma=0.99,
            model=dict(
                fcnet_hiddens=[10],
                fcnet_activation="linear",
                free_log_std=True,
                vf_share_layers=True,
            ),
        ))

        for fw, sess in framework_iterator(config, session=True):
            trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
            policy = trainer.get_policy()

            # Check the free log std var is created.
            if fw == "torch":
                matching = [
                    v for (n, v) in policy.model.named_parameters()
                    if "log_std" in n
                ]
            else:
                matching = [
                    v for v in policy.model.trainable_variables()
                    if "log_std" in str(v)
                ]
            assert len(matching) == 1, matching
            log_std_var = matching[0]

            def get_value():
                if fw == "tf":
                    return policy.get_session().run(log_std_var)[0]
                elif fw == "torch":
                    return log_std_var.detach().cpu().numpy()[0]
                else:
                    return log_std_var.numpy()[0]

            # Check the variable is initially zero.
            init_std = get_value()
            assert init_std == 0.0, init_std
            batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy())
            if fw == "torch":
                batch = policy._lazy_tensor_dict(batch)
            policy.learn_on_batch(batch)

            # Check the variable is updated.
            post_std = get_value()
            assert post_std != 0.0, post_std
            trainer.stop()
Пример #3
0
    def test_ppo_exploration_setup(self):
        """Tests, whether PPO runs with different exploration setups."""
        config = (
            ppo.PPOConfig().environment(env_config={
                "is_slippery": False,
                "map_name": "4x4"
            }, ).rollouts(
                # Run locally.
                num_rollout_workers=0, ))
        obs = np.array(0)

        # Test against all frameworks.
        for fw in framework_iterator(config):
            # Default Agent should be setup with StochasticSampling.
            trainer = ppo.PPOTrainer(config=config, env="FrozenLake-v1")
            # explore=False, always expect the same (deterministic) action.
            a_ = trainer.compute_single_action(obs,
                                               explore=False,
                                               prev_action=np.array(2),
                                               prev_reward=np.array(1.0))
            # Test whether this is really the argmax action over the logits.
            if fw != "tf":
                last_out = trainer.get_policy().model.last_output()
                if fw == "torch":
                    check(a_, np.argmax(last_out.detach().cpu().numpy(), 1)[0])
                else:
                    check(a_, np.argmax(last_out.numpy(), 1)[0])
            for _ in range(50):
                a = trainer.compute_single_action(
                    obs,
                    explore=False,
                    prev_action=np.array(2),
                    prev_reward=np.array(1.0),
                )
                check(a, a_)

            # With explore=True (default), expect stochastic actions.
            actions = []
            for _ in range(300):
                actions.append(
                    trainer.compute_single_action(obs,
                                                  prev_action=np.array(2),
                                                  prev_reward=np.array(1.0)))
            check(np.mean(actions), 1.5, atol=0.2)
            trainer.stop()
Пример #4
0
    def test_ppo_loss_function(self):
        """Tests the PPO loss function math."""
        config = (ppo.PPOConfig().rollouts(num_rollout_workers=0, ).training(
            gamma=0.99,
            model=dict(
                fcnet_hiddens=[10],
                fcnet_activation="linear",
                vf_share_layers=True,
            ),
        ))

        for fw, sess in framework_iterator(config, session=True):
            trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
            policy = trainer.get_policy()

            # Check no free log std var by default.
            if fw == "torch":
                matching = [
                    v for (n, v) in policy.model.named_parameters()
                    if "log_std" in n
                ]
            else:
                matching = [
                    v for v in policy.model.trainable_variables()
                    if "log_std" in str(v)
                ]
            assert len(matching) == 0, matching

            # Post-process (calculate simple (non-GAE) advantages) and attach
            # to train_batch dict.
            # A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
            # [0.50005, -0.505, 0.5]
            train_batch = compute_gae_for_sample_batch(policy,
                                                       FAKE_BATCH.copy())
            if fw == "torch":
                train_batch = policy._lazy_tensor_dict(train_batch)

            # Check Advantage values.
            check(train_batch[Postprocessing.VALUE_TARGETS],
                  [0.50005, -0.505, 0.5])

            # Calculate actual PPO loss.
            if fw in ["tf2", "tfe"]:
                ppo_surrogate_loss_tf(policy, policy.model, Categorical,
                                      train_batch)
            elif fw == "torch":
                PPOTorchPolicy.loss(policy, policy.model, policy.dist_class,
                                    train_batch)

            vars = (policy.model.variables()
                    if fw != "torch" else list(policy.model.parameters()))
            if fw == "tf":
                vars = policy.get_session().run(vars)
            expected_shared_out = fc(
                train_batch[SampleBatch.CUR_OBS],
                vars[0 if fw != "torch" else 2],
                vars[1 if fw != "torch" else 3],
                framework=fw,
            )
            expected_logits = fc(
                expected_shared_out,
                vars[2 if fw != "torch" else 0],
                vars[3 if fw != "torch" else 1],
                framework=fw,
            )
            expected_value_outs = fc(expected_shared_out,
                                     vars[4],
                                     vars[5],
                                     framework=fw)

            kl, entropy, pg_loss, vf_loss, overall_loss = self._ppo_loss_helper(
                policy,
                policy.model,
                Categorical if fw != "torch" else TorchCategorical,
                train_batch,
                expected_logits,
                expected_value_outs,
                sess=sess,
            )
            if sess:
                policy_sess = policy.get_session()
                k, e, pl, v, tl = policy_sess.run(
                    [
                        policy._mean_kl_loss,
                        policy._mean_entropy,
                        policy._mean_policy_loss,
                        policy._mean_vf_loss,
                        policy._total_loss,
                    ],
                    feed_dict=policy._get_loss_inputs_dict(train_batch,
                                                           shuffle=False),
                )
                check(k, kl)
                check(e, entropy)
                check(pl, np.mean(-pg_loss))
                check(v, np.mean(vf_loss), decimals=4)
                check(tl, overall_loss, decimals=4)
            elif fw == "torch":
                check(policy.model.tower_stats["mean_kl_loss"], kl)
                check(policy.model.tower_stats["mean_entropy"], entropy)
                check(policy.model.tower_stats["mean_policy_loss"],
                      np.mean(-pg_loss))
                check(
                    policy.model.tower_stats["mean_vf_loss"],
                    np.mean(vf_loss),
                    decimals=4,
                )
                check(policy.model.tower_stats["total_loss"],
                      overall_loss,
                      decimals=4)
            else:
                check(policy._mean_kl_loss, kl)
                check(policy._mean_entropy, entropy)
                check(policy._mean_policy_loss, np.mean(-pg_loss))
                check(policy._mean_vf_loss, np.mean(vf_loss), decimals=4)
                check(policy._total_loss, overall_loss, decimals=4)
            trainer.stop()