Example #1
0
        embed_dim=config["embed_dim"],
        encoder_type=config["encoder_type"])
    # make action output distrib
    action_dist_class = get_dist_class(config, action_space)
    return policy.model, action_dist_class


def get_dist_class(config, action_space):
    if isinstance(action_space, Discrete):
        return TorchCategorical
    else:
        if config["normalize_actions"]:
            return TorchSquashedGaussian if \
                not config["_use_beta_distribution"] else TorchBeta
        else:
            return TorchDiagGaussian


#######################################################################################################
#####################################   Policy   #####################################################
#######################################################################################################

# hack to avoid cycle imports
import algorithms.baselines.ppo.ppo_trainer

BaselinePPOTorchPolicy = PPOTorchPolicy.with_updates(
    name="BaselinePPOTorchPolicy",
    make_model_and_action_dist=build_ppo_model_and_action_dist,
    # get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
    get_default_config=lambda: algorithms.baselines.ppo.ppo_trainer.PPO_CONFIG)
Example #2
0
CCPPOTFPolicy = PPOTFPolicy.with_updates(
    name="CCPPOTFPolicy",
    postprocess_fn=centralized_critic_postprocessing,
    loss_fn=loss_with_central_critic,
    before_loss_init=setup_mixins,
    grad_stats_fn=central_vf_stats,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        CentralizedValueMixin
    ])

CCPPOTorchPolicy = PPOTorchPolicy.with_updates(
    name="CCPPOTorchPolicy",
    postprocess_fn=centralized_critic_postprocessing,
    loss_fn=loss_with_central_critic,
    before_init=setup_mixins,
    mixins=[
        TorchLR, TorchEntropyCoeffSchedule, TorchKLCoeffMixin,
        CentralizedValueMixin
    ])


def get_policy_class(config):
    return CCPPOTorchPolicy if config["use_pytorch"] else CCPPOTFPolicy


CCTrainer = PPOTrainer.with_updates(
    name="CCPPOTrainer",
    default_policy=CCPPOTFPolicy,
    get_policy_class=get_policy_class,
)
Example #3
0
    # Add spatial CAPS loss to the report
    if policy.config["symmetric_policy_reg"] > 0.0:
        stats_dict["symmetry"] = policy._mean_symmetric_policy_loss
    if policy.config["caps_temporal_reg"] > 0.0:
        stats_dict["temporal_smoothness"] = policy._mean_temporal_caps_loss
    if policy.config["caps_spatial_reg"] > 0.0:
        stats_dict["spatial_smoothness"] = policy._mean_spatial_caps_loss
    if policy.config["caps_global_reg"] > 0.0:
        stats_dict["global_smoothness"] = policy._mean_global_caps_loss

    return stats_dict


PPOTorchPolicy = PPOTorchPolicy.with_updates(
    before_loss_init=ppo_init,
    loss_fn=ppo_loss,
    stats_fn=ppo_stats,
    get_default_config=lambda: DEFAULT_CONFIG,
)


def get_policy_class(config: TrainerConfigDict) -> Optional[Type[Policy]]:
    """ TODO: Write documentation.
    """
    if config["framework"] == "torch":
        return PPOTorchPolicy
    return None


PPOTrainer = PPOTrainer.with_updates(default_config=DEFAULT_CONFIG,
                                     get_policy_class=get_policy_class)