Exemple #1
0
CCPPOTFPolicy = PPOTFPolicy.with_updates(
    name="CCPPOTFPolicy",
    postprocess_fn=centralized_critic_postprocessing,
    loss_fn=loss_with_central_critic,
    before_loss_init=setup_mixins,
    grad_stats_fn=central_vf_stats,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        CentralizedValueMixin
    ])

CCPPOTorchPolicy = PPOTorchPolicy.with_updates(
    name="CCPPOTorchPolicy",
    postprocess_fn=centralized_critic_postprocessing,
    loss_fn=loss_with_central_critic,
    before_init=setup_mixins,
    mixins=[
        TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
        CentralizedValueMixin
    ])


def get_policy_class(config):
    return CCPPOTorchPolicy if config["use_pytorch"] else CCPPOTFPolicy


CCTrainer = PPOTrainer.with_updates(
    name="CCPPOTrainer",
    default_policy=CCPPOTFPolicy,
    get_policy_class=get_policy_class,
)
Exemple #2
0
    def test_ppo_loss_function(self):
        """Tests the PPO loss function math."""
        config = copy.deepcopy(ppo.DEFAULT_CONFIG)
        config["num_workers"] = 0  # Run locally.
        config["gamma"] = 0.99
        config["model"]["fcnet_hiddens"] = [10]
        config["model"]["fcnet_activation"] = "linear"
        config["model"]["vf_share_layers"] = True

        for fw, sess in framework_iterator(config, session=True):
            trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
            policy = trainer.get_policy()

            # Check no free log std var by default.
            if fw == "torch":
                matching = [
                    v for (n, v) in policy.model.named_parameters()
                    if "log_std" in n
                ]
            else:
                matching = [
                    v for v in policy.model.trainable_variables()
                    if "log_std" in str(v)
                ]
            assert len(matching) == 0, matching

            # Post-process (calculate simple (non-GAE) advantages) and attach
            # to train_batch dict.
            # A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
            # [0.50005, -0.505, 0.5]
            train_batch = compute_gae_for_sample_batch(policy,
                                                       FAKE_BATCH.copy())
            if fw == "torch":
                train_batch = policy._lazy_tensor_dict(train_batch)

            # Check Advantage values.
            check(train_batch[Postprocessing.VALUE_TARGETS],
                  [0.50005, -0.505, 0.5])

            # Calculate actual PPO loss.
            if fw in ["tf2", "tfe"]:
                ppo_surrogate_loss_tf(policy, policy.model, Categorical,
                                      train_batch)
            elif fw == "torch":
                PPOTorchPolicy.loss(policy, policy.model, policy.dist_class,
                                    train_batch)

            vars = policy.model.variables() if fw != "torch" else \
                list(policy.model.parameters())
            if fw == "tf":
                vars = policy.get_session().run(vars)
            expected_shared_out = fc(train_batch[SampleBatch.CUR_OBS],
                                     vars[0 if fw != "torch" else 2],
                                     vars[1 if fw != "torch" else 3],
                                     framework=fw)
            expected_logits = fc(expected_shared_out,
                                 vars[2 if fw != "torch" else 0],
                                 vars[3 if fw != "torch" else 1],
                                 framework=fw)
            expected_value_outs = fc(expected_shared_out,
                                     vars[4],
                                     vars[5],
                                     framework=fw)

            kl, entropy, pg_loss, vf_loss, overall_loss = \
                self._ppo_loss_helper(
                    policy, policy.model,
                    Categorical if fw != "torch" else TorchCategorical,
                    train_batch,
                    expected_logits, expected_value_outs,
                    sess=sess
                )
            if sess:
                policy_sess = policy.get_session()
                k, e, pl, v, tl = policy_sess.run(
                    [
                        policy._mean_kl_loss,
                        policy._mean_entropy,
                        policy._mean_policy_loss,
                        policy._mean_vf_loss,
                        policy._total_loss,
                    ],
                    feed_dict=policy._get_loss_inputs_dict(train_batch,
                                                           shuffle=False))
                check(k, kl)
                check(e, entropy)
                check(pl, np.mean(-pg_loss))
                check(v, np.mean(vf_loss), decimals=4)
                check(tl, overall_loss, decimals=4)
            elif fw == "torch":
                check(policy.model.tower_stats["mean_kl_loss"], kl)
                check(policy.model.tower_stats["mean_entropy"], entropy)
                check(policy.model.tower_stats["mean_policy_loss"],
                      np.mean(-pg_loss))
                check(policy.model.tower_stats["mean_vf_loss"],
                      np.mean(vf_loss),
                      decimals=4)
                check(policy.model.tower_stats["total_loss"],
                      overall_loss,
                      decimals=4)
            else:
                check(policy._mean_kl_loss, kl)
                check(policy._mean_entropy, entropy)
                check(policy._mean_policy_loss, np.mean(-pg_loss))
                check(policy._mean_vf_loss, np.mean(vf_loss), decimals=4)
                check(policy._total_loss, overall_loss, decimals=4)
            trainer.stop()
Exemple #3
0
 def __init__(self, observation_space, action_space, config):
     PPOTorchPolicy.__init__(self, observation_space, action_space, config)
     CentralizedValueMixin.__init__(self)
Exemple #4
0
        embed_dim=config["embed_dim"],
        encoder_type=config["encoder_type"])
    # make action output distrib
    action_dist_class = get_dist_class(config, action_space)
    return policy.model, action_dist_class


def get_dist_class(config, action_space):
    if isinstance(action_space, Discrete):
        return TorchCategorical
    else:
        if config["normalize_actions"]:
            return TorchSquashedGaussian if \
                not config["_use_beta_distribution"] else TorchBeta
        else:
            return TorchDiagGaussian


#######################################################################################################
#####################################   Policy   #####################################################
#######################################################################################################

# hack to avoid cycle imports
import algorithms.baselines.ppo.ppo_trainer

BaselinePPOTorchPolicy = PPOTorchPolicy.with_updates(
    name="BaselinePPOTorchPolicy",
    make_model_and_action_dist=build_ppo_model_and_action_dist,
    # get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
    get_default_config=lambda: algorithms.baselines.ppo.ppo_trainer.PPO_CONFIG)
Exemple #5
0
    # Add spatial CAPS loss to the report
    if policy.config["symmetric_policy_reg"] > 0.0:
        stats_dict["symmetry"] = policy._mean_symmetric_policy_loss
    if policy.config["caps_temporal_reg"] > 0.0:
        stats_dict["temporal_smoothness"] = policy._mean_temporal_caps_loss
    if policy.config["caps_spatial_reg"] > 0.0:
        stats_dict["spatial_smoothness"] = policy._mean_spatial_caps_loss
    if policy.config["caps_global_reg"] > 0.0:
        stats_dict["global_smoothness"] = policy._mean_global_caps_loss

    return stats_dict


PPOTorchPolicy = PPOTorchPolicy.with_updates(
    before_loss_init=ppo_init,
    loss_fn=ppo_loss,
    stats_fn=ppo_stats,
    get_default_config=lambda: DEFAULT_CONFIG,
)


def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
    """ TODO: Write documentation.
    """
    if config["framework"] == "torch":
        return PPOTorchPolicy
    return None


PPOTrainer = PPOTrainer.with_updates(default_config=DEFAULT_CONFIG,
                                     get_policy_class=get_policy_class)