Esempio n. 1
0
def loss_with_central_critic(policy, model, dist_class, train_batch):
    CentralizedValueMixin.__init__(policy)

    logits, state = model.from_batch(train_batch)
    action_dist = dist_class(logits, model)
    policy.central_value_out = policy.model.central_value_function(
        train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS],
        train_batch[OPPONENT_ACTION])

    policy.loss_obj = PPOLoss(policy.action_space,
                              dist_class,
                              model,
                              train_batch[Postprocessing.VALUE_TARGETS],
                              train_batch[Postprocessing.ADVANTAGES],
                              train_batch[SampleBatch.ACTIONS],
                              train_batch[SampleBatch.ACTION_DIST_INPUTS],
                              train_batch[SampleBatch.ACTION_LOGP],
                              train_batch[SampleBatch.VF_PREDS],
                              action_dist,
                              policy.central_value_out,
                              policy.kl_coeff,
                              tf.ones_like(
                                  train_batch[Postprocessing.ADVANTAGES],
                                  dtype=tf.bool),
                              entropy_coeff=policy.entropy_coeff,
                              clip_param=policy.config["clip_param"],
                              vf_clip_param=policy.config["vf_clip_param"],
                              vf_loss_coeff=policy.config["vf_loss_coeff"],
                              use_gae=policy.config["use_gae"],
                              model_config=policy.config["model"])

    return policy.loss_obj.loss
def loss_with_moa(policy, model, dist_class, train_batch):
    """
    Calculate PPO loss with MOA loss
    :return: Combined PPO+MOA loss
    """
    # you need to override this bit to pull out the right bits from train_batch
    logits, state = model.from_batch(train_batch)
    action_dist = dist_class(logits, model)

    moa_loss = setup_moa_loss(logits, policy, train_batch)
    policy.moa_loss = moa_loss.total_loss

    if state:
        max_seq_len = tf.reduce_max(train_batch["seq_lens"])
        mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = tf.reshape(mask, [-1])
    else:
        mask = tf.ones_like(train_batch[Postprocessing.ADVANTAGES],
                            dtype=tf.bool)

    policy.loss_obj = PPOLoss(
        dist_class,
        model,
        train_batch[Postprocessing.VALUE_TARGETS],
        train_batch[Postprocessing.ADVANTAGES],
        train_batch[SampleBatch.ACTIONS],
        train_batch[SampleBatch.ACTION_DIST_INPUTS],
        train_batch[SampleBatch.ACTION_LOGP],
        train_batch[SampleBatch.VF_PREDS],
        action_dist,
        model.value_function(),
        policy.kl_coeff,
        mask,
        policy.entropy_coeff,
        policy.config["clip_param"],
        policy.config["vf_clip_param"],
        policy.config["vf_loss_coeff"],
        policy.config["use_gae"],
    )

    policy.loss_obj.loss += moa_loss.total_loss
    return policy.loss_obj.loss
Esempio n. 3
0
def tnb_loss(policy, model, dist_class, train_batch):
    """Add novelty loss with original ppo loss using TNB method"""
    logits, state = model.from_batch(train_batch)
    action_dist = dist_class(logits, model)

    if state:
        max_seq_len = tf.reduce_max(train_batch["seq_lens"])
        mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
        mask = tf.reshape(mask, [-1])
    else:
        mask = tf.ones_like(
            train_batch[Postprocessing.ADVANTAGES], dtype=tf.bool
        )

    policy.loss_obj = PPOLoss(
        # policy.action_space,
        dist_class,
        model,
        train_batch[Postprocessing.VALUE_TARGETS],
        train_batch[Postprocessing.ADVANTAGES],
        train_batch[SampleBatch.ACTIONS],
        train_batch[BEHAVIOUR_LOGITS],
        train_batch["action_logp"],
        train_batch[SampleBatch.VF_PREDS],
        action_dist,
        model.value_function(),
        policy.kl_coeff,
        mask,
        entropy_coeff=policy.entropy_coeff,
        clip_param=policy.config["clip_param"],
        vf_clip_param=policy.config["vf_clip_param"],
        vf_loss_coeff=policy.config["vf_loss_coeff"],
        use_gae=policy.config["use_gae"],
        # model_config=policy.config["model"]
    )

    # if policy.enable_novelty:
    if policy.config['use_novelty_value_network']:
        policy.novelty_loss_obj = PPOLoss(
            # policy.action_space,
            dist_class,
            model,
            train_batch[NOVELTY_VALUE_TARGETS],
            train_batch[NOVELTY_ADVANTAGES],
            train_batch[SampleBatch.ACTIONS],
            train_batch[BEHAVIOUR_LOGITS],
            train_batch["action_logp"],
            train_batch[NOVELTY_VALUES],
            action_dist,
            model.novelty_value_function(),
            policy.kl_coeff,
            mask,
            entropy_coeff=policy.entropy_coeff,
            clip_param=policy.config["clip_param"],
            vf_clip_param=policy.config["vf_clip_param"],
            vf_loss_coeff=policy.config["vf_loss_coeff"],
            use_gae=policy.config["use_gae"],
            # model_config=policy.config["model"]
        )
    else:
        policy.novelty_loss_obj = PPOLossNovelty(
            dist_class,
            model,
            train_batch[NOVELTY_ADVANTAGES],
            train_batch[SampleBatch.ACTIONS],
            train_batch[BEHAVIOUR_LOGITS],
            train_batch["action_logp"],
            action_dist,
            policy.kl_coeff,
            mask,
            entropy_coeff=policy.entropy_coeff,
            clip_param=policy.config["clip_param"]
        )

    policy.novelty_reward_mean = tf.reduce_mean(train_batch[NOVELTY_REWARDS])
    policy.novelty_reward_ratio = tf.reduce_mean(
        tf.cast(
            train_batch[NOVELTY_REWARDS] > policy.config['tnb_plus_threshold'],
            'float32'
        )
    )

    return [
        policy.loss_obj.loss, policy.novelty_loss_obj.loss,
        policy.novelty_reward_mean
    ]