Esempio n. 1
0
def loss_with_central_critic(policy, model, dist_class, train_batch):
    """Copied from PPO but optimizing the central value function"""
    CentralizedValueMixin.__init__(policy)

    logits, state = model.from_batch(train_batch)
    action_dist = dist_class(logits, model)
    policy.central_value_out = policy.model.central_value_function(
        train_batch[SampleBatch.CUR_OBS], train_batch[OTHER_AGENTS])

    policy.loss_obj = PPOLoss(policy.action_space,
                              dist_class,
                              model,
                              train_batch[Postprocessing.VALUE_TARGETS],
                              train_batch[Postprocessing.ADVANTAGES],
                              train_batch[SampleBatch.ACTIONS],
                              train_batch[BEHAVIOUR_LOGITS],
                              train_batch[ACTION_LOGP],
                              train_batch[SampleBatch.VF_PREDS],
                              action_dist,
                              policy.central_value_out,
                              policy.kl_coeff,
                              tf.ones_like(
                                  train_batch[Postprocessing.ADVANTAGES],
                                  dtype=tf.bool),
                              entropy_coeff=policy.entropy_coeff,
                              clip_param=policy.config["clip_param"],
                              vf_clip_param=policy.config["vf_clip_param"],
                              vf_loss_coeff=policy.config["vf_loss_coeff"],
                              use_gae=policy.config["use_gae"],
                              model_config=policy.config["model"])

    return policy.loss_obj.loss
Esempio n. 2
0
def loss_with_central_critic(policy, batch_tensors):
    CentralizedValueMixin.__init__(policy)

    policy.loss_obj = PPOLoss(
        policy.action_space,
        batch_tensors[Postprocessing.VALUE_TARGETS],
        batch_tensors[Postprocessing.ADVANTAGES],
        batch_tensors[SampleBatch.ACTIONS],
        batch_tensors[BEHAVIOUR_LOGITS],
        batch_tensors[SampleBatch.VF_PREDS],
        policy.action_dist,
        policy.central_value_function,
        policy.kl_coeff,
        tf.ones_like(batch_tensors[Postprocessing.ADVANTAGES], dtype=tf.bool),
        entropy_coeff=policy.entropy_coeff,
        clip_param=policy.config["clip_param"],
        vf_clip_param=policy.config["vf_clip_param"],
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        use_gae=policy.config["use_gae"],
        model_config=policy.config["model"])

    return policy.loss_obj.loss