policy.q_loss.stats.update({"q_loss": policy.q_loss.loss})

    loss = policy.q_model.extra_loss(policy.q_loss.loss, train_batch,
                                     policy.q_loss.stats)

    return loss


def _compute_q_values(policy, model, obs, obs_space, action_space):
    model({
        "obs": obs,
        "is_training": policy._get_is_training_placeholder(),
    }, [], None)

    q_out = model.get_q_out()

    return q_out["value"], q_out["logits"], q_out["dist"]


LegalActionDQNPolicy = DQNTFPolicy.with_updates(
    name="LegalActionDQNPolicy",
    action_sampler_fn=build_q_networks,
    loss_fn=build_q_losses)

LegalActionDQNTrainer = DQNTrainer.with_updates(
    name="LegalActionDQN", default_policy=LegalActionDQNPolicy)

LegalActionApexTrainer = LegalActionDQNTrainer.with_updates(
    name="LegalActionAPEX",
    default_config=APEX_DEFAULT_CONFIG,
    **APEX_TRAINER_PROPERTIES)
示例#2
0
            tf.shape(restored["previous_round_legal_actions"])[0] *
            restored["previous_round_legal_actions"].shape[1],
            restored["previous_round_legal_actions"].shape[2]
        ])
    target_q_model.forward(
        {
            "obs": previous_round_obs,
            "is_training": policy._get_is_training_placeholder()
        }, [], None)
    q_out = target_q_model.get_q_out()
    previous_round = tf.one_hot(tf.argmax(q_out["value"], 1),
                                policy.action_space.n)
    previous_round = tf.reshape(previous_round, [
        tf.shape(restored["previous_round"])[0],
        restored["previous_round"].shape[1], action_space.n
    ])
    return previous_round


TargetPolicyInferenceDQNPolicy = LegalActionDQNPolicy.with_updates(
    name="TargetPolicyInferenceDQNPolicy", loss_fn=build_q_losses)

TargetPolicyInferenceDQNTrainer = DQNTrainer.with_updates(
    name="TargetPolicyInferenceDQN",
    default_policy=TargetPolicyInferenceDQNPolicy)

TargetPolicyInferenceApexTrainer = TargetPolicyInferenceDQNTrainer.with_updates(
    name="TargetPolicyInferenceAPEX",
    default_config=APEX_DEFAULT_CONFIG,
    **APEX_TRAINER_PROPERTIES)