CCPPOTFPolicy = PPOTFPolicy.with_updates( name="CCPPOTFPolicy", postprocess_fn=centralized_critic_postprocessing, loss_fn=loss_with_central_critic, before_loss_init=setup_mixins, grad_stats_fn=central_vf_stats, mixins=[ LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, CentralizedValueMixin ]) CCPPOTorchPolicy = PPOTorchPolicy.with_updates( name="CCPPOTorchPolicy", postprocess_fn=centralized_critic_postprocessing, loss_fn=loss_with_central_critic, before_init=setup_mixins, mixins=[ TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin, CentralizedValueMixin ]) def get_policy_class(config): return CCPPOTorchPolicy if config["use_pytorch"] else CCPPOTFPolicy CCTrainer = PPOTrainer.with_updates( name="CCPPOTrainer", default_policy=CCPPOTFPolicy, get_policy_class=get_policy_class, )
def test_ppo_loss_function(self): """Tests the PPO loss function math.""" config = copy.deepcopy(ppo.DEFAULT_CONFIG) config["num_workers"] = 0 # Run locally. config["gamma"] = 0.99 config["model"]["fcnet_hiddens"] = [10] config["model"]["fcnet_activation"] = "linear" config["model"]["vf_share_layers"] = True for fw, sess in framework_iterator(config, session=True): trainer = ppo.PPOTrainer(config=config, env="CartPole-v0") policy = trainer.get_policy() # Check no free log std var by default. if fw == "torch": matching = [ v for (n, v) in policy.model.named_parameters() if "log_std" in n ] else: matching = [ v for v in policy.model.trainable_variables() if "log_std" in str(v) ] assert len(matching) == 0, matching # Post-process (calculate simple (non-GAE) advantages) and attach # to train_batch dict. # A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] = # [0.50005, -0.505, 0.5] train_batch = compute_gae_for_sample_batch(policy, FAKE_BATCH.copy()) if fw == "torch": train_batch = policy._lazy_tensor_dict(train_batch) # Check Advantage values. check(train_batch[Postprocessing.VALUE_TARGETS], [0.50005, -0.505, 0.5]) # Calculate actual PPO loss. if fw in ["tf2", "tfe"]: ppo_surrogate_loss_tf(policy, policy.model, Categorical, train_batch) elif fw == "torch": PPOTorchPolicy.loss(policy, policy.model, policy.dist_class, train_batch) vars = policy.model.variables() if fw != "torch" else \ list(policy.model.parameters()) if fw == "tf": vars = policy.get_session().run(vars) expected_shared_out = fc(train_batch[SampleBatch.CUR_OBS], vars[0 if fw != "torch" else 2], vars[1 if fw != "torch" else 3], framework=fw) expected_logits = fc(expected_shared_out, vars[2 if fw != "torch" else 0], vars[3 if fw != "torch" else 1], framework=fw) expected_value_outs = fc(expected_shared_out, vars[4], vars[5], framework=fw) kl, entropy, pg_loss, vf_loss, overall_loss = \ self._ppo_loss_helper( policy, policy.model, Categorical if fw != "torch" else TorchCategorical, train_batch, expected_logits, expected_value_outs, sess=sess ) if sess: policy_sess = policy.get_session() k, e, pl, v, tl = policy_sess.run( [ policy._mean_kl_loss, policy._mean_entropy, policy._mean_policy_loss, policy._mean_vf_loss, policy._total_loss, ], feed_dict=policy._get_loss_inputs_dict(train_batch, shuffle=False)) check(k, kl) check(e, entropy) check(pl, np.mean(-pg_loss)) check(v, np.mean(vf_loss), decimals=4) check(tl, overall_loss, decimals=4) elif fw == "torch": check(policy.model.tower_stats["mean_kl_loss"], kl) check(policy.model.tower_stats["mean_entropy"], entropy) check(policy.model.tower_stats["mean_policy_loss"], np.mean(-pg_loss)) check(policy.model.tower_stats["mean_vf_loss"], np.mean(vf_loss), decimals=4) check(policy.model.tower_stats["total_loss"], overall_loss, decimals=4) else: check(policy._mean_kl_loss, kl) check(policy._mean_entropy, entropy) check(policy._mean_policy_loss, np.mean(-pg_loss)) check(policy._mean_vf_loss, np.mean(vf_loss), decimals=4) check(policy._total_loss, overall_loss, decimals=4) trainer.stop()
def __init__(self, observation_space, action_space, config): PPOTorchPolicy.__init__(self, observation_space, action_space, config) CentralizedValueMixin.__init__(self)
embed_dim=config["embed_dim"], encoder_type=config["encoder_type"]) # make action output distrib action_dist_class = get_dist_class(config, action_space) return policy.model, action_dist_class def get_dist_class(config, action_space): if isinstance(action_space, Discrete): return TorchCategorical else: if config["normalize_actions"]: return TorchSquashedGaussian if \ not config["_use_beta_distribution"] else TorchBeta else: return TorchDiagGaussian ####################################################################################################### ##################################### Policy ##################################################### ####################################################################################################### # hack to avoid cycle imports import algorithms.baselines.ppo.ppo_trainer BaselinePPOTorchPolicy = PPOTorchPolicy.with_updates( name="BaselinePPOTorchPolicy", make_model_and_action_dist=build_ppo_model_and_action_dist, # get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, get_default_config=lambda: algorithms.baselines.ppo.ppo_trainer.PPO_CONFIG)
# Add spatial CAPS loss to the report if policy.config["symmetric_policy_reg"] > 0.0: stats_dict["symmetry"] = policy._mean_symmetric_policy_loss if policy.config["caps_temporal_reg"] > 0.0: stats_dict["temporal_smoothness"] = policy._mean_temporal_caps_loss if policy.config["caps_spatial_reg"] > 0.0: stats_dict["spatial_smoothness"] = policy._mean_spatial_caps_loss if policy.config["caps_global_reg"] > 0.0: stats_dict["global_smoothness"] = policy._mean_global_caps_loss return stats_dict PPOTorchPolicy = PPOTorchPolicy.with_updates( before_loss_init=ppo_init, loss_fn=ppo_loss, stats_fn=ppo_stats, get_default_config=lambda: DEFAULT_CONFIG, ) def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]: """ TODO: Write documentation. """ if config["framework"] == "torch": return PPOTorchPolicy return None PPOTrainer = PPOTrainer.with_updates(default_config=DEFAULT_CONFIG, get_policy_class=get_policy_class)