Ejemplo n.º 1
0

def central_vf_stats(policy, train_batch, grads):
    # Report the explained variance of the central value function.
    return {
        "vf_explained_var":
        explained_variance(train_batch[Postprocessing.VALUE_TARGETS],
                           policy.central_value_out),
    }


CCPPO = PPOTFPolicy.with_updates(
    name="CCPPO",
    postprocess_fn=centralized_critic_postprocessing,
    loss_fn=loss_with_central_critic,
    before_loss_init=setup_mixins,
    grad_stats_fn=central_vf_stats,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        CentralizedValueMixin
    ])

CCTrainer = PPOTrainer.with_updates(name="CCPPOTrainer", default_policy=CCPPO)

if __name__ == "__main__":
    args = parser.parse_args()
    ModelCatalog.register_custom_model("cc_model", CentralizedCriticModel)
    tune.run(CCTrainer,
             stop={
                 "timesteps_total": args.stop,
                 "episode_reward_mean": 7.99,
             },
Ejemplo n.º 2
0
def setup_mixins_dice(policy, obs_space, action_space, config):
    setup_mixins(policy, obs_space, action_space, config)
    DiversityValueNetworkMixin.__init__(policy, obs_space, action_space,
                                        config)
    discrete = isinstance(action_space, gym.spaces.Discrete)
    ComputeDiversityMixin.__init__(policy, discrete)


def setup_late_mixins(policy, obs_space, action_space, config):
    if config[DELAY_UPDATE]:
        TargetNetworkMixin.__init__(policy, obs_space, action_space, config)


DiCEPolicy = PPOTFPolicy.with_updates(
    name="DiCEPolicy",
    get_default_config=lambda: dice_default_config,
    postprocess_fn=postprocess_dice,
    loss_fn=dice_loss,
    stats_fn=kl_and_loss_stats_modified,
    gradients_fn=dice_gradient,
    grad_stats_fn=grad_stats_fn,
    extra_action_fetches_fn=additional_fetches,
    before_loss_init=setup_mixins_dice,
    after_init=setup_late_mixins,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        ValueNetworkMixin, DiversityValueNetworkMixin, ComputeDiversityMixin,
        TargetNetworkMixin
    ])
Ejemplo n.º 3
0
fc_with_mask_model_config = {
    "model": {
        "custom_model": "fc_with_mask",
        "custom_options": {}
    }
}

ppo_agent_default_config_with_mask = merge_dicts(DEFAULT_CONFIG,
                                                 fc_with_mask_model_config)

PPOTFPolicyWithMask = PPOTFPolicy.with_updates(
    name="PPOTFPolicyWithMask",
    get_default_config=lambda: ppo_agent_default_config_with_mask,
    extra_action_fetches_fn=vf_preds_and_logits_fetches_new,
    before_loss_init=setup_mixins,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        ValueNetworkMixin, AddMaskInfoMixinForPolicy
    ])


class AddMaskInfoMixin(object):
    def get_mask_info(self):
        return self.get_mask()

    def get_mask(self):
        return self.get_policy().get_mask()

    def set_mask(self, mask_dict):
        # Check the input is correct.
Ejemplo n.º 4
0
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy, postprocess_ppo_gae


def my_postprocess_ppo_gae(policy, sample_batch, *args, **kwargs):
    if sample_batch.get('infos') is not None:
        idx = [i for i, x in enumerate(sample_batch['infos']) if x['done']]
        if idx:
            idx.append(sample_batch.count)
            sbatch = sample_batch.slice(0, idx[0] + 1)
            sbatch['dones'][-1] = True
            batch = postprocess_ppo_gae(policy, sbatch, *args, **kwargs)
            for s, t in zip(idx[:-1], idx[1:]):
                sbatch = sample_batch.slice(s, t + 1)
                sbatch['dones'][-1] = True
                batch.concat(
                    postprocess_ppo_gae(policy, sbatch, *args, **kwargs))
            return batch
    return postprocess_ppo_gae(policy, sample_batch, *args, **kwargs)


MyPpoPolicy = PPOTFPolicy.with_updates(name="MyPpoTFPolicy",
                                       postprocess_fn=my_postprocess_ppo_gae)

MyPpoTrainer = PPOTrainer.with_updates(name="MyPpoTrainer",
                                       default_policy=MyPpoPolicy)
Ejemplo n.º 5
0
        sample_batch_size=config["sample_batch_size"],
        num_envs_per_worker=config["num_envs_per_worker"],
        train_batch_size=config["train_batch_size"],
        standardize_fields=["advantages"],
        shuffle_sequences=config["shuffle_sequences"])


