def compute_q_values(policy: Policy, model: ModelV2, input_dict, state_batches=None, seq_lens=None, explore=None, is_training: bool = False): config = policy.config input_dict["is_training"] = policy._get_is_training_placeholder() model_out, state = model(input_dict, state_batches or [], seq_lens) if config["num_atoms"] > 1: (action_scores, z, support_logits_per_action, logits, dist) = model.get_q_value_distributions(model_out) else: (action_scores, logits, dist) = model.get_q_value_distributions(model_out) if config["dueling"]: state_score = model.get_state_value(model_out) if config["num_atoms"] > 1: support_logits_per_action_mean = tf.reduce_mean( support_logits_per_action, 1) support_logits_per_action_centered = ( support_logits_per_action - tf.expand_dims(support_logits_per_action_mean, 1)) support_logits_per_action = tf.expand_dims( state_score, 1) + support_logits_per_action_centered support_prob_per_action = tf.nn.softmax( logits=support_logits_per_action) value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1) logits = support_logits_per_action dist = support_prob_per_action else: action_scores_mean = reduce_mean_ignore_inf(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( action_scores_mean, 1) value = state_score + action_scores_centered else: value = action_scores return value, logits, dist, state
def _compute_q_values(policy, model, obs, neighbor_obs, obs_space, action_space): config = policy.config model_out, state = model( { "obs": obs, "neighbor_obs": neighbor_obs, "is_training": policy._get_is_training_placeholder(), }, [], None) if config["num_atoms"] > 1: (action_scores, z, support_logits_per_action, logits, dist) = model.get_q_value_distributions(model_out) else: (action_scores, logits, dist) = model.get_q_value_distributions(model_out) if config["dueling"]: state_score = model.get_state_value(model_out) if config["num_atoms"] > 1: support_logits_per_action_mean = tf.reduce_mean( support_logits_per_action, 1) support_logits_per_action_centered = ( support_logits_per_action - tf.expand_dims(support_logits_per_action_mean, 1)) support_logits_per_action = tf.expand_dims( state_score, 1) + support_logits_per_action_centered support_prob_per_action = tf.nn.softmax( logits=support_logits_per_action) value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1) logits = support_logits_per_action dist = support_prob_per_action else: action_scores_mean = reduce_mean_ignore_inf(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( action_scores_mean, 1) value = state_score + action_scores_centered else: value = action_scores return value, logits, dist
def calculate_and_store_q(self, input_dict, model_out): if self.q_config["num_atoms"] > 1: (action_scores, z, support_logits_per_action, logits, dist) = self.get_q_value_distributions(model_out) else: (action_scores, logits, dist) = self.get_q_value_distributions(model_out) if self.q_config["dueling"]: state_score = self.get_state_value(model_out) if self.q_config["num_atoms"] > 1: support_logits_per_action_mean = tf.reduce_mean( support_logits_per_action, 1) support_logits_per_action_centered = ( support_logits_per_action - tf.expand_dims(support_logits_per_action_mean, 1)) support_logits_per_action = tf.expand_dims( state_score, 1) + support_logits_per_action_centered support_prob_per_action = tf.nn.softmax( logits=support_logits_per_action) value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1) logits = support_logits_per_action dist = support_prob_per_action else: action_scores_mean = reduce_mean_ignore_inf(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( action_scores_mean, 1) value = state_score + action_scores_centered else: value = action_scores # Mask out invalid actions (use tf.float32.min for stability) inf_mask = tf.cast(tf.maximum( tf.log(input_dict["obs"]["legal_actions"]), tf.float32.min), dtype=tf.float32) value = value + inf_mask self.q_out = {"value": value, "logits": logits, "dist": dist}