def loss_with_central_critic(policy, model, dist_class, train_batch): """Copied from PPO but optimizing the central value function""" CentralizedValueMixin.__init__(policy) logits, state = model.from_batch(train_batch) action_dist = dist_class(logits, model) policy.central_value_out = policy.model.central_value_function( train_batch[SampleBatch.CUR_OBS], train_batch[OTHER_AGENTS]) policy.loss_obj = PPOLoss(policy.action_space, dist_class, model, train_batch[Postprocessing.VALUE_TARGETS], train_batch[Postprocessing.ADVANTAGES], train_batch[SampleBatch.ACTIONS], train_batch[BEHAVIOUR_LOGITS], train_batch[ACTION_LOGP], train_batch[SampleBatch.VF_PREDS], action_dist, policy.central_value_out, policy.kl_coeff, tf.ones_like( train_batch[Postprocessing.ADVANTAGES], dtype=tf.bool), entropy_coeff=policy.entropy_coeff, clip_param=policy.config["clip_param"], vf_clip_param=policy.config["vf_clip_param"], vf_loss_coeff=policy.config["vf_loss_coeff"], use_gae=policy.config["use_gae"], model_config=policy.config["model"]) return policy.loss_obj.loss
def loss_with_central_critic(policy, batch_tensors): CentralizedValueMixin.__init__(policy) policy.loss_obj = PPOLoss( policy.action_space, batch_tensors[Postprocessing.VALUE_TARGETS], batch_tensors[Postprocessing.ADVANTAGES], batch_tensors[SampleBatch.ACTIONS], batch_tensors[BEHAVIOUR_LOGITS], batch_tensors[SampleBatch.VF_PREDS], policy.action_dist, policy.central_value_function, policy.kl_coeff, tf.ones_like(batch_tensors[Postprocessing.ADVANTAGES], dtype=tf.bool), entropy_coeff=policy.entropy_coeff, clip_param=policy.config["clip_param"], vf_clip_param=policy.config["vf_clip_param"], vf_loss_coeff=policy.config["vf_loss_coeff"], use_gae=policy.config["use_gae"], model_config=policy.config["model"]) return policy.loss_obj.loss