def action_distribution_fn( policy: Policy, model: ModelV2, input_dict: ModelInputDict, *, state_batches: Optional[List[TensorType]] = None, seq_lens: Optional[TensorType] = None, prev_action_batch: Optional[TensorType] = None, prev_reward_batch=None, explore: Optional[bool] = None, timestep: Optional[int] = None, is_training: Optional[bool] = None ) -> Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]: """The action distribution function to be used the algorithm. An action distribution function is used to customize the choice of action distribution class and the resulting action distribution inputs (to parameterize the distribution object). After parameterizing the distribution, a `sample()` call will be made on it to generate actions. Args: policy: The Policy being queried for actions and calling this function. model (TorchModelV2): The SAC specific model to use to generate the distribution inputs (see sac_tf|torch_model.py). Must support the `get_action_model_outputs` method. input_dict: The input-dict to be used for the model call. state_batches (Optional[List[TensorType]]): The list of internal state tensor batches. seq_lens (Optional[TensorType]): The tensor of sequence lengths used in RNNs. prev_action_batch (Optional[TensorType]): Optional batch of prev actions used by the model. prev_reward_batch (Optional[TensorType]): Optional batch of prev rewards used by the model. explore (Optional[bool]): Whether to activate exploration or not. If None, use value of `config.explore`. timestep (Optional[int]): An optional timestep. is_training (Optional[bool]): An optional is-training flag. Returns: Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]: The dist inputs, dist class, and a list of internal state outputs (in the RNN case). """ # Get base-model output (w/o the SAC specific parts of the network). model_out, _ = model(input_dict, [], None) # Use the base output to get the policy outputs from the SAC model's # policy components. action_dist_inputs, _ = model.get_action_model_outputs(model_out) # Get a distribution class to be used with the just calculated dist-inputs. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) return action_dist_inputs, action_dist_class, []
def get_distribution_inputs_and_class( policy: Policy, model: ModelV2, obs_batch: TensorType, *, explore: bool = True, **kwargs ) -> Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: """The action distribution function to be used the algorithm. An action distribution function is used to customize the choice of action distribution class and the resulting action distribution inputs (to parameterize the distribution object). After parameterizing the distribution, a `sample()` call will be made on it to generate actions. Args: policy (Policy): The Policy being queried for actions and calling this function. model (SACTFModel): The SAC specific Model to use to generate the distribution inputs (see sac_tf|torch_model.py). Must support the `get_action_model_outputs` method. obs_batch (TensorType): The observations to be used as inputs to the model. explore (bool): Whether to activate exploration or not. Returns: Tuple[TensorType, Type[TFActionDistribution], List[TensorType]]: The dist inputs, dist class, and a list of internal state outputs (in the RNN case). """ # Get base-model (forward) output (this should be a noop call). forward_out, state_out = model( SampleBatch(obs=obs_batch, _is_training=policy._get_is_training_placeholder()), [], None, ) # Use the base output to get the policy outputs from the SAC model's # policy components. distribution_inputs, _ = model.get_action_model_outputs(forward_out) # Get a distribution class to be used with the just calculated dist-inputs. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) return distribution_inputs, action_dist_class, state_out
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy: The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch: The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Look up the target model (tower) using the model tower. target_model = policy.target_models[model] # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] model_out_t, _ = model( SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True), [], None) model_out_tp1, _ = model( SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None) target_model_out_tp1, _ = target_model( SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) log_pis_t = F.log_softmax(action_dist_inputs_t, dim=-1) policy_t = torch.exp(log_pis_t) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1) log_pis_tp1 = F.log_softmax(action_dist_inputs_tp1, -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t, _ = model.get_q_values(model_out_t) # Target Q-values. q_tp1, _ = target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values(model_out_t) twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) action_dist_t = action_dist_class(action_dist_inputs_t, model) policy_t = (action_dist_t.sample() if not deterministic else action_dist_t.deterministic_sample()) log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1) action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) policy_tp1 = (action_dist_tp1.sample() if not deterministic else action_dist_tp1.deterministic_sample()) log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy in given current state. q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy, _ = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = torch.mean(torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = torch.mean( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach(), ), dim=-1, )) else: alpha_loss = -torch.mean(model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t model.tower_stats["policy_t"] = policy_t model.tower_stats["log_pis_t"] = log_pis_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss])
def sac_actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] _is_training = policy._get_is_training_placeholder() # Get the base model output from the train batch. model_out_t, _ = model( SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=_is_training), [], None, ) # Get the base model output from the next observations in the train batch. model_out_tp1, _ = model( SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training), [], None, ) # Get the target model's base outputs from the next observations in the # train batch. target_model_out_tp1, _ = policy.target_model( SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=_is_training), [], None, ) # Discrete actions case. if model.discrete: # Get all action probs directly from pi and form their logp. action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) log_pis_t = tf.nn.log_softmax(action_dist_inputs_t, -1) policy_t = tf.math.exp(log_pis_t) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1) log_pis_tp1 = tf.nn.log_softmax(action_dist_inputs_tp1, -1) policy_tp1 = tf.math.exp(log_pis_tp1) # Q-values. q_t, _ = model.get_q_values(model_out_t) # Target Q-values. q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values(model_out_t) twin_q_tp1, _ = policy.target_model.get_twin_q_values( target_model_out_tp1) q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_tp1 -= model.alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = tf.one_hot(train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1]) q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1) if policy.config["twin_q"]: twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1) q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best # Continuous actions case. else: # Sample simgle actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) action_dist_t = action_dist_class(action_dist_inputs_t, policy.model) policy_t = (action_dist_t.sample() if not deterministic else action_dist_t.deterministic_sample()) log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1) action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, policy.model) policy_tp1 = (action_dist_tp1.sample() if not deterministic else action_dist_tp1.deterministic_sample()) log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t, _ = model.get_q_values( model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values( model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)) # Q-values for current policy in given current state. q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy, _ = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = tf.reduce_min( (q_t_det_policy, twin_q_t_det_policy), axis=0) # target q network evaluation q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1, _ = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) if policy.config["twin_q"]: twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 -= model.alpha * log_pis_tp1 q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best # Compute RHS of bellman equation for the Q-loss (critic(s)). q_t_selected_target = tf.stop_gradient( tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked) # Compute the TD-error (potentially clipped). base_td_error = tf.math.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error # Calculate one or two critic losses (2 in the twin_q case). prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))] if policy.config["twin_q"]: critic_loss.append( tf.reduce_mean(prio_weights * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: alpha_loss = tf.reduce_mean( tf.reduce_sum( tf.multiply( tf.stop_gradient(policy_t), -model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy), ), axis=-1, )) actor_loss = tf.reduce_mean( tf.reduce_sum( tf.multiply( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, model.alpha * log_pis_t - tf.stop_gradient(q_t), ), axis=-1, )) else: alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)) actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy) # Save for stats function. policy.policy_t = policy_t policy.q_t = q_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.alpha_value = model.alpha policy.target_entropy = model.target_entropy # In a custom apply op we handle the losses separately, but return them # combined in one loss here. return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
def cql_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32) rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) target_model_out_tp1, _ = policy.target_model( SampleBatch(obs=next_obs, _is_training=True), [], None) action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) action_dist_t = action_dist_class(action_dist_inputs_t, model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = tf.expand_dims(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)) # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = tf.math.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q, _ = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t) min_q = tf.math.minimum(min_q, twin_q_) actor_loss = tf.reduce_mean( tf.stop_gradient(alpha) * log_pis_t - min_q) else: bc_logp = action_dist_t.logp(actions) actor_loss = tf.reduce_mean( tf.stop_gradient(alpha) * log_pis_t - bc_logp) # actor_loss = -tf.reduce_mean(bc_logp) # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) policy_tp1, _ = action_dist_tp1.sample_logp() q_t, _ = model.get_q_values(model_out_t, actions) q_t_selected = tf.squeeze(q_t, axis=-1) if twin_q: twin_q_t, _ = model.get_twin_q_values(model_out_t, actions) twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1) # Target q network evaluation. q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1, _ = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1) q_tp1_best = tf.squeeze(input=q_tp1, axis=-1) q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best # compute RHS of bellman equation q_t_target = tf.stop_gradient(rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked) # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = tf.math.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target) if twin_q: critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions, _ = policy._random_action_generator.get_exploration_action( action_distribution=action_dist_class( tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model), timestep=0, explore=True, ) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, model_out_t, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, model_out_tp1, num_actions) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**int(curr_actions.shape[-1])) cat_q1 = tf.concat( [ q1_rand - random_density, q1_next_actions - tf.stop_gradient(next_logp), q1_curr_actions - tf.stop_gradient(curr_logp), ], 1, ) if twin_q: cat_q2 = tf.concat( [ q2_rand - random_density, q2_next_actions - tf.stop_gradient(next_logp), q2_curr_actions - tf.stop_gradient(curr_logp), ], 1, ) min_qf1_loss_ = ( tf.reduce_mean(tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1)) * min_q_weight * cql_temp) min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight) if twin_q: min_qf2_loss_ = ( tf.reduce_mean(tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1)) * min_q_weight * cql_temp) min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) * min_q_weight) if use_lagrange: alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0, 1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) # Save for stats function. policy.q_t = q_t_selected policy.policy_t = policy_t policy.log_pis_t = log_pis_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss # Return all loss terms corresponding to our optimizers. if use_lagrange: return actor_loss + tf.math.add_n( critic_loss) + alpha_loss + alpha_prime_loss return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = policy.target_models[model] # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches seq_lens = train_batch.get(SampleBatch.SEQ_LENS) model_out_t, state_in_t = model( SampleBatch( obs=train_batch[SampleBatch.CUR_OBS], prev_actions=train_batch[SampleBatch.PREV_ACTIONS], prev_rewards=train_batch[SampleBatch.PREV_REWARDS], _is_training=True, ), state_batches, seq_lens, ) states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"]) model_out_tp1, state_in_tp1 = model( SampleBatch( obs=train_batch[SampleBatch.NEXT_OBS], prev_actions=train_batch[SampleBatch.ACTIONS], prev_rewards=train_batch[SampleBatch.REWARDS], _is_training=True, ), state_batches, seq_lens, ) states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) target_model_out_tp1, target_state_in_tp1 = target_model( SampleBatch( obs=train_batch[SampleBatch.NEXT_OBS], prev_actions=train_batch[SampleBatch.ACTIONS], prev_rewards=train_batch[SampleBatch.REWARDS], _is_training=True, ), state_batches, seq_lens, ) target_states_in_tp1 = target_model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. action_dist_inputs_t, _ = model.get_action_model_outputs( model_out_t, states_in_t["policy"], seq_lens) log_pis_t = F.log_softmax( action_dist_inputs_t, dim=-1, ) policy_t = torch.exp(log_pis_t) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1, states_in_tp1["policy"], seq_lens) log_pis_tp1 = F.log_softmax( action_dist_inputs_tp1, -1, ) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens) # Target Q-values. q_tp1, _ = target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values(model_out_t, states_in_t["twin_q"], seq_lens) twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs( model_out_t, states_in_t["policy"], seq_lens) action_dist_t = action_dist_class( action_dist_inputs_t, model, ) policy_t = (action_dist_t.sample() if not deterministic else action_dist_t.deterministic_sample()) log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_inputs_t, _ = model.get_action_model_outputs( model_out_tp1, states_in_tp1["policy"], seq_lens) action_dist_tp1 = action_dist_class( action_dist_inputs_t, model, ) policy_tp1 = (action_dist_tp1.sample() if not deterministic else action_dist_tp1.deterministic_sample()) log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, train_batch[SampleBatch.ACTIONS], ) # Q-values for current policy in given current state. q_t_det_policy, _ = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy, _ = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, policy_t) q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens, policy_tp1) if policy.config["twin_q"]: twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens, policy_tp1, ) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # BURNIN # B = state_batches[0].shape[0] T = q_t_selected.shape[0] // B seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T) # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False seq_mask = seq_mask.reshape(-1) num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) td_error = td_error * seq_mask # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = reduce_mean_valid( torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = reduce_mean_valid( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach(), ), dim=-1, )) else: alpha_loss = -reduce_mean_valid( model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t * seq_mask[..., None] model.tower_stats["policy_t"] = policy_t * seq_mask[..., None] model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None] model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss # Store per time chunk (b/c we need only one mean # prioritized replay weight per stored sequence). model.tower_stats["td_error"] = torch.mean(td_error.reshape([-1, T]), dim=-1) # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss])
def cql_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # Look up the target model (tower) using the model tower. target_model = policy.target_models[model] # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] action_low = model.action_space.low[0] action_high = model.action_space.high[0] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = train_batch[SampleBatch.ACTIONS] rewards = train_batch[SampleBatch.REWARDS].float() next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) target_model_out_tp1, _ = target_model( SampleBatch(obs=next_obs, _is_training=True), [], None) action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) action_dist_t = action_dist_class(action_dist_inputs_t, model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = torch.unsqueeze(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean() batch_size = tree.flatten(obs)[0].shape[0] if batch_size == policy.config["train_batch_size"]: policy.alpha_optim.zero_grad() alpha_loss.backward() policy.alpha_optim.step() # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = torch.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q, _ = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t) min_q = torch.min(min_q, twin_q_) actor_loss = (alpha.detach() * log_pis_t - min_q).mean() else: bc_logp = action_dist_t.logp(actions) actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean() # actor_loss = -bc_logp.mean() if batch_size == policy.config["train_batch_size"]: policy.actor_optim.zero_grad() actor_loss.backward(retain_graph=True) policy.actor_optim.step() # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) policy_tp1, _ = action_dist_tp1.sample_logp() q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) q_t_selected = torch.squeeze(q_t, dim=-1) if twin_q: twin_q_t, _ = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) # Target q network evaluation. q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1, _ = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best # compute RHS of bellman equation q_t_target = ( rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = torch.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = torch.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target) if twin_q: critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions = convert_to_torch_tensor( torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(action_low, action_high), policy.device, ) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, model_out_t, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, model_out_tp1, num_actions) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**curr_actions.shape[-1]) cat_q1 = torch.cat( [ q1_rand - random_density, q1_next_actions - next_logp.detach(), q1_curr_actions - curr_logp.detach(), ], 1, ) if twin_q: cat_q2 = torch.cat( [ q2_rand - random_density, q2_next_actions - next_logp.detach(), q2_curr_actions - curr_logp.detach(), ], 1, ) min_qf1_loss_ = (torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp) min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight) if twin_q: min_qf2_loss_ = (torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp) min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight) if use_lagrange: alpha_prime = torch.clamp(model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) if batch_size == policy.config["train_batch_size"]: policy.critic_optims[0].zero_grad() critic_loss[0].backward(retain_graph=True) policy.critic_optims[0].step() if twin_q: policy.critic_optims[1].zero_grad() critic_loss[1].backward(retain_graph=False) policy.critic_optims[1].step() # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. # SAC stats. model.tower_stats["q_t"] = q_t_selected model.tower_stats["policy_t"] = policy_t model.tower_stats["log_pis_t"] = log_pis_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss model.tower_stats["log_alpha_value"] = model.log_alpha model.tower_stats["alpha_value"] = alpha model.tower_stats["target_entropy"] = model.target_entropy # CQL stats. model.tower_stats["cql_loss"] = cql_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error if use_lagrange: model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0] model.tower_stats["alpha_prime_value"] = alpha_prime model.tower_stats["alpha_prime_loss"] = alpha_prime_loss if batch_size == policy.config["train_batch_size"]: policy.alpha_prime_optim.zero_grad() alpha_prime_loss.backward() policy.alpha_prime_optim.step() # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss] + ([alpha_prime_loss] if use_lagrange else []))