Ejemplo n.º 1
0
def compute_q_values(policy: Policy,
                     model: ModelV2,
                     input_dict,
                     state_batches=None,
                     seq_lens=None,
                     explore=None,
                     is_training: bool = False):

    config = policy.config

    input_dict["is_training"] = policy._get_is_training_placeholder()
    model_out, state = model(input_dict, state_batches or [], seq_lens)

    if config["num_atoms"] > 1:
        (action_scores, z, support_logits_per_action, logits,
         dist) = model.get_q_value_distributions(model_out)
    else:
        (action_scores, logits,
         dist) = model.get_q_value_distributions(model_out)

    if config["dueling"]:
        state_score = model.get_state_value(model_out)
        if config["num_atoms"] > 1:
            support_logits_per_action_mean = tf.reduce_mean(
                support_logits_per_action, 1)
            support_logits_per_action_centered = (
                support_logits_per_action -
                tf.expand_dims(support_logits_per_action_mean, 1))
            support_logits_per_action = tf.expand_dims(
                state_score, 1) + support_logits_per_action_centered
            support_prob_per_action = tf.nn.softmax(
                logits=support_logits_per_action)
            value = tf.reduce_sum(input_tensor=z * support_prob_per_action,
                                  axis=-1)
            logits = support_logits_per_action
            dist = support_prob_per_action
        else:
            action_scores_mean = reduce_mean_ignore_inf(action_scores, 1)
            action_scores_centered = action_scores - tf.expand_dims(
                action_scores_mean, 1)
            value = state_score + action_scores_centered
    else:
        value = action_scores

    return value, logits, dist, state
Ejemplo n.º 2
0
def _compute_q_values(policy, model, obs, neighbor_obs, obs_space,
                      action_space):
    config = policy.config
    model_out, state = model(
        {
            "obs": obs,
            "neighbor_obs": neighbor_obs,
            "is_training": policy._get_is_training_placeholder(),
        }, [], None)

    if config["num_atoms"] > 1:
        (action_scores, z, support_logits_per_action, logits,
         dist) = model.get_q_value_distributions(model_out)
    else:
        (action_scores, logits,
         dist) = model.get_q_value_distributions(model_out)

    if config["dueling"]:
        state_score = model.get_state_value(model_out)
        if config["num_atoms"] > 1:
            support_logits_per_action_mean = tf.reduce_mean(
                support_logits_per_action, 1)
            support_logits_per_action_centered = (
                support_logits_per_action -
                tf.expand_dims(support_logits_per_action_mean, 1))
            support_logits_per_action = tf.expand_dims(
                state_score, 1) + support_logits_per_action_centered
            support_prob_per_action = tf.nn.softmax(
                logits=support_logits_per_action)
            value = tf.reduce_sum(input_tensor=z * support_prob_per_action,
                                  axis=-1)
            logits = support_logits_per_action
            dist = support_prob_per_action
        else:
            action_scores_mean = reduce_mean_ignore_inf(action_scores, 1)
            action_scores_centered = action_scores - tf.expand_dims(
                action_scores_mean, 1)
            value = state_score + action_scores_centered
    else:
        value = action_scores

    return value, logits, dist
Ejemplo n.º 3
0
 def calculate_and_store_q(self, input_dict, model_out):
     if self.q_config["num_atoms"] > 1:
         (action_scores, z, support_logits_per_action, logits,
          dist) = self.get_q_value_distributions(model_out)
     else:
         (action_scores, logits,
          dist) = self.get_q_value_distributions(model_out)
     if self.q_config["dueling"]:
         state_score = self.get_state_value(model_out)
         if self.q_config["num_atoms"] > 1:
             support_logits_per_action_mean = tf.reduce_mean(
                 support_logits_per_action, 1)
             support_logits_per_action_centered = (
                 support_logits_per_action -
                 tf.expand_dims(support_logits_per_action_mean, 1))
             support_logits_per_action = tf.expand_dims(
                 state_score, 1) + support_logits_per_action_centered
             support_prob_per_action = tf.nn.softmax(
                 logits=support_logits_per_action)
             value = tf.reduce_sum(input_tensor=z * support_prob_per_action,
                                   axis=-1)
             logits = support_logits_per_action
             dist = support_prob_per_action
         else:
             action_scores_mean = reduce_mean_ignore_inf(action_scores, 1)
             action_scores_centered = action_scores - tf.expand_dims(
                 action_scores_mean, 1)
             value = state_score + action_scores_centered
     else:
         value = action_scores
     # Mask out invalid actions (use tf.float32.min for stability)
     inf_mask = tf.cast(tf.maximum(
         tf.log(input_dict["obs"]["legal_actions"]), tf.float32.min),
                        dtype=tf.float32)
     value = value + inf_mask
     self.q_out = {"value": value, "logits": logits, "dist": dist}