def compute_policy_logits(policy: Policy, model: ModelV2, obs: TensorType, is_training=None) -> TensorType: model_out, _ = model( { SampleBatch.CUR_OBS: obs, "is_training": is_training if is_training is not None else policy._get_is_training_placeholder(), }, [], None) return model_out
def get_distribution_inputs_and_class( policy: Policy, model: ModelV2, obs_batch: TensorType, *, explore: bool = True, **kwargs ) -> Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: """The action distribution function to be used the algorithm. An action distribution function is used to customize the choice of action distribution class and the resulting action distribution inputs (to parameterize the distribution object). After parameterizing the distribution, a `sample()` call will be made on it to generate actions. Args: policy (Policy): The Policy being queried for actions and calling this function. model (SACTFModel): The SAC specific Model to use to generate the distribution inputs (see sac_tf|torch_model.py). Must support the `get_action_model_outputs` method. obs_batch (TensorType): The observations to be used as inputs to the model. explore (bool): Whether to activate exploration or not. Returns: Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: The dist inputs, dist class, and a list of internal state outputs (in the RNN case). """ # Get base-model (forward) output (this should be a noop call). forward_out, state_out = model( SampleBatch(obs=obs_batch, _is_training=policy._get_is_training_placeholder()), [], None, ) # Use the base output to get the policy outputs from the SAC model's # policy components. distribution_inputs, _ = model.get_action_model_outputs(forward_out) # Get a distribution class to be used with the just calculated dist-inputs. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) return distribution_inputs, action_dist_class, state_out
def compute_q_values(policy: Policy, model: ModelV2, input_dict, state_batches=None, seq_lens=None, explore=None, is_training: bool = False): config = policy.config input_dict["is_training"] = policy._get_is_training_placeholder() model_out, state = model(input_dict, state_batches or [], seq_lens) if config["num_atoms"] > 1: (action_scores, z, support_logits_per_action, logits, dist) = model.get_q_value_distributions(model_out) else: (action_scores, logits, dist) = model.get_q_value_distributions(model_out) if config["dueling"]: state_score = model.get_state_value(model_out) if config["num_atoms"] > 1: support_logits_per_action_mean = tf.reduce_mean( support_logits_per_action, 1) support_logits_per_action_centered = ( support_logits_per_action - tf.expand_dims(support_logits_per_action_mean, 1)) support_logits_per_action = tf.expand_dims( state_score, 1) + support_logits_per_action_centered support_prob_per_action = tf.nn.softmax( logits=support_logits_per_action) value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1) logits = support_logits_per_action dist = support_prob_per_action else: action_scores_mean = reduce_mean_ignore_inf(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( action_scores_mean, 1) value = state_score + action_scores_centered else: value = action_scores return value, logits, dist, state
def compute_q_values(policy: Policy, model: ModelV2, obs: TensorType, explore): config = policy.config model_out, state = model( { SampleBatch.CUR_OBS: obs, "is_training": policy._get_is_training_placeholder(), }, [], None) if config["num_atoms"] > 1: (action_scores, z, support_logits_per_action, logits, dist) = model.get_q_value_distributions(model_out) else: (action_scores, logits, dist) = model.get_q_value_distributions(model_out) if config["dueling"]: state_score = model.get_state_value(model_out) if config["num_atoms"] > 1: support_logits_per_action_mean = tf.reduce_mean( support_logits_per_action, 1) support_logits_per_action_centered = ( support_logits_per_action - tf.expand_dims(support_logits_per_action_mean, 1)) support_logits_per_action = tf.expand_dims( state_score, 1) + support_logits_per_action_centered support_prob_per_action = tf.nn.softmax( logits=support_logits_per_action) value = tf.reduce_sum(input_tensor=z * support_prob_per_action, axis=-1) logits = support_logits_per_action dist = support_prob_per_action else: action_scores_mean = reduce_mean_ignore_inf(action_scores, 1) action_scores_centered = action_scores - tf.expand_dims( action_scores_mean, 1) value = state_score + action_scores_centered else: value = action_scores return value, logits, dist
def sac_actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] # Get the base model output from the train batch. model_out_t, _ = model( { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Get the base model output from the next observations in the train batch. model_out_tp1, _ = model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Get the target model's base outputs from the next observations in the # train batch. target_model_out_tp1, _ = policy.target_model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Discrete actions case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = tf.nn.log_softmax(model.get_policy_output(model_out_t), -1) policy_t = tf.math.exp(log_pis_t) log_pis_tp1 = tf.nn.log_softmax(model.get_policy_output(model_out_tp1), -1) policy_tp1 = tf.math.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t) # Target Q-values. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t) twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1) q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_tp1 -= model.alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = tf.one_hot(train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1]) q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1) if policy.config["twin_q"]: twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1) q_tp1_best_masked = \ (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \ q_tp1_best # Continuous actions case. else: # Sample simgle actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values( model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)) # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = tf.reduce_min( (q_t_det_policy, twin_q_t_det_policy), axis=0) # target q network evaluation q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) if policy.config["twin_q"]: twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 -= model.alpha * log_pis_tp1 q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best # Compute RHS of bellman equation for the Q-loss (critic(s)). q_t_selected_target = tf.stop_gradient( tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked) # Compute the TD-error (potentially clipped). base_td_error = tf.math.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error # Calculate one or two critic losses (2 in the twin_q case). prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))] if policy.config["twin_q"]: critic_loss.append( tf.reduce_mean(prio_weights * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: alpha_loss = tf.reduce_mean( tf.reduce_sum(tf.multiply( tf.stop_gradient(policy_t), -model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)), axis=-1)) actor_loss = tf.reduce_mean( tf.reduce_sum( tf.multiply( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, model.alpha * log_pis_t - tf.stop_gradient(q_t)), axis=-1)) else: alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)) actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy) # Save for stats function. policy.policy_t = policy_t policy.q_t = q_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.alpha_value = model.alpha policy.target_entropy = model.target_entropy # In a custom apply op we handle the losses separately, but return them # combined in one loss here. return actor_loss + tf.math.add_n(critic_loss) + alpha_loss