def _get_q_value( self, model: ModelV2, model_out: TensorType, actions: TensorType ) -> TensorType: # helper function to compute the pessimistic q value q1 = model.get_q_values(model_out, actions) q2 = model.get_twin_q_values(model_out, actions) return torch.minimum(q1, q2)
def _compute_critic_loss( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ): discount = self.config["gamma"] # Compute bellman targets to regress on # target, use target model to compute the target target_model = cast(CRRModel, self.target_models[model]) target_out_next, _ = target_model( {SampleBatch.OBS: train_batch[SampleBatch.NEXT_OBS]}) # compute target values with no gradient with torch.no_grad(): # get the action of the current policy evaluated at the next state pi_s_next = dist_class( target_model.get_policy_output(target_out_next), target_model) target_a_next = pi_s_next.sample() if not self._is_action_discrete: target_a_next = target_a_next.clamp( torch.from_numpy(self.action_space.low).to(target_a_next), torch.from_numpy(self.action_space.high).to(target_a_next), ) # q1_target = target_model.get_q_values(target_out_next, target_a_next) # q2_target = target_model.get_twin_q_values(target_out_next, target_a_next) # target_q_next = torch.minimum(q1_target, q2_target).squeeze(-1) target_q_next = self._get_q_value(target_model, target_out_next, target_a_next).squeeze(-1) target = ( train_batch[SampleBatch.REWARDS] + discount * (1.0 - train_batch[SampleBatch.DONES].float()) * target_q_next) # compute the predicted output model = cast(CRRModel, model) model_out_t, _ = model({SampleBatch.OBS: train_batch[SampleBatch.OBS]}) q1 = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]).squeeze(-1) q2 = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS]).squeeze(-1) # compute the MSE loss for all q-functions loss_q1 = (target - q1)**2 loss_q2 = (target - q2)**2 loss = 0.5 * (loss_q1 + loss_q2) loss = loss.mean(0) # logging self.log("loss_q1", loss_q1.mean()) self.log("loss_q2", loss_q2.mean()) self.log("targets_avg", target.mean()) self.log("targets_max", target.max()) self.log("targets_min", target.min()) return loss
def sac_actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] # Get the base model output from the train batch. model_out_t, _ = model( { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Get the base model output from the next observations in the train batch. model_out_tp1, _ = model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Get the target model's base outputs from the next observations in the # train batch. target_model_out_tp1, _ = policy.target_model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Discrete actions case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = tf.nn.log_softmax(model.get_policy_output(model_out_t), -1) policy_t = tf.math.exp(log_pis_t) log_pis_tp1 = tf.nn.log_softmax(model.get_policy_output(model_out_tp1), -1) policy_tp1 = tf.math.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t) # Target Q-values. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t) twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1) q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_tp1 -= model.alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = tf.one_hot(train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1]) q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1) if policy.config["twin_q"]: twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1) q_tp1_best_masked = \ (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \ q_tp1_best # Continuous actions case. else: # Sample simgle actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values( model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32)) # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = tf.reduce_min( (q_t_det_policy, twin_q_t_det_policy), axis=0) # target q network evaluation q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) if policy.config["twin_q"]: twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 -= model.alpha * log_pis_tp1 q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best # Compute RHS of bellman equation for the Q-loss (critic(s)). q_t_selected_target = tf.stop_gradient( tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked) # Compute the TD-error (potentially clipped). base_td_error = tf.math.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error # Calculate one or two critic losses (2 in the twin_q case). prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))] if policy.config["twin_q"]: critic_loss.append( tf.reduce_mean(prio_weights * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: alpha_loss = tf.reduce_mean( tf.reduce_sum(tf.multiply( tf.stop_gradient(policy_t), -model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)), axis=-1)) actor_loss = tf.reduce_mean( tf.reduce_sum( tf.multiply( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, model.alpha * log_pis_t - tf.stop_gradient(q_t)), axis=-1)) else: alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)) actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy) # Save for stats function. policy.policy_t = policy_t policy.q_t = q_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.alpha_value = model.alpha policy.target_entropy = model.target_entropy # In a custom apply op we handle the losses separately, but return them # combined in one loss here. return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _, train_batch: SampleBatch) -> TensorType: twin_q = policy.config["twin_q"] gamma = policy.config["gamma"] n_step = policy.config["n_step"] use_huber = policy.config["use_huber"] huber_threshold = policy.config["huber_threshold"] l2_reg = policy.config["l2_reg"] input_dict = SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True) input_dict_next = SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True) model_out_t, _ = model(input_dict, [], None) model_out_tp1, _ = model(input_dict_next, [], None) target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None) policy.target_q_func_vars = policy.target_model.variables() # Policy network evaluation. policy_t = model.get_policy_output(model_out_t) policy_tp1 = policy.target_model.get_policy_output(target_model_out_tp1) # Action outputs. if policy.config["smooth_target_policy"]: target_noise_clip = policy.config["target_noise_clip"] clipped_normal_sample = tf.clip_by_value( tf.random.normal(tf.shape(policy_tp1), stddev=policy.config["target_noise"]), -target_noise_clip, target_noise_clip, ) policy_tp1_smoothed = tf.clip_by_value( policy_tp1 + clipped_normal_sample, policy.action_space.low * tf.ones_like(policy_tp1), policy.action_space.high * tf.ones_like(policy_tp1), ) else: # No smoothing, just use deterministic actions. policy_tp1_smoothed = policy_tp1 # Q-net(s) evaluation. # prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) # Q-values for given actions & observations in given current q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy (no noise) in given current state q_t_det_policy = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Target q-net(s) evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1_smoothed) q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) if twin_q: twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 = tf.minimum(q_tp1, twin_q_tp1) q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = ( 1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best # Compute RHS of bellman equation. q_t_selected_target = tf.stop_gradient( tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + gamma**n_step * q_tp1_best_masked) # Compute the error (potentially clipped). if twin_q: td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) + huber_loss( twin_td_error, huber_threshold) else: errors = 0.5 * tf.math.square(td_error) + 0.5 * tf.math.square( twin_td_error) else: td_error = q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) else: errors = 0.5 * tf.math.square(td_error) critic_loss = tf.reduce_mean( tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) * errors) actor_loss = -tf.reduce_mean(q_t_det_policy) # Add l2-regularization if required. if l2_reg is not None: for var in policy.model.policy_variables(): if "bias" not in var.name: actor_loss += l2_reg * tf.nn.l2_loss(var) for var in policy.model.q_variables(): if "bias" not in var.name: critic_loss += l2_reg * tf.nn.l2_loss(var) # Model self-supervised losses. if policy.config["use_state_preprocessor"]: # Expand input_dict in case custom_loss' need them. input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS] input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] if log_once("ddpg_custom_loss"): logger.warning( "You are using a state-preprocessor with DDPG and " "therefore, `custom_loss` will be called on your Model! " "Please be aware that DDPG now uses the ModelV2 API, which " "merges all previously separate sub-models (policy_model, " "q_model, and twin_q_model) into one ModelV2, on which " "`custom_loss` is called, passing it " "[actor_loss, critic_loss] as 1st argument. " "You may have to change your custom loss function to handle " "this.") [actor_loss, critic_loss] = model.custom_loss([actor_loss, critic_loss], input_dict) # Store values for stats function. policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.td_error = td_error policy.q_t = q_t # Return one loss value (even though we treat them separately in our # 2 optimizers: actor and critic). return policy.critic_loss + policy.actor_loss
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy: The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch: The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Look up the target model (tower) using the model tower. target_model = policy.target_models[model] # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] model_out_t, _ = model( SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True), [], None) model_out_tp1, _ = model( SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None) target_model_out_tp1, _ = target_model( SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) log_pis_t = F.log_softmax(action_dist_inputs_t, dim=-1) policy_t = torch.exp(log_pis_t) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1) log_pis_tp1 = F.log_softmax(action_dist_inputs_tp1, -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t, _ = model.get_q_values(model_out_t) # Target Q-values. q_tp1, _ = target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values(model_out_t) twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) action_dist_t = action_dist_class(action_dist_inputs_t, model) policy_t = (action_dist_t.sample() if not deterministic else action_dist_t.deterministic_sample()) log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_inputs_tp1, _ = model.get_action_model_outputs( model_out_tp1) action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) policy_tp1 = (action_dist_tp1.sample() if not deterministic else action_dist_tp1.deterministic_sample()) log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t, _ = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t, _ = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy in given current state. q_t_det_policy, _ = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy, _ = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1, _ = target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1, _ = target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = ( 1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = torch.mean(torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = torch.mean( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach(), ), dim=-1, )) else: alpha_loss = -torch.mean(model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t model.tower_stats["policy_t"] = policy_t model.tower_stats["log_pis_t"] = log_pis_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss])
def cql_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # Look up the target model (tower) using the model tower. target_model = policy.target_models[model] # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] action_low = model.action_space.low[0] action_high = model.action_space.high[0] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = train_batch[SampleBatch.ACTIONS] rewards = train_batch[SampleBatch.REWARDS].float() next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) target_model_out_tp1, _ = target_model( SampleBatch(obs=next_obs, _is_training=True), [], None) action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), policy.model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = torch.unsqueeze(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean() batch_size = tree.flatten(obs)[0].shape[0] if batch_size == policy.config["train_batch_size"]: policy.alpha_optim.zero_grad() alpha_loss.backward() policy.alpha_optim.step() # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = torch.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_ = model.get_twin_q_values(model_out_t, policy_t) min_q = torch.min(min_q, twin_q_) actor_loss = (alpha.detach() * log_pis_t - min_q).mean() else: bc_logp = action_dist_t.logp(actions) actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean() # actor_loss = -bc_logp.mean() if batch_size == policy.config["train_batch_size"]: policy.actor_optim.zero_grad() actor_loss.backward(retain_graph=True) policy.actor_optim.step() # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_tp1 = action_dist_class(model.get_policy_output(model_out_tp1), policy.model) policy_tp1, _ = action_dist_tp1.sample_logp() q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) q_t_selected = torch.squeeze(q_t, dim=-1) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) # Target q network evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1 = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best # compute RHS of bellman equation q_t_target = ( rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = torch.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = torch.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target) if twin_q: critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions = convert_to_torch_tensor( torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(action_low, action_high), policy.device, ) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, model_out_t, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, model_out_tp1, num_actions) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**curr_actions.shape[-1]) cat_q1 = torch.cat( [ q1_rand - random_density, q1_next_actions - next_logp.detach(), q1_curr_actions - curr_logp.detach(), ], 1, ) if twin_q: cat_q2 = torch.cat( [ q2_rand - random_density, q2_next_actions - next_logp.detach(), q2_curr_actions - curr_logp.detach(), ], 1, ) min_qf1_loss_ = (torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp) min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight) if twin_q: min_qf2_loss_ = (torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp) min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight) if use_lagrange: alpha_prime = torch.clamp(model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) if batch_size == policy.config["train_batch_size"]: policy.critic_optims[0].zero_grad() critic_loss[0].backward(retain_graph=True) policy.critic_optims[0].step() if twin_q: policy.critic_optims[1].zero_grad() critic_loss[1].backward(retain_graph=False) policy.critic_optims[1].step() # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. # SAC stats. model.tower_stats["q_t"] = q_t_selected model.tower_stats["policy_t"] = policy_t model.tower_stats["log_pis_t"] = log_pis_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss model.tower_stats["log_alpha_value"] = model.log_alpha model.tower_stats["alpha_value"] = alpha model.tower_stats["target_entropy"] = model.target_entropy # CQL stats. model.tower_stats["cql_loss"] = cql_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error if use_lagrange: model.tower_stats["log_alpha_prime_value"] = model.log_alpha_prime[0] model.tower_stats["alpha_prime_value"] = alpha_prime model.tower_stats["alpha_prime_loss"] = alpha_prime_loss if batch_size == policy.config["train_batch_size"]: policy.alpha_prime_optim.zero_grad() alpha_prime_loss.backward() policy.alpha_prime_optim.step() # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss] + ([alpha_prime_loss] if use_lagrange else []))
def cql_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: print(policy.cur_iter) policy.cur_iter += 1 # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] twin_q = policy.config["twin_q"] discount = policy.config["gamma"] action_low = model.action_space.low[0] action_high = model.action_space.high[0] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = train_batch[SampleBatch.ACTIONS] rewards = train_batch[SampleBatch.REWARDS] next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model({ "obs": obs, "is_training": True, }, [], None) model_out_tp1, _ = model({ "obs": next_obs, "is_training": True, }, [], None) target_model_out_tp1, _ = policy.target_model({ "obs": next_obs, "is_training": True, }, [], None) action_dist_class = _get_dist_class(policy.config, policy.action_space) action_dist_t = action_dist_class( model.get_policy_output(model_out_t), policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean() # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = torch.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_ = model.get_twin_q_values(model_out_t, policy_t) min_q = torch.min(min_q, twin_q_) actor_loss = (alpha.detach() * log_pis_t - min_q).mean() else: bc_logp = action_dist_t.logp(actions) actor_loss = (alpha * log_pis_t - bc_logp).mean() # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() # Q-values for the batched actions. q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) q_t = torch.squeeze(q_t, dim=-1) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) twin_q_t = torch.squeeze(twin_q_t, dim=-1) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 = torch.squeeze(input=q_tp1, dim=-1) q_tp1 = (1.0 - terminals.float()) * q_tp1 # compute RHS of bellman equation q_t_target = ( rewards + (discount**policy.config["n_step"]) * q_tp1).detach() # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = torch.abs(q_t - q_t_target) if twin_q: twin_td_error = torch.abs(twin_q_t - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [nn.MSELoss()(q_t, q_t_target)] if twin_q: critic_loss.append(nn.MSELoss()(twin_q_t, q_t_target)) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions = convert_to_torch_tensor( torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(action_low, action_high), policy.device) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, obs, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, next_obs, num_actions) curr_logp = curr_logp.view(actions.shape[0], num_actions, 1) next_logp = next_logp.view(actions.shape[0], num_actions, 1) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat( model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat( model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**curr_actions.shape[-1]) cat_q1 = torch.cat([ q1_rand - random_density, q1_next_actions - next_logp.detach(), q1_curr_actions - curr_logp.detach() ], 1) if twin_q: cat_q2 = torch.cat([ q2_rand - random_density, q2_next_actions - next_logp.detach(), q2_curr_actions - curr_logp.detach() ], 1) min_qf1_loss = torch.logsumexp( cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf1_loss = min_qf1_loss - q_t.mean() * min_q_weight if twin_q: min_qf2_loss = torch.logsumexp( cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf2_loss = min_qf2_loss - twin_q_t.mean() * min_q_weight if use_lagrange: alpha_prime = torch.clamp( model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf2_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss[0] += min_qf1_loss if twin_q: critic_loss[1] += min_qf2_loss # Save for stats function. policy.q_t = q_t policy.policy_t = policy_t policy.log_pis_t = log_pis_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss # Return all loss terms corresponding to our optimizers. if use_lagrange: return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss] + [policy.alpha_prime_loss]) return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _, train_batch: SampleBatch) -> TensorType: target_model = policy.target_models[model] twin_q = policy.config["twin_q"] gamma = policy.config["gamma"] n_step = policy.config["n_step"] use_huber = policy.config["use_huber"] huber_threshold = policy.config["huber_threshold"] l2_reg = policy.config["l2_reg"] input_dict = SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True) input_dict_next = SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True) model_out_t, _ = model(input_dict, [], None) model_out_tp1, _ = model(input_dict_next, [], None) target_model_out_tp1, _ = target_model(input_dict_next, [], None) # Policy network evaluation. # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) policy_t = model.get_policy_output(model_out_t) # policy_batchnorm_update_ops = list( # set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) policy_tp1 = target_model.get_policy_output(target_model_out_tp1) # Action outputs. if policy.config["smooth_target_policy"]: target_noise_clip = policy.config["target_noise_clip"] clipped_normal_sample = torch.clamp( torch.normal(mean=torch.zeros(policy_tp1.size()), std=policy.config["target_noise"]).to( policy_tp1.device), -target_noise_clip, target_noise_clip, ) policy_tp1_smoothed = torch.min( torch.max( policy_tp1 + clipped_normal_sample, torch.tensor( policy.action_space.low, dtype=torch.float32, device=policy_tp1.device, ), ), torch.tensor(policy.action_space.high, dtype=torch.float32, device=policy_tp1.device), ) else: # No smoothing, just use deterministic actions. policy_tp1_smoothed = policy_tp1 # Q-net(s) evaluation. # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) # Q-values for given actions & observations in given current q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy (no noise) in given current state q_t_det_policy = model.get_q_values(model_out_t, policy_t) actor_loss = -torch.mean(q_t_det_policy) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # q_batchnorm_update_ops = list( # set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) # Target q-net(s) evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed) if twin_q: twin_q_tp1 = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1_smoothed) q_t_selected = torch.squeeze(q_t, axis=len(q_t.shape) - 1) if twin_q: twin_q_t_selected = torch.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1_best = torch.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Compute RHS of bellman equation. q_t_selected_target = (train_batch[SampleBatch.REWARDS] + gamma**n_step * q_tp1_best_masked).detach() # Compute the error (potentially clipped). if twin_q: td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) + huber_loss( twin_td_error, huber_threshold) else: errors = 0.5 * (torch.pow(td_error, 2.0) + torch.pow(twin_td_error, 2.0)) else: td_error = q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) else: errors = 0.5 * torch.pow(td_error, 2.0) critic_loss = torch.mean(train_batch[PRIO_WEIGHTS] * errors) # Add l2-regularization if required. if l2_reg is not None: for name, var in model.policy_variables(as_dict=True).items(): if "bias" not in name: actor_loss += l2_reg * l2_loss(var) for name, var in model.q_variables(as_dict=True).items(): if "bias" not in name: critic_loss += l2_reg * l2_loss(var) # Model self-supervised losses. if policy.config["use_state_preprocessor"]: # Expand input_dict in case custom_loss' need them. input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS] input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] [actor_loss, critic_loss] = model.custom_loss([actor_loss, critic_loss], input_dict) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error # Return two loss terms (corresponding to the two optimizers, we create). return actor_loss, critic_loss
def cql_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32) rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) target_model_out_tp1, _ = policy.target_model( SampleBatch(obs=next_obs, _is_training=True), [], None) action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = tf.expand_dims(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)) # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = tf.math.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_ = model.get_twin_q_values(model_out_t, policy_t) min_q = tf.math.minimum(min_q, twin_q_) actor_loss = tf.reduce_mean( tf.stop_gradient(alpha) * log_pis_t - min_q) else: bc_logp = action_dist_t.logp(actions) actor_loss = tf.reduce_mean( tf.stop_gradient(alpha) * log_pis_t - bc_logp) # actor_loss = -tf.reduce_mean(bc_logp) # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_tp1 = action_dist_class(model.get_policy_output(model_out_tp1), model) policy_tp1, _ = action_dist_tp1.sample_logp() q_t = model.get_q_values(model_out_t, actions) q_t_selected = tf.squeeze(q_t, axis=-1) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, actions) twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1) q_tp1_best = tf.squeeze(input=q_tp1, axis=-1) q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best # compute RHS of bellman equation q_t_target = tf.stop_gradient(rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked) # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = tf.math.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target) if twin_q: critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions, _ = policy._random_action_generator.get_exploration_action( action_distribution=action_dist_class( tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model), timestep=0, explore=True, ) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, model_out_t, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, model_out_tp1, num_actions) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**int(curr_actions.shape[-1])) cat_q1 = tf.concat( [ q1_rand - random_density, q1_next_actions - tf.stop_gradient(next_logp), q1_curr_actions - tf.stop_gradient(curr_logp), ], 1, ) if twin_q: cat_q2 = tf.concat( [ q2_rand - random_density, q2_next_actions - tf.stop_gradient(next_logp), q2_curr_actions - tf.stop_gradient(curr_logp), ], 1, ) min_qf1_loss_ = ( tf.reduce_mean(tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1)) * min_q_weight * cql_temp) min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight) if twin_q: min_qf2_loss_ = ( tf.reduce_mean(tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1)) * min_q_weight * cql_temp) min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) * min_q_weight) if use_lagrange: alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0, 1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) # Save for stats function. policy.q_t = q_t_selected policy.policy_t = policy_t policy.log_pis_t = log_pis_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss # Return all loss terms corresponding to our optimizers. if use_lagrange: return actor_loss + tf.math.add_n( critic_loss) + alpha_loss + alpha_prime_loss return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches seq_lens = train_batch.get("seq_lens") model_out_t, state_in_t = model( { "obs": train_batch[SampleBatch.CUR_OBS], "prev_actions": train_batch[SampleBatch.PREV_ACTIONS], "prev_rewards": train_batch[SampleBatch.PREV_REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"]) model_out_tp1, state_in_tp1 = model( { "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) target_model_out_tp1, target_state_in_tp1 = policy.target_model( { "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) target_states_in_tp1 = \ policy.target_model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = F.log_softmax(model.get_policy_output( model_out_t, states_in_t["policy"], seq_lens)[0], dim=-1) policy_t = torch.exp(log_pis_t) log_pis_tp1 = F.log_softmax( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)[0] # Target Q-values. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens)[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t, states_in_t["twin_q"], seq_lens)[0] twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens)[0] q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * \ q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class( model.get_policy_output(model_out_t, states_in_t["policy"], seq_lens)[0], policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, policy_t)[0] if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, policy_t)[0] q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens, policy_tp1)[0] if policy.config["twin_q"]: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens, policy_tp1)[0] # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # BURNIN # B = state_batches[0].shape[0] T = q_t_selected.shape[0] // B seq_mask = sequence_mask(train_batch["seq_lens"], T) # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False seq_mask = seq_mask.reshape(-1) num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = reduce_mean_valid( torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = reduce_mean_valid( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach()), dim=-1)) else: alpha_loss = -reduce_mean_valid( model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) # Save for stats function. policy.q_t = q_t * seq_mask[..., None] policy.policy_t = policy_t * seq_mask[..., None] policy.log_pis_t = log_pis_t * seq_mask[..., None] # Store td-error in model, such that for multi-GPU, we do not override # them during the parallel loss phase. TD-error tensor in final stats # can then be concatenated and retrieved for each individual batch item. model.td_error = td_error * seq_mask policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # Return all loss terms corresponding to our optimizers. return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])
def action_distribution_fn( policy: Policy, model: ModelV2, input_dict: ModelInputDict, *, state_batches: Optional[List[TensorType]] = None, seq_lens: Optional[TensorType] = None, prev_action_batch: Optional[TensorType] = None, prev_reward_batch=None, explore: Optional[bool] = None, timestep: Optional[int] = None, is_training: Optional[bool] = None) -> \ Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]: """The action distribution function to be used the algorithm. An action distribution function is used to customize the choice of action distribution class and the resulting action distribution inputs (to parameterize the distribution object). After parameterizing the distribution, a `sample()` call will be made on it to generate actions. Args: policy (Policy): The Policy being queried for actions and calling this function. model (TorchModelV2): The SAC specific Model to use to generate the distribution inputs (see sac_tf|torch_model.py). Must support the `get_policy_output` method. input_dict (ModelInputDict): The input-dict to be used for the model call. state_batches (Optional[List[TensorType]]): The list of internal state tensor batches. seq_lens (Optional[TensorType]): The tensor of sequence lengths used in RNNs. prev_action_batch (Optional[TensorType]): Optional batch of prev actions used by the model. prev_reward_batch (Optional[TensorType]): Optional batch of prev rewards used by the model. explore (Optional[bool]): Whether to activate exploration or not. If None, use value of `config.explore`. timestep (Optional[int]): An optional timestep. is_training (Optional[bool]): An optional is-training flag. Returns: Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]: The dist inputs, dist class, and a list of internal state outputs (in the RNN case). """ # Get base-model output (w/o the SAC specific parts of the network). model_out, state_in = model(input_dict, state_batches, seq_lens) # Use the base output to get the policy outputs from the SAC model's # policy components. states_in = model.select_state(state_in, ["policy", "q", "twin_q"]) distribution_inputs, policy_state_out = \ model.get_policy_output(model_out, states_in["policy"], seq_lens) _, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens) if model.twin_q_net: _, twin_q_state_out = \ model.get_twin_q_values(model_out, states_in["twin_q"], seq_lens) else: twin_q_state_out = [] # Get a distribution class to be used with the just calculated dist-inputs. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) states_out = policy_state_out + q_state_out + twin_q_state_out return distribution_inputs, action_dist_class, states_out
def cql_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] action_low = model.action_space.low[0] action_high = model.action_space.high[0] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = train_batch[SampleBatch.ACTIONS] rewards = train_batch[SampleBatch.REWARDS] next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] policy_optimizer = policy._optimizers[0] critic1_optimizer = policy._optimizers[1] critic2_optimizer = policy._optimizers[2] alpha_optimizer = policy._optimizers[3] model_out_t, _ = model({ "obs": obs, "is_training": True, }, [], None) model_out_tp1, _ = model({ "obs": next_obs, "is_training": True, }, [], None) target_model_out_tp1, _ = policy.target_model( { "obs": next_obs, "is_training": True, }, [], None) action_dist_class = _get_dist_class(policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), policy.model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = torch.unsqueeze(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean() if obs.shape[0] == policy.config["train_batch_size"]: alpha_optimizer.zero_grad() alpha_loss.backward() alpha_optimizer.step() # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = torch.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_ = model.get_twin_q_values(model_out_t, policy_t) min_q = torch.min(min_q, twin_q_) actor_loss = (alpha.detach() * log_pis_t - min_q).mean() else: def bc_log(model, obs, actions): z = atanh(actions) logits = model.get_policy_output(obs) mean, log_std = torch.chunk(logits, 2, dim=-1) # Mean Clamping for Stability mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX) log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT) std = torch.exp(log_std) normal_dist = torch.distributions.Normal(mean, std) return torch.sum(normal_dist.log_prob(z) - torch.log(1 - actions * actions + SMALL_NUMBER), dim=-1) bc_logp = bc_log(model, model_out_t, actions) actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean() if obs.shape[0] == policy.config["train_batch_size"]: policy_optimizer.zero_grad() actor_loss.backward(retain_graph=True) policy_optimizer.step() # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_tp1 = action_dist_class(model.get_policy_output(model_out_tp1), policy.model) policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp() log_pis_tp1 = torch.unsqueeze(log_pis_tp1, -1) q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) q_t_selected = torch.squeeze(q_t, dim=-1) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best # compute RHS of bellman equation q_t_target = ( rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = torch.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = torch.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target) if twin_q: critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions = convert_to_torch_tensor( torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(action_low, action_high), policy.device) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, model_out_t, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, model_out_tp1, num_actions) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**curr_actions.shape[-1]) cat_q1 = torch.cat([ q1_rand - random_density, q1_next_actions - next_logp.detach(), q1_curr_actions - curr_logp.detach() ], 1) if twin_q: cat_q2 = torch.cat([ q2_rand - random_density, q2_next_actions - next_logp.detach(), q2_curr_actions - curr_logp.detach() ], 1) min_qf1_loss_ = torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight) if twin_q: min_qf2_loss_ = torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight) if use_lagrange: alpha_prime = torch.clamp(model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) if obs.shape[0] == policy.config["train_batch_size"]: critic1_optimizer.zero_grad() critic_loss[0].backward(retain_graph=True) critic1_optimizer.step() critic2_optimizer.zero_grad() critic_loss[1].backward(retain_graph=False) critic2_optimizer.step() # Save for stats function. policy.q_t = q_t_selected policy.policy_t = policy_t policy.log_pis_t = log_pis_t model.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss # Return all loss terms corresponding to our optimizers. if use_lagrange: return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss] + [policy.alpha_prime_loss]) return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])