def setup_mixins_modified(policy, obs_space, action_space, config):
    AddLossMixin.__init__(policy, config)
    setup_mixins(policy, obs_space, action_space, config)


ExtraLossPPOTFPolicy = PPOTFPolicy.with_updates(
    name="ExtraLossPPOTFPolicy",
    get_default_config=lambda: extra_loss_ppo_default_config,
    postprocess_fn=postprocess_ppo_gae_modified,
    stats_fn=kl_and_loss_stats_modified,
    loss_fn=extra_loss_ppo_loss,
    before_loss_init=setup_mixins_modified,
    mixins=mixin_list + [AddLossMixin])

ExtraLossPPOTrainer = PPOTrainer.with_updates(
    name="ExtraLossPPO",
    default_config=extra_loss_ppo_default_config,
    validate_config=validate_config_modified,
    default_policy=ExtraLossPPOTFPolicy,
    make_policy_optimizer=choose_policy_optimizer)

if __name__ == '__main__':
    from toolbox.marl.test_extra_loss import test_extra_loss_ppo_trainer1

    test_extra_loss_ppo_trainer1(True)
Ejemplo n.º 6
0
                self._alpha_val = 0.5
        else:
            if running_mean > 1.5 * self._novelty_target:
                self._alpha_val *= (1 - self._alpha_coefficient)
            elif running_mean < 0.5 * self._novelty_target:
                self._alpha_val = min(
                    (1 + self._alpha_coefficient) * self._alpha_val, 0.5)

        self._alpha.load(self._alpha_val, session=self.get_session())
        return self._alpha_val


DECEPolicy = PPOTFPolicy.with_updates(
    name="DECEPolicy",
    get_default_config=lambda: dece_default_config,
    postprocess_fn=postprocess_dece,
    loss_fn=loss_dece,
    stats_fn=kl_and_loss_stats_modified,
    gradients_fn=tnb_gradients,
    grad_stats_fn=grad_stats_fn,
    extra_action_fetches_fn=additional_fetches,
    before_loss_init=setup_mixins_dece,
    after_init=setup_late_mixins,
    mixins=[
        LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
        ValueNetworkMixin, NoveltyValueNetworkMixin, ComputeNoveltyMixin,
        TargetNetworkMixin, ConstrainNoveltyMixin
    ],
    get_batch_divisibility_req=get_batch_divisibility_req,
)
Ejemplo n.º 7
0
def setup_mixins(policy, obs_space, action_space, config):
    ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
    KLCoeffMixin.__init__(policy, config)
    EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
                                  config["entropy_coeff_schedule"])
    warmup_steps = config["model"]["custom_options"].get(
        "warmup_steps", 100000)
    TransformerLearningRateSchedule.__init__(
        policy, config["model"]["custom_options"]["transformer"]["num_heads"],
        warmup_steps)


TTFPPOPolicy = PPOTFPolicy.with_updates(name="TTFPPOPolicy",
                                        before_loss_init=setup_mixins,
                                        mixins=[
                                            TransformerLearningRateSchedule,
                                            EntropyCoeffSchedule, KLCoeffMixin,
                                            ValueNetworkMixin
                                        ])

TTFPPOPolicyInfer = PPOTFPolicy.with_updates(name="TTFPPOPolicyInfer",
                                             before_loss_init=setup_mixins,
                                             mixins=[
                                                 LearningRateSchedule,
                                                 EntropyCoeffSchedule,
                                                 KLCoeffMixin,
                                                 ValueNetworkMixin
                                             ])

register_trainable(
    "TTFPPO",
Ejemplo n.º 8
0
            return self.get_session().run(
                fim_embedding,
                feed_dict={self._input_dict[SampleBatch.CUR_OBS]: ob})

        self.get_fim_embedding = get_fim_embedding


def before_loss_init(policy, obs_space, action_space, config):
    setup_mixins(policy, obs_space, action_space, config)
    FIMEmbeddingMixin.__init__(policy)


PPOFIMTFPolicy = PPOTFPolicy.with_updates(name="PPOFIMTFPolicy",
                                          before_loss_init=before_loss_init,
                                          mixins=[
                                              LearningRateSchedule,
                                              EntropyCoeffSchedule,
                                              KLCoeffMixin, ValueNetworkMixin,
                                              FIMEmbeddingMixin
                                          ])


def get_policy_class(config):
    if config.get("use_pytorch") is True:
        raise NotImplementedError()
    else:
        return PPOFIMTFPolicy


PPOFIMTrainer = PPOTrainer.with_updates(
    name="PPOFIM",
    default_policy=PPOFIMTFPolicy,