def compute_q_values( policy: Policy, model: ModelV2, input_dict, state_batches=None, seq_lens=None, explore=None, is_training: bool = False, ): config = policy.config model_out, state = model(input_dict, state_batches or [], seq_lens) if config["num_atoms"] > 1: ( action_scores, z, support_logits_per_action, logits, probs_or_logits, ) = model.get_q_value_distributions(model_out) else: (action_scores, logits, probs_or_logits) = model.get_q_value_distributions( model_out ) if config["dueling"]: state_score = model.get_state_value(model_out) if policy.config["num_atoms"] > 1: support_logits_per_action_mean = torch.mean( support_logits_per_action, dim=1 ) support_logits_per_action_centered = ( support_logits_per_action - torch.unsqueeze(support_logits_per_action_mean, dim=1) ) support_logits_per_action = ( torch.unsqueeze(state_score, dim=1) + support_logits_per_action_centered ) support_prob_per_action = nn.functional.softmax( support_logits_per_action, dim=-1 ) value = torch.sum(z * support_prob_per_action, dim=-1) logits = support_logits_per_action probs_or_logits = support_prob_per_action else: advantages_mean = reduce_mean_ignore_inf(action_scores, 1) advantages_centered = action_scores - torch.unsqueeze(advantages_mean, 1) value = state_score + advantages_centered else: value = action_scores return value, logits, probs_or_logits, state
def compute_q_values( policy: Policy, model: ModelV2, input_batch: SampleBatch, state_batches=None, seq_lens=None, explore=None, is_training: bool = False, ): config = policy.config model_out, state = model(input_batch, state_batches or [], seq_lens) if config["num_atoms"] > 1: ( action_scores, z, support_logits_per_action, logits, dist, ) = model.get_q_value_distributions(model_out) else: (action_scores, logits, dist) = model.get_q_value_distributions(model_out) if config["dueling"]: state_score = model.get_state_value(model_out) if config["num_atoms"] > 1: support_logits_per_action_mean = tf.reduce_mean( support_logits_per_action, 1 ) support_logits_per_action_centered = ( support_logits_per_action - tf.expand_dims(support_logits_per_action_mean, 1) ) support_logits_per_action = ( tf.expand_dims(state_score, 1) + support_logits_per_action_centered ) support_prob_per_action = tf.nn.softmax(logits=support_logits_per_action) value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1) logits = support_logits_per_action dist = support_prob_per_action else: action_scores_mean = reduce_mean_ignore_inf(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( action_scores_mean, 1 ) value = state_score + action_scores_centered else: value = action_scores return value, logits, dist, state
def compute_q_values(policy: Policy, model: ModelV2, obs: TensorType, explore, is_training: bool = False): config = policy.config model_out, state = model( { SampleBatch.CUR_OBS: obs, "is_training": is_training, }, [], None) if config["num_atoms"] > 1: (action_scores, z, support_logits_per_action, logits, probs_or_logits) = model.get_q_value_distributions(model_out) else: (action_scores, logits, probs_or_logits) = model.get_q_value_distributions(model_out) if config["dueling"]: state_score = model.get_state_value(model_out) if policy.config["num_atoms"] > 1: support_logits_per_action_mean = torch.mean( support_logits_per_action, dim=1) support_logits_per_action_centered = ( support_logits_per_action - torch.unsqueeze(support_logits_per_action_mean, dim=1)) support_logits_per_action = torch.unsqueeze( state_score, dim=1) + support_logits_per_action_centered support_prob_per_action = nn.functional.softmax( support_logits_per_action) value = torch.sum(z * support_prob_per_action, dim=-1) logits = support_logits_per_action probs_or_logits = support_prob_per_action else: advantages_mean = reduce_mean_ignore_inf(action_scores, 1) advantages_centered = action_scores - torch.unsqueeze( advantages_mean, 1) value = state_score + advantages_centered else: value = action_scores return value, logits, probs_or_logits
def compute_q_values(policy: Policy, model: ModelV2, obs: TensorType, explore): config = policy.config model_out, state = model( { SampleBatch.CUR_OBS: obs, "is_training": policy._get_is_training_placeholder(), }, [], None) if config["num_atoms"] > 1: (action_scores, z, support_logits_per_action, logits, dist) = model.get_q_value_distributions(model_out) else: (action_scores, logits, dist) = model.get_q_value_distributions(model_out) if config["dueling"]: state_score = model.get_state_value(model_out) if config["num_atoms"] > 1: support_logits_per_action_mean = tf.reduce_mean( support_logits_per_action, 1) support_logits_per_action_centered = ( support_logits_per_action - tf.expand_dims(support_logits_per_action_mean, 1)) support_logits_per_action = tf.expand_dims( state_score, 1) + support_logits_per_action_centered support_prob_per_action = tf.nn.softmax( logits=support_logits_per_action) value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1) logits = support_logits_per_action dist = support_prob_per_action else: action_scores_mean = reduce_mean_ignore_inf(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( action_scores_mean, 1) value = state_score + action_scores_centered else: value = action_scores return value, logits, dist