def appo_surrogate_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> TensorType: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = policy.target_model.from_batch(train_batch) prev_action_dist = dist_class(behaviour_logits, policy.model) values = policy.model.value_function() values_time_major = _make_time_major(values) policy.model_vars = policy.model.variables() policy.target_model_vars = policy.target_model.variables() if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) - 1 mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t * mask) / num_valid else: reduce_mean_valid = torch.mean if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True)") old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_old_policy_behaviour_logits = torch.chunk( old_policy_behaviour_logits, output_hidden_shape, dim=1) # Prepare actions for loss loss_actions = actions if is_multidiscrete else torch.unsqueeze( actions, dim=1) # Prepare KL for Loss action_kl = _make_time_major(old_policy_action_dist.kl(action_dist), drop_last=True) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits, drop_last=True), target_policy_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=True), actions=torch.unbind(_make_time_major(loss_actions, drop_last=True), dim=2), discounts=(1.0 - _make_time_major(dones, drop_last=True).float()) * policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=values_time_major[:-1], # drop-last=True bootstrap_value=values_time_major[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"] ) actions_logp = _make_time_major(action_dist.logp(actions), drop_last=True) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions), drop_last=True) old_policy_actions_logp = _make_time_major( old_policy_action_dist.logp(actions), drop_last=True) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) policy._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. delta = values_time_major[:-1] - vtrace_returns.vs value_targets = vtrace_returns.vs mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy(), drop_last=True)) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss action_kl = _make_time_major(prev_action_dist.kl(action_dist)) actions_logp = _make_time_major(action_dist.logp(actions)) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) logp_ratio = torch.exp(actions_logp - prev_actions_logp) advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = _make_time_major( train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy())) # The summed weighted loss total_loss = mean_policy_loss + \ mean_vf_loss * policy.config["vf_loss_coeff"] - \ mean_entropy * policy.config["entropy_coeff"] # Optional additional KL Loss if policy.config["use_kl_loss"]: total_loss += policy.kl_coeff * mean_kl policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_kl = mean_kl policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._value_targets = value_targets return total_loss
def ppo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ logits, state = model.from_batch(train_batch, is_training=True) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: max_seq_len = torch.max(train_batch["seq_lens"]) mask = sequence_mask(train_batch["seq_lens"], max_seq_len, time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid # non-RNN case: No masking. else: mask = None reduce_mean_valid = torch.mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = torch.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) if policy.config["use_gae"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_clipped = prev_value_fn_out + torch.clamp( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = torch.pow( vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss = torch.max(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) else: mean_vf_loss = 0.0 total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl - policy.entropy_coeff * curr_entropy) # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl return total_loss
def cql_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32) rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) target_model_out_tp1, _ = policy.target_model( SampleBatch(obs=next_obs, _is_training=True), [], None) action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_inputs_t, _ = model.get_action_model_outputs(model_out_t) action_dist_t = action_dist_class(action_dist_inputs_t, model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = tf.expand_dims(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)) # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = tf.math.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q, _ = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_, _ = model.get_twin_q_values(model_out_t, policy_t) min_q = tf.math.minimum(min_q, twin_q_) actor_loss = tf.reduce_mean( tf.stop_gradient(alpha) * log_pis_t - min_q) else: bc_logp = action_dist_t.logp(actions) actor_loss = tf.reduce_mean( tf.stop_gradient(alpha) * log_pis_t - bc_logp) # actor_loss = -tf.reduce_mean(bc_logp) # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_inputs_tp1, _ = model.get_action_model_outputs(model_out_tp1) action_dist_tp1 = action_dist_class(action_dist_inputs_tp1, model) policy_tp1, _ = action_dist_tp1.sample_logp() q_t, _ = model.get_q_values(model_out_t, actions) q_t_selected = tf.squeeze(q_t, axis=-1) if twin_q: twin_q_t, _ = model.get_twin_q_values(model_out_t, actions) twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1) # Target q network evaluation. q_tp1, _ = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1, _ = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1) q_tp1_best = tf.squeeze(input=q_tp1, axis=-1) q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best # compute RHS of bellman equation q_t_target = tf.stop_gradient(rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked) # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = tf.math.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target) if twin_q: critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions, _ = policy._random_action_generator.get_exploration_action( action_distribution=action_dist_class( tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model), timestep=0, explore=True, ) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, model_out_t, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, model_out_tp1, num_actions) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**int(curr_actions.shape[-1])) cat_q1 = tf.concat( [ q1_rand - random_density, q1_next_actions - tf.stop_gradient(next_logp), q1_curr_actions - tf.stop_gradient(curr_logp), ], 1, ) if twin_q: cat_q2 = tf.concat( [ q2_rand - random_density, q2_next_actions - tf.stop_gradient(next_logp), q2_curr_actions - tf.stop_gradient(curr_logp), ], 1, ) min_qf1_loss_ = ( tf.reduce_mean(tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1)) * min_q_weight * cql_temp) min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight) if twin_q: min_qf2_loss_ = ( tf.reduce_mean(tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1)) * min_q_weight * cql_temp) min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) * min_q_weight) if use_lagrange: alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0, 1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) # Save for stats function. policy.q_t = q_t_selected policy.policy_t = policy_t policy.log_pis_t = log_pis_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss # Return all loss terms corresponding to our optimizers. if use_lagrange: return actor_loss + tf.math.add_n( critic_loss) + alpha_loss + alpha_prime_loss return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
def context(self): """Returns a contextmanager for the current TF graph.""" if self.graph: return self.graph.as_default() else: return ModelV2.context(self)
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Look up the target model (tower) using the model tower. target_model = policy.target_models[model] # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] model_out_t, _ = model( { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": True, }, [], None) model_out_tp1, _ = model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": True, }, [], None) target_model_out_tp1, _ = target_model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": True, }, [], None) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = F.log_softmax(model.get_policy_output(model_out_t), dim=-1) policy_t = torch.exp(log_pis_t) log_pis_tp1 = F.log_softmax(model.get_policy_output(model_out_tp1), -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t) # Target Q-values. q_tp1 = target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t) twin_q_tp1 = target_model.get_twin_q_values(target_model_out_tp1) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * \ q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1 = target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * \ q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = torch.mean(torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = torch.mean( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach()), dim=-1)) else: alpha_loss = -torch.mean(model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t model.tower_stats["policy_t"] = policy_t model.tower_stats["log_pis_t"] = log_pis_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss])
def cql_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # Look up the target model (tower) using the model tower. target_model = policy.target_models[model] # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] action_low = model.action_space.low[0] action_high = model.action_space.high[0] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = train_batch[SampleBatch.ACTIONS] rewards = train_batch[SampleBatch.REWARDS].float() next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model({ "obs": obs, "is_training": True, }, [], None) model_out_tp1, _ = model({ "obs": next_obs, "is_training": True, }, [], None) target_model_out_tp1, _ = target_model({ "obs": next_obs, "is_training": True, }, [], None) action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class( model.get_policy_output(model_out_t), policy.model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = torch.unsqueeze(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean() if obs.shape[0] == policy.config["train_batch_size"]: policy.alpha_optim.zero_grad() alpha_loss.backward() policy.alpha_optim.step() # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = torch.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_ = model.get_twin_q_values(model_out_t, policy_t) min_q = torch.min(min_q, twin_q_) actor_loss = (alpha.detach() * log_pis_t - min_q).mean() else: bc_logp = action_dist_t.logp(actions) actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean() # actor_loss = -bc_logp.mean() if obs.shape[0] == policy.config["train_batch_size"]: policy.actor_optim.zero_grad() actor_loss.backward(retain_graph=True) policy.actor_optim.step() # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), policy.model) policy_tp1, _ = action_dist_tp1.sample_logp() q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) q_t_selected = torch.squeeze(q_t, dim=-1) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) # Target q network evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1 = target_model.get_twin_q_values(target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best # compute RHS of bellman equation q_t_target = ( rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = torch.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = torch.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target) if twin_q: critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions = convert_to_torch_tensor( torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(action_low, action_high), policy.device) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, model_out_t, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, model_out_tp1, num_actions) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat( model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat( model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**curr_actions.shape[-1]) cat_q1 = torch.cat([ q1_rand - random_density, q1_next_actions - next_logp.detach(), q1_curr_actions - curr_logp.detach() ], 1) if twin_q: cat_q2 = torch.cat([ q2_rand - random_density, q2_next_actions - next_logp.detach(), q2_curr_actions - curr_logp.detach() ], 1) min_qf1_loss_ = torch.logsumexp( cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight) if twin_q: min_qf2_loss_ = torch.logsumexp( cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight) if use_lagrange: alpha_prime = torch.clamp( model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) if obs.shape[0] == policy.config["train_batch_size"]: policy.critic_optims[0].zero_grad() critic_loss[0].backward(retain_graph=True) policy.critic_optims[0].step() if twin_q: policy.critic_optims[1].zero_grad() critic_loss[1].backward(retain_graph=False) policy.critic_optims[1].step() # Save for stats function. policy.q_t = q_t_selected policy.policy_t = policy_t policy.log_pis_t = log_pis_t model.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats. policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss if obs.shape[0] == policy.config["train_batch_size"]: policy.alpha_prime_optim.zero_grad() alpha_prime_loss.backward() policy.alpha_prime_optim.step() # Return all loss terms corresponding to our optimizers. if use_lagrange: return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss] + [policy.alpha_prime_loss]) return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])
def loss( self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Compute loss for Proximal Policy Objective. Args: model: The Model to calculate the loss for. dist_class: The action distr. class. train_batch: The training data. Returns: The PPO loss tensor given the input batch. """ logits, state = model(train_batch) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask = sequence_mask( train_batch[SampleBatch.SEQ_LENS], max_seq_len, time_major=model.is_time_major(), ) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid # non-RNN case: No masking. else: mask = None reduce_mean_valid = torch.mean prev_action_dist = dist_class( train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = torch.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) # Only calculate kl loss if necessary (kl-coeff > 0.0). if self.config["kl_coeff"] > 0.0: action_kl = prev_action_dist.kl(curr_action_dist) mean_kl_loss = reduce_mean_valid(action_kl) else: mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"]), ) mean_policy_loss = reduce_mean_valid(-surrogate_loss) # Compute a value function loss. if self.config["use_critic"]: value_fn_out = model.value_function() vf_loss = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss_clipped = torch.clamp(vf_loss, 0, self.config["vf_clip_param"]) mean_vf_loss = reduce_mean_valid(vf_loss_clipped) # Ignore the value function. else: value_fn_out = 0 vf_loss_clipped = mean_vf_loss = 0.0 total_loss = reduce_mean_valid(-surrogate_loss + self.config["vf_loss_coeff"] * vf_loss_clipped - self.entropy_coeff * curr_entropy) # Add mean_kl_loss (already processed through `reduce_mean_valid`), # if necessary. if self.config["kl_coeff"] > 0.0: total_loss += self.kl_coeff * mean_kl_loss # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_policy_loss"] = mean_policy_loss model.tower_stats["mean_vf_loss"] = mean_vf_loss model.tower_stats["vf_explained_var"] = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], value_fn_out) model.tower_stats["mean_entropy"] = mean_entropy model.tower_stats["mean_kl_loss"] = mean_kl_loss return total_loss
def loss( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss function. Args: model: The Model to calculate the loss for. dist_class: The action distr. class. train_batch: The training data. Returns: The A3C loss tensor given the input batch. """ logits, _ = model(train_batch) values = model.value_function() if self.is_recurrent(): B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) valid_mask = torch.reshape(mask_orig, [-1]) else: valid_mask = torch.ones_like(values, dtype=torch.bool) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) pi_err = -torch.sum( torch.masked_select( log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask)) # Compute a value function loss. if self.config["use_critic"]: value_err = 0.5 * torch.sum( torch.pow( torch.masked_select( values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], valid_mask, ), 2.0, )) # Ignore the value function. else: value_err = 0.0 entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) total_loss = (pi_err + value_err * self.config["vf_loss_coeff"] - entropy * self.entropy_coeff) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["entropy"] = entropy model.tower_stats["pi_err"] = pi_err model.tower_stats["value_err"] = value_err return total_loss
def loss( self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(self.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [self.action_space.n] elif isinstance(self.action_space, gym.spaces.MultiDiscrete): is_multidiscrete = True output_hidden_shape = self.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_outputs = torch.split(model_out, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) values = model.value_function() if self.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask_orig, [-1]) else: mask = torch.ones_like(rewards) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze( actions, dim=1) # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. drop_last = self.config["vtrace_drop_last_ts"] loss = VTraceLoss( actions=_make_time_major(loss_actions, drop_last=drop_last), actions_logp=_make_time_major(action_dist.logp(actions), drop_last=drop_last), actions_entropy=_make_time_major(action_dist.entropy(), drop_last=drop_last), dones=_make_time_major(dones, drop_last=drop_last), behaviour_action_logp=_make_time_major(behaviour_action_logp, drop_last=drop_last), behaviour_logits=_make_time_major(unpacked_behaviour_logits, drop_last=drop_last), target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last), discount=self.config["gamma"], rewards=_make_time_major(rewards, drop_last=drop_last), values=_make_time_major(values, drop_last=drop_last), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, valid_mask=_make_time_major(mask, drop_last=drop_last), config=self.config, vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.entropy_coeff, clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], ) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["pi_loss"] = loss.pi_loss model.tower_stats["vf_loss"] = loss.vf_loss model.tower_stats["entropy"] = loss.entropy model.tower_stats["mean_entropy"] = loss.mean_entropy model.tower_stats["total_loss"] = loss.total_loss values_batched = make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), values, drop_last=self.config["vtrace"] and drop_last, ) model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(loss.value_targets, [-1]), torch.reshape(values_batched, [-1])) return loss.total_loss
def build_slateq_losses( policy: Policy, model: ModelV2, _: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> TensorType: """Constructs the choice- and Q-value losses for the SlateQTorchPolicy. Args: policy: The Policy to calculate the loss for. model: The Model to calculate the loss for. train_batch: The training data. Returns: Tuple consisting of 1) the choice loss- and 2) the Q-value loss tensors. """ start = time.time() obs = restore_original_dimensions(train_batch[SampleBatch.OBS], policy.observation_space, tensorlib=torch) # user.shape: [batch_size, embedding_size] user = obs["user"] # doc.shape: [batch_size, num_docs, embedding_size] doc = torch.cat([val.unsqueeze(1) for val in obs["doc"].values()], 1) # action.shape: [batch_size, slate_size] actions = train_batch[SampleBatch.ACTIONS] next_obs = restore_original_dimensions(train_batch[SampleBatch.NEXT_OBS], policy.observation_space, tensorlib=torch) # Step 1: Build user choice model loss _, _, embedding_size = doc.shape # selected_doc.shape: [batch_size, slate_size, embedding_size] selected_doc = torch.gather( # input.shape: [batch_size, num_docs, embedding_size] input=doc, dim=1, # index.shape: [batch_size, slate_size, embedding_size] index=actions.unsqueeze(2).expand(-1, -1, embedding_size).long(), ) scores = model.choice_model(user, selected_doc) choice_loss_fn = nn.CrossEntropyLoss() # clicks.shape: [batch_size, slate_size] clicks = torch.stack( [resp["click"][:, 1] for resp in next_obs["response"]], dim=1) no_clicks = 1 - torch.sum(clicks, 1, keepdim=True) # clicks.shape: [batch_size, slate_size+1] targets = torch.cat([clicks, no_clicks], dim=1) choice_loss = choice_loss_fn(scores, torch.argmax(targets, dim=1)) # print(model.choice_model.a.item(), model.choice_model.b.item()) # Step 2: Build qvalue loss # Fields in available in train_batch: ['t', 'eps_id', 'agent_index', # 'next_actions', 'obs', 'actions', 'rewards', 'prev_actions', # 'prev_rewards', 'dones', 'infos', 'new_obs', 'unroll_id', 'weights', # 'batch_indexes'] learning_strategy = policy.config["slateq_strategy"] # Myopic agent: Don't care about value of next state. # Acts only based off immediate reward. if learning_strategy == "MYOP": next_q_values = torch.tensor(0.0, requires_grad=False) # Q-learning: Default setting for SlateQ -> Use DQN-style loss function. elif learning_strategy == "QL": # next_doc.shape: [batch_size, num_docs, embedding_size] next_doc = torch.cat( [val.unsqueeze(1) for val in next_obs["doc"].values()], 1) next_user = next_obs["user"] dones = train_batch[SampleBatch.DONES] with torch.no_grad(): if policy.config["double_q"]: next_target_per_slate_q_values = policy.target_models[ model].get_per_slate_q_values(next_user, next_doc) _, next_q_values, _ = model.choose_slate( next_user, next_doc, next_target_per_slate_q_values) else: _, next_q_values, _ = policy.target_models[model].choose_slate( next_user, next_doc) next_q_values = next_q_values.detach() next_q_values[dones.bool()] = 0.0 # SARS'A': Use on-policy sarsa loss. elif learning_strategy == "SARSA": # next_doc.shape: [batch_size, num_docs, embedding_size] next_doc = torch.cat( [val.unsqueeze(1) for val in next_obs["doc"].values()], 1) next_actions = train_batch["next_actions"] _, _, embedding_size = next_doc.shape # selected_doc.shape: [batch_size, slate_size, embedding_size] next_selected_doc = torch.gather( # input.shape: [batch_size, num_docs, embedding_size] input=next_doc, dim=1, # index.shape: [batch_size, slate_size, embedding_size] index=next_actions.unsqueeze(2).expand(-1, -1, embedding_size).long(), ) next_user = next_obs["user"] dones = train_batch[SampleBatch.DONES] with torch.no_grad(): # q_values.shape: [batch_size, slate_size+1] q_values = model.q_model(next_user, next_selected_doc) # raw_scores.shape: [batch_size, slate_size+1] raw_scores = model.choice_model(next_user, next_selected_doc) max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True) scores = torch.exp(raw_scores - max_raw_scores) # next_q_values.shape: [batch_size] next_q_values = torch.sum(q_values * scores, dim=1) / torch.sum( scores, dim=1) next_q_values[dones.bool()] = 0.0 else: raise ValueError(learning_strategy) # target_q_values.shape: [batch_size] target_q_values = (train_batch[SampleBatch.REWARDS] + policy.config["gamma"] * next_q_values) # q_values.shape: [batch_size, slate_size+1]. q_values = model.q_model(user, selected_doc) # raw_scores.shape: [batch_size, slate_size+1]. raw_scores = model.choice_model(user, selected_doc) max_raw_scores, _ = torch.max(raw_scores, dim=1, keepdim=True) scores = torch.exp(raw_scores - max_raw_scores) q_values = torch.sum(q_values * scores, dim=1) / torch.sum( scores, dim=1) # shape=[batch_size] td_error = torch.abs(q_values - target_q_values) q_value_loss = torch.mean(huber_loss(td_error)) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_loss"] = q_value_loss model.tower_stats["q_values"] = q_values model.tower_stats["next_q_values"] = next_q_values model.tower_stats["next_q_minus_q"] = next_q_values - q_values model.tower_stats["td_error"] = td_error model.tower_stats["target_q_values"] = target_q_values model.tower_stats["scores"] = scores model.tower_stats["raw_scores"] = raw_scores model.tower_stats["choice_loss"] = choice_loss model.tower_stats["choice_beta"] = model.choice_model.beta model.tower_stats[ "choice_score_no_click"] = model.choice_model.score_no_click logger.debug(f"loss calculation took {time.time()-start}s") return choice_loss, q_value_loss
def build_slateq_losses( policy: Policy, model: ModelV2, _, train_batch: SampleBatch, ) -> TensorType: """Constructs the choice- and Q-value losses for the SlateQTorchPolicy. Args: policy: The Policy to calculate the loss for. model: The Model to calculate the loss for. train_batch: The training data. Returns: The user-choice- and Q-value loss tensors. """ # B=batch size # S=slate size # C=num candidates # E=embedding size # A=number of all possible slates # Q-value computations. # --------------------- # action.shape: [B, S] actions = train_batch[SampleBatch.ACTIONS] observation = convert_to_torch_tensor( train_batch[SampleBatch.OBS], device=actions.device ) # user.shape: [B, E] user_obs = observation["user"] batch_size, embedding_size = user_obs.shape # doc.shape: [B, C, E] doc_obs = list(observation["doc"].values()) A, S = policy.slates.shape # click_indicator.shape: [B, S] click_indicator = torch.stack( [k["click"] for k in observation["response"]], 1 ).float() # item_reward.shape: [B, S] item_reward = torch.stack([k["watch_time"] for k in observation["response"]], 1) # q_values.shape: [B, C] q_values = model.get_q_values(user_obs, doc_obs) # slate_q_values.shape: [B, S] slate_q_values = torch.take_along_dim(q_values, actions.long(), dim=-1) # Only get the Q from the clicked document. # replay_click_q.shape: [B] replay_click_q = torch.sum(slate_q_values * click_indicator, dim=1) # Target computations. # -------------------- next_obs = convert_to_torch_tensor( train_batch[SampleBatch.NEXT_OBS], device=actions.device ) # user.shape: [B, E] user_next_obs = next_obs["user"] # doc.shape: [B, C, E] doc_next_obs = list(next_obs["doc"].values()) # Only compute the watch time reward of the clicked item. reward = torch.sum(item_reward * click_indicator, dim=1) # TODO: Find out, whether it's correct here to use obs, not next_obs! # Dopamine uses obs, then next_obs only for the score. # next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs) next_q_values = policy.target_models[model].get_q_values(user_obs, doc_obs) scores, score_no_click = score_documents(user_next_obs, doc_next_obs) # next_q_values_slate.shape: [B, A, S] indices = policy.slates_indices.to(next_q_values.device) next_q_values_slate = torch.take_along_dim(next_q_values, indices, dim=1).reshape( [-1, A, S] ) # scores_slate.shape [B, A, S] scores_slate = torch.take_along_dim(scores, indices, dim=1).reshape([-1, A, S]) # score_no_click_slate.shape: [B, A] score_no_click_slate = torch.reshape( torch.tile(score_no_click, policy.slates.shape[:1]), [batch_size, -1] ) # next_q_target_slate.shape: [B, A] next_q_target_slate = torch.sum(next_q_values_slate * scores_slate, dim=2) / ( torch.sum(scores_slate, dim=2) + score_no_click_slate ) next_q_target_max, _ = torch.max(next_q_target_slate, dim=1) target = reward + policy.config["gamma"] * next_q_target_max * ( 1.0 - train_batch["dones"].float() ) target = target.detach() clicked = torch.sum(click_indicator, dim=1) mask_clicked_slates = clicked > 0 clicked_indices = torch.arange(batch_size).to(mask_clicked_slates.device) clicked_indices = torch.masked_select(clicked_indices, mask_clicked_slates) # Clicked_indices is a vector and torch.gather selects the batch dimension. q_clicked = torch.gather(replay_click_q, 0, clicked_indices) target_clicked = torch.gather(target, 0, clicked_indices) td_error = torch.where( clicked.bool(), replay_click_q - target, torch.zeros_like(train_batch[SampleBatch.REWARDS]), ) if policy.config["use_huber"]: loss = huber_loss(td_error, delta=policy.config["huber_threshold"]) else: loss = torch.pow(td_error, 2.0) loss = torch.mean(loss) td_error = torch.abs(td_error) mean_td_error = torch.mean(td_error) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_values"] = torch.mean(q_values) model.tower_stats["q_clicked"] = torch.mean(q_clicked) model.tower_stats["scores"] = torch.mean(scores) model.tower_stats["score_no_click"] = torch.mean(score_no_click) model.tower_stats["slate_q_values"] = torch.mean(slate_q_values) model.tower_stats["replay_click_q"] = torch.mean(replay_click_q) model.tower_stats["bellman_reward"] = torch.mean(reward) model.tower_stats["next_q_values"] = torch.mean(next_q_values) model.tower_stats["target"] = torch.mean(target) model.tower_stats["next_q_target_slate"] = torch.mean(next_q_target_slate) model.tower_stats["next_q_target_max"] = torch.mean(next_q_target_max) model.tower_stats["target_clicked"] = torch.mean(target_clicked) model.tower_stats["q_loss"] = loss model.tower_stats["td_error"] = td_error model.tower_stats["mean_td_error"] = mean_td_error model.tower_stats["mean_actions"] = torch.mean(actions.float()) # selected_doc.shape: [batch_size, slate_size, embedding_size] selected_doc = torch.gather( # input.shape: [batch_size, num_docs, embedding_size] torch.stack(doc_obs, 1), 1, # index.shape: [batch_size, slate_size, embedding_size] actions.unsqueeze(2).expand(-1, -1, embedding_size).long(), ) scores = model.choice_model(user_obs, selected_doc) # click_indicator.shape: [batch_size, slate_size] # no_clicks.shape: [batch_size, 1] no_clicks = 1 - torch.sum(click_indicator, 1, keepdim=True) # targets.shape: [batch_size, slate_size+1] targets = torch.cat([click_indicator, no_clicks], dim=1) choice_loss = nn.functional.cross_entropy(scores, torch.argmax(targets, dim=1)) # print(model.choice_model.a.item(), model.choice_model.b.item()) model.tower_stats["choice_loss"] = choice_loss return choice_loss, loss
def _compute_critic_loss( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ): discount = self.config["gamma"] # Compute bellman targets to regress on # target, use target model to compute the target target_model = cast(CRRModel, self.target_models[model]) target_out_next, _ = target_model( {SampleBatch.OBS: train_batch[SampleBatch.NEXT_OBS]} ) # compute target values with no gradient with torch.no_grad(): # get the action of the current policy evaluated at the next state pi_s_next = dist_class( target_model.get_policy_output(target_out_next), target_model ) target_a_next = pi_s_next.sample() if not self._is_action_discrete: target_a_next = target_a_next.clamp( torch.from_numpy(self.action_space.low).to(target_a_next), torch.from_numpy(self.action_space.high).to(target_a_next), ) # q1_target = target_model.get_q_values(target_out_next, target_a_next) # q2_target = target_model.get_twin_q_values(target_out_next, target_a_next) # target_q_next = torch.minimum(q1_target, q2_target).squeeze(-1) target_q_next = self._get_q_value( target_model, target_out_next, target_a_next ).squeeze(-1) target = ( train_batch[SampleBatch.REWARDS] + discount * (1.0 - train_batch[SampleBatch.DONES].float()) * target_q_next ) # compute the predicted output model = cast(CRRModel, model) model_out_t, _ = model({SampleBatch.OBS: train_batch[SampleBatch.OBS]}) q1 = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]).squeeze( -1 ) q2 = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS] ).squeeze(-1) # compute the MSE loss for all q-functions loss_q1 = (target - q1) ** 2 loss_q2 = (target - q2) ** 2 loss = 0.5 * (loss_q1 + loss_q2) loss = loss.mean(0) # logging self.log("loss_q1", loss_q1.mean()) self.log("loss_q2", loss_q2.mean()) self.log("targets_avg", target.mean()) self.log("targets_max", target.max()) self.log("targets_min", target.min()) return loss
def _compute_adv_and_logps( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> None: # uses mean|max|expectation to compute estimate of advantages # continuous/discrete action spaces: # for max: # A(s_t, a_t) = Q(s_t, a_t) - max_{a^j} Q(s_t, a^j) # where a^j is m times sampled from the policy p(a | s_t) # for mean: # A(s_t, a_t) = Q(s_t, a_t) - avg( Q(s_t, a^j) ) # where a^j is m times sampled from the policy p(a | s_t) # discrete action space and adv_type=expectation: # A(s_t, a_t) = Q(s_t, a_t) - sum_j[Q(s_t, a^j) * pi(a^j)] advantage_type = self.config["advantage_type"] n_action_sample = self.config["n_action_sample"] batch_size = len(train_batch) out_t, _ = model(train_batch) # construct pi(s_t) and Q(s_t, a_t) for computing advantage actions pi_s_t = dist_class(model.get_policy_output(out_t), model) q_t = self._get_q_value(model, out_t, train_batch[SampleBatch.ACTIONS]) # compute the logp of the actions in the dataset (for computing actor's loss) action_logp = pi_s_t.dist.log_prob(train_batch[SampleBatch.ACTIONS]) # fix the shape if it's not canonical (i.e. shape[-1] != 1) if len(action_logp.shape) <= 1: action_logp.unsqueeze_(-1) train_batch[SampleBatch.ACTION_LOGP] = action_logp if advantage_type == "expectation": assert ( self._is_action_discrete ), "Action space should be discrete when advantage_type = expectation." assert hasattr( self.model, "q_model" ), "CRR's ModelV2 should have q_model neural network in discrete \ action spaces" assert isinstance( pi_s_t.dist, torch.distributions.Categorical ), "The output of the policy should be a torch Categorical \ distribution." q_vals = self.model.q_model(out_t) if hasattr(self.model, "twin_q_model"): q_twins = self.model.twin_q_model(out_t) q_vals = torch.minimum(q_vals, q_twins) probs = pi_s_t.dist.probs v_t = (q_vals * probs).sum(-1, keepdims=True) else: policy_actions = pi_s_t.dist.sample((n_action_sample,)) # samples if self._is_action_discrete: flat_actions = policy_actions.reshape(-1) else: flat_actions = policy_actions.reshape(-1, *self.action_space.shape) reshaped_s_t = train_batch[SampleBatch.OBS].view( 1, batch_size, *self.observation_space.shape ) reshaped_s_t = reshaped_s_t.expand( n_action_sample, batch_size, *self.observation_space.shape ) flat_s_t = reshaped_s_t.reshape(-1, *self.observation_space.shape) input_v_t = SampleBatch( **{SampleBatch.OBS: flat_s_t, SampleBatch.ACTIONS: flat_actions} ) out_v_t, _ = model(input_v_t) flat_q_st_pi = self._get_q_value(model, out_v_t, flat_actions) reshaped_q_st_pi = flat_q_st_pi.reshape(-1, batch_size, 1) if advantage_type == "mean": v_t = reshaped_q_st_pi.mean(dim=0) elif advantage_type == "max": v_t, _ = reshaped_q_st_pi.max(dim=0) else: raise ValueError(f"Invalid advantage type: {advantage_type}.") adv_t = q_t - v_t train_batch["advantages"] = adv_t # logging self.log("q_batch_avg", q_t.mean()) self.log("q_batch_max", q_t.max()) self.log("q_batch_min", q_t.min()) self.log("v_batch_avg", v_t.mean()) self.log("v_batch_max", v_t.max()) self.log("v_batch_min", v_t.min()) self.log("adv_batch_avg", adv_t.mean()) self.log("adv_batch_max", adv_t.max()) self.log("adv_batch_min", adv_t.min()) self.log("reward_batch_avg", train_batch[SampleBatch.REWARDS].mean())
def build_slateq_losses( policy: Policy, model: ModelV2, _, train_batch: SampleBatch, ) -> TensorType: """Constructs the choice- and Q-value losses for the SlateQTorchPolicy. Args: policy: The Policy to calculate the loss for. model: The Model to calculate the loss for. train_batch: The training data. Returns: The Q-value loss tensor. """ # B=batch size # S=slate size # C=num candidates # E=embedding size # A=number of all possible slates # Q-value computations. # --------------------- observation = train_batch[SampleBatch.OBS] # user.shape: [B, E] user_obs = observation["user"] batch_size = tf.shape(user_obs)[0] # doc.shape: [B, C, E] doc_obs = list(observation["doc"].values()) # action.shape: [B, S] actions = train_batch[SampleBatch.ACTIONS] # click_indicator.shape: [B, S] click_indicator = tf.cast( tf.stack([k["click"] for k in observation["response"]], 1), tf.float32) # item_reward.shape: [B, S] item_reward = tf.stack([k["watch_time"] for k in observation["response"]], 1) # q_values.shape: [B, C] q_values = model.get_q_values(user_obs, doc_obs) # slate_q_values.shape: [B, S] slate_q_values = tf.gather(q_values, tf.cast(actions, dtype=tf.int32), batch_dims=-1) # Only get the Q from the clicked document. # replay_click_q.shape: [B] replay_click_q = tf.reduce_sum(input_tensor=slate_q_values * click_indicator, axis=1, name="replay_click_q") # Target computations. # -------------------- next_obs = train_batch[SampleBatch.NEXT_OBS] # user.shape: [B, E] user_next_obs = next_obs["user"] # doc.shape: [B, C, E] doc_next_obs = list(next_obs["doc"].values()) # Only compute the watch time reward of the clicked item. reward = tf.reduce_sum(input_tensor=item_reward * click_indicator, axis=1) # TODO: Find out, whether it's correct here to use obs, not next_obs! # Dopamine uses obs, then next_obs only for the score. # next_q_values = policy.target_model.get_q_values(user_next_obs, doc_next_obs) next_q_values = policy.target_model.get_q_values(user_obs, doc_obs) scores, score_no_click = score_documents(user_next_obs, doc_next_obs) # next_q_values_slate.shape: [B, A, S] next_q_values_slate = tf.gather(next_q_values, policy.slates, axis=1) # scores_slate.shape [B, A, S] scores_slate = tf.gather(scores, policy.slates, axis=1) # score_no_click_slate.shape: [B, A] score_no_click_slate = tf.reshape( tf.tile(score_no_click, tf.shape(input=policy.slates)[:1]), [batch_size, -1]) # next_q_target_slate.shape: [B, A] next_q_target_slate = tf.reduce_sum( input_tensor=next_q_values_slate * scores_slate, axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) + score_no_click_slate) next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1) target = reward + policy.config["gamma"] * next_q_target_max * ( 1.0 - tf.cast(train_batch["dones"], tf.float32)) target = tf.stop_gradient(target) clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1) clicked_indices = tf.squeeze(tf.where(tf.equal(clicked, 1)), axis=1) # Clicked_indices is a vector and tf.gather selects the batch dimension. q_clicked = tf.gather(replay_click_q, clicked_indices) target_clicked = tf.gather(target, clicked_indices) td_error = tf.where( tf.cast(clicked, tf.bool), replay_click_q - target, tf.zeros_like(train_batch[SampleBatch.REWARDS]), ) if policy.config["use_huber"]: loss = huber_loss(td_error, delta=policy.config["huber_threshold"]) else: loss = tf.math.square(td_error) loss = tf.reduce_mean(loss) td_error = tf.abs(td_error) mean_td_error = tf.reduce_mean(td_error) policy._q_values = tf.reduce_mean(q_values) policy._q_clicked = tf.reduce_mean(q_clicked) policy._scores = tf.reduce_mean(scores) policy._score_no_click = tf.reduce_mean(score_no_click) policy._slate_q_values = tf.reduce_mean(slate_q_values) policy._replay_click_q = tf.reduce_mean(replay_click_q) policy._bellman_reward = tf.reduce_mean(reward) policy._next_q_values = tf.reduce_mean(next_q_values) policy._target = tf.reduce_mean(target) policy._next_q_target_slate = tf.reduce_mean(next_q_target_slate) policy._next_q_target_max = tf.reduce_mean(next_q_target_max) policy._target_clicked = tf.reduce_mean(target_clicked) policy._q_loss = loss policy._td_error = td_error policy._mean_td_error = mean_td_error policy._mean_actions = tf.reduce_mean(actions) return loss
def __init__( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict, *, model: ModelV2, loss: Callable[ [Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch], Union[TensorType, List[TensorType]]], action_distribution_class: Type[TorchDistributionWrapper], action_sampler_fn: Optional[Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]] = None, action_distribution_fn: Optional[ Callable[[Policy, ModelV2, TensorType, TensorType, TensorType], Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]]] = None, max_seq_len: int = 20, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, ): """Build a policy from policy and loss torch modules. Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES is set. Only single GPU is supported for now. Args: observation_space (gym.spaces.Space): observation space of the policy. action_space (gym.spaces.Space): action space of the policy. config (TrainerConfigDict): The Policy config dict. model (ModelV2): PyTorch policy module. Given observations as input, this module must return a list of outputs where the first item is action logits, and the rest can be any value. loss (Callable[[Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch], Union[TensorType, List[TensorType]]]): Callable that returns a single scalar loss or a list of loss terms. action_distribution_class (Type[TorchDistributionWrapper]): Class for a torch action distribution. action_sampler_fn (Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]): A callable returning a sampled action and its log-likelihood given Policy, ModelV2, input_dict, explore, timestep, and is_training. action_distribution_fn (Optional[Callable[[Policy, ModelV2, Dict[str, TensorType], TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]]): A callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). Note: No Exploration hooks have to be called from within `action_distribution_fn`. It's should only perform a simple forward pass through some model. If None, pass inputs through `self.model()` to get distribution inputs. The callable takes as inputs: Policy, ModelV2, input_dict, explore, timestep, is_training. max_seq_len (int): Max sequence length for LSTM training. get_batch_divisibility_req (Optional[Callable[[Policy], int]]]): Optional callable that returns the divisibility requirement for sample batches given the Policy. """ self.framework = "torch" super().__init__(observation_space, action_space, config) if torch.cuda.is_available(): logger.info("TorchPolicy running on GPU.") self.device = torch.device("cuda") else: logger.info("TorchPolicy running on CPU.") self.device = torch.device("cpu") self.model = model.to(self.device) # Combine view_requirements for Model and Policy. self.view_requirements.update(self.model.inference_view_requirements) self.exploration = self._create_exploration() self.unwrapped_model = model # used to support DistributedDataParallel self._loss = loss self._optimizers = force_list(self.optimizer()) self.dist_class = action_distribution_class self.action_sampler_fn = action_sampler_fn self.action_distribution_fn = action_distribution_fn # If set, means we are using distributed allreduce during learning. self.distributed_world_size = None self.max_seq_len = max_seq_len self.batch_divisibility_req = get_batch_divisibility_req(self) if \ callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1)
def action_distribution_fn( policy: Policy, model: ModelV2, input_dict: ModelInputDict, *, state_batches: Optional[List[TensorType]] = None, seq_lens: Optional[TensorType] = None, prev_action_batch: Optional[TensorType] = None, prev_reward_batch=None, explore: Optional[bool] = None, timestep: Optional[int] = None, is_training: Optional[bool] = None) -> \ Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]: """The action distribution function to be used the algorithm. An action distribution function is used to customize the choice of action distribution class and the resulting action distribution inputs (to parameterize the distribution object). After parameterizing the distribution, a `sample()` call will be made on it to generate actions. Args: policy (Policy): The Policy being queried for actions and calling this function. model (TorchModelV2): The SAC specific Model to use to generate the distribution inputs (see sac_tf|torch_model.py). Must support the `get_policy_output` method. input_dict (ModelInputDict): The input-dict to be used for the model call. state_batches (Optional[List[TensorType]]): The list of internal state tensor batches. seq_lens (Optional[TensorType]): The tensor of sequence lengths used in RNNs. prev_action_batch (Optional[TensorType]): Optional batch of prev actions used by the model. prev_reward_batch (Optional[TensorType]): Optional batch of prev rewards used by the model. explore (Optional[bool]): Whether to activate exploration or not. If None, use value of `config.explore`. timestep (Optional[int]): An optional timestep. is_training (Optional[bool]): An optional is-training flag. Returns: Tuple[TensorType, Type[TorchDistributionWrapper], List[TensorType]]: The dist inputs, dist class, and a list of internal state outputs (in the RNN case). """ # Get base-model output (w/o the SAC specific parts of the network). model_out, state_in = model(input_dict, state_batches, seq_lens) # Use the base output to get the policy outputs from the SAC model's # policy components. states_in = model.select_state(state_in, ["policy", "q", "twin_q"]) distribution_inputs, policy_state_out = \ model.get_policy_output(model_out, states_in["policy"], seq_lens) _, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens) if model.twin_q_net: _, twin_q_state_out = \ model.get_twin_q_values(model_out, states_in["twin_q"], seq_lens) else: twin_q_state_out = [] # Get a distribution class to be used with the just calculated dist-inputs. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) states_out = policy_state_out + q_state_out + twin_q_state_out return distribution_inputs, action_dist_class, states_out
def __init__( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict, *, model: ModelV2, loss: Callable[[ Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch ], Union[TensorType, List[TensorType]]], action_distribution_class: Type[TorchDistributionWrapper], action_sampler_fn: Optional[Callable[[ TensorType, List[TensorType] ], Tuple[TensorType, TensorType]]] = None, action_distribution_fn: Optional[Callable[[ Policy, ModelV2, TensorType, TensorType, TensorType ], Tuple[TensorType, Type[TorchDistributionWrapper], List[ TensorType]]]] = None, max_seq_len: int = 20, get_batch_divisibility_req: Optional[Callable[[Policy], int]] = None, ): """Build a policy from policy and loss torch modules. Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES is set. Only single GPU is supported for now. Args: observation_space (gym.spaces.Space): observation space of the policy. action_space (gym.spaces.Space): action space of the policy. config (TrainerConfigDict): The Policy config dict. model (ModelV2): PyTorch policy module. Given observations as input, this module must return a list of outputs where the first item is action logits, and the rest can be any value. loss (Callable[[Policy, ModelV2, Type[TorchDistributionWrapper], SampleBatch], Union[TensorType, List[TensorType]]]): Callable that returns a single scalar loss or a list of loss terms. action_distribution_class (Type[TorchDistributionWrapper]): Class for a torch action distribution. action_sampler_fn (Callable[[TensorType, List[TensorType]], Tuple[TensorType, TensorType]]): A callable returning a sampled action and its log-likelihood given Policy, ModelV2, input_dict, explore, timestep, and is_training. action_distribution_fn (Optional[Callable[[Policy, ModelV2, ModelInputDict, TensorType, TensorType], Tuple[TensorType, type, List[TensorType]]]]): A callable returning distribution inputs (parameters), a dist-class to generate an action distribution object from, and internal-state outputs (or an empty list if not applicable). Note: No Exploration hooks have to be called from within `action_distribution_fn`. It's should only perform a simple forward pass through some model. If None, pass inputs through `self.model()` to get distribution inputs. The callable takes as inputs: Policy, ModelV2, ModelInputDict, explore, timestep, is_training. max_seq_len (int): Max sequence length for LSTM training. get_batch_divisibility_req (Optional[Callable[[Policy], int]]]): Optional callable that returns the divisibility requirement for sample batches given the Policy. """ self.framework = "torch" super().__init__(observation_space, action_space, config) # Log device and worker index. from ray.rllib.evaluation.rollout_worker import get_global_worker worker = get_global_worker() worker_idx = worker.worker_index if worker else 0 # Create multi-GPU model towers, if necessary. # - The central main model will be stored under self.model, residing on # self.device. # - Each GPU will have a copy of that model under # self.model_gpu_towers, matching the devices in self.devices. # - Parallelization is done by splitting the train batch and passing # it through the model copies in parallel, then averaging over the # resulting gradients, applying these averages on the main model and # updating all towers' weights from the main model. # - In case of just one device (1 (fake) GPU or 1 CPU), no # parallelization will be done. if config["_fake_gpus"] or config["num_gpus"] == 0 or \ not torch.cuda.is_available(): logger.info("TorchPolicy (worker={}) running on {}.".format( worker_idx if worker_idx > 0 else "local", "{} fake-GPUs".format(config["num_gpus"]) if config["_fake_gpus"] else "CPU")) self.device = torch.device("cpu") self.devices = [ self.device for _ in range(config["num_gpus"] or 1) ] self.model_gpu_towers = [ model if config["num_gpus"] == 0 else copy.deepcopy(model) for i in range(config["num_gpus"] or 1) ] else: logger.info("TorchPolicy (worker={}) running on {} GPU(s).".format( worker_idx if worker_idx > 0 else "local", config["num_gpus"])) self.device = torch.device("cuda") self.devices = [ torch.device("cuda:{}".format(id_)) for i, id_ in enumerate(ray.get_gpu_ids()) if i < config["num_gpus"] ] self.model_gpu_towers = nn.parallel.replicate.replicate( model, [ id_ for i, id_ in enumerate(ray.get_gpu_ids()) if i < config["num_gpus"] ]) # Move model to device. self.model = model.to(self.device) # Lock used for locking some methods on the object-level. # This prevents possible race conditions when calling the model # first, then its value function (e.g. in a loss function), in # between of which another model call is made (e.g. to compute an # action). self._lock = threading.RLock() self._state_inputs = self.model.get_initial_state() self._is_recurrent = len(self._state_inputs) > 0 # Auto-update model's inference view requirements, if recurrent. self._update_model_view_requirements_from_init_state() # Combine view_requirements for Model and Policy. self.view_requirements.update(self.model.view_requirements) self.exploration = self._create_exploration() self.unwrapped_model = model # used to support DistributedDataParallel self._loss = loss self._optimizers = force_list(self.optimizer()) # Store, which params (by index within the model's list of # parameters) should be updated per optimizer. # Maps optimizer idx to set or param indices. self.multi_gpu_param_groups: List[Set[int]] = [] main_params = {p: i for i, p in enumerate(self.model.parameters())} for o in self._optimizers: param_indices = [] for pg_idx, pg in enumerate(o.param_groups): for p in pg["params"]: param_indices.append(main_params[p]) self.multi_gpu_param_groups.append(set(param_indices)) self.dist_class = action_distribution_class self.action_sampler_fn = action_sampler_fn self.action_distribution_fn = action_distribution_fn # If set, means we are using distributed allreduce during learning. self.distributed_world_size = None self.max_seq_len = max_seq_len self.batch_divisibility_req = get_batch_divisibility_req(self) if \ callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1)
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = policy.target_models[model] # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches seq_lens = train_batch.get(SampleBatch.SEQ_LENS) model_out_t, state_in_t = model( { "obs": train_batch[SampleBatch.CUR_OBS], "prev_actions": train_batch[SampleBatch.PREV_ACTIONS], "prev_rewards": train_batch[SampleBatch.PREV_REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"]) model_out_tp1, state_in_tp1 = model( { "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) target_model_out_tp1, target_state_in_tp1 = target_model( { "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) target_states_in_tp1 = target_model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = F.log_softmax(model.get_policy_output( model_out_t, states_in_t["policy"], seq_lens)[0], dim=-1) policy_t = torch.exp(log_pis_t) log_pis_tp1 = F.log_softmax( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)[0] # Target Q-values. q_tp1 = target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens)[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t, states_in_t["twin_q"], seq_lens)[0] twin_q_tp1 = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens)[0] q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * \ q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class( model.get_policy_output(model_out_t, states_in_t["policy"], seq_lens)[0], model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, policy_t)[0] if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, policy_t)[0] q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens, policy_tp1)[0] if policy.config["twin_q"]: twin_q_tp1 = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens, policy_tp1)[0] # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # BURNIN # B = state_batches[0].shape[0] T = q_t_selected.shape[0] // B seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T) # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False seq_mask = seq_mask.reshape(-1) num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) td_error = td_error * seq_mask # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = reduce_mean_valid( torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = reduce_mean_valid( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach()), dim=-1)) else: alpha_loss = -reduce_mean_valid( model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t * seq_mask[..., None] model.tower_stats["policy_t"] = policy_t * seq_mask[..., None] model.tower_stats["log_pis_t"] = log_pis_t * seq_mask[..., None] model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss model.tower_stats["alpha_loss"] = alpha_loss # Store per time chunk (b/c we need only one mean # prioritized replay weight per stored sequence). model.tower_stats["td_error"] = torch.mean(td_error.reshape([-1, T]), dim=-1) # Return all loss terms corresponding to our optimizers. return tuple([actor_loss] + critic_loss + [alpha_loss])
def ppo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ logits, state = model.from_batch(train_batch) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: # Derive max_seq_len from the data itself, not from the seq_lens # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still # 0-padded up to T=5 (as it's the case for attention nets). B = tf.shape(train_batch["seq_lens"])[0] max_seq_len = tf.shape(logits)[0] // B mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) mask = tf.reshape(mask, [-1]) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) # non-RNN case: No masking. else: mask = None reduce_mean_valid = tf.reduce_mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = tf.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = tf.minimum( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * tf.clip_by_value( logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) if policy.config["use_gae"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = tf.math.square(value_fn_out - train_batch[Postprocessing.VALUE_TARGETS]) vf_clipped = prev_value_fn_out + tf.clip_by_value( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = tf.math.square(vf_clipped - train_batch[Postprocessing.VALUE_TARGETS]) vf_loss = tf.maximum(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) total_loss = reduce_mean_valid( -surrogate_loss + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) else: mean_vf_loss = tf.constant(0.0) total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl - policy.entropy_coeff * curr_entropy) # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl return total_loss
def model_value_predictions( policy: Policy, input_dict: Dict[str, TensorType], state_batches, model: ModelV2, action_dist: ActionDistribution) -> Dict[str, TensorType]: return {SampleBatch.VF_PREDS: model.value_function()}
def loss( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> List[TensorType]: target_model = self.target_models[model] twin_q = self.config["twin_q"] gamma = self.config["gamma"] n_step = self.config["n_step"] use_huber = self.config["use_huber"] huber_threshold = self.config["huber_threshold"] l2_reg = self.config["l2_reg"] input_dict = SampleBatch( obs=train_batch[SampleBatch.CUR_OBS], _is_training=True ) input_dict_next = SampleBatch( obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True ) model_out_t, _ = model(input_dict, [], None) model_out_tp1, _ = model(input_dict_next, [], None) target_model_out_tp1, _ = target_model(input_dict_next, [], None) # Policy network evaluation. # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) policy_t = model.get_policy_output(model_out_t) # policy_batchnorm_update_ops = list( # set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) policy_tp1 = target_model.get_policy_output(target_model_out_tp1) # Action outputs. if self.config["smooth_target_policy"]: target_noise_clip = self.config["target_noise_clip"] clipped_normal_sample = torch.clamp( torch.normal( mean=torch.zeros(policy_tp1.size()), std=self.config["target_noise"] ).to(policy_tp1.device), -target_noise_clip, target_noise_clip, ) policy_tp1_smoothed = torch.min( torch.max( policy_tp1 + clipped_normal_sample, torch.tensor( self.action_space.low, dtype=torch.float32, device=policy_tp1.device, ), ), torch.tensor( self.action_space.high, dtype=torch.float32, device=policy_tp1.device, ), ) else: # No smoothing, just use deterministic actions. policy_tp1_smoothed = policy_tp1 # Q-net(s) evaluation. # prev_update_ops = set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) # Q-values for given actions & observations in given current q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy (no noise) in given current state q_t_det_policy = model.get_q_values(model_out_t, policy_t) actor_loss = -torch.mean(q_t_det_policy) if twin_q: twin_q_t = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS] ) # q_batchnorm_update_ops = list( # set(tf1.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) # Target q-net(s) evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed) if twin_q: twin_q_tp1 = target_model.get_twin_q_values( target_model_out_tp1, policy_tp1_smoothed ) q_t_selected = torch.squeeze(q_t, axis=len(q_t.shape) - 1) if twin_q: twin_q_t_selected = torch.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1_best = torch.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # Compute RHS of bellman equation. q_t_selected_target = ( train_batch[SampleBatch.REWARDS] + gamma ** n_step * q_tp1_best_masked ).detach() # Compute the error (potentially clipped). if twin_q: td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) + huber_loss( twin_td_error, huber_threshold ) else: errors = 0.5 * ( torch.pow(td_error, 2.0) + torch.pow(twin_td_error, 2.0) ) else: td_error = q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) else: errors = 0.5 * torch.pow(td_error, 2.0) critic_loss = torch.mean(train_batch[PRIO_WEIGHTS] * errors) # Add l2-regularization if required. if l2_reg is not None: for name, var in model.policy_variables(as_dict=True).items(): if "bias" not in name: actor_loss += l2_reg * l2_loss(var) for name, var in model.q_variables(as_dict=True).items(): if "bias" not in name: critic_loss += l2_reg * l2_loss(var) # Model self-supervised losses. if self.config["use_state_preprocessor"]: # Expand input_dict in case custom_loss' need them. input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS] input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] [actor_loss, critic_loss] = model.custom_loss( [actor_loss, critic_loss], input_dict ) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["q_t"] = q_t model.tower_stats["actor_loss"] = actor_loss model.tower_stats["critic_loss"] = critic_loss # TD-error tensor in final stats # will be concatenated and retrieved for each individual batch item. model.tower_stats["td_error"] = td_error # Return two loss terms (corresponding to the two optimizers, we create). return [actor_loss, critic_loss]
def _get_q_value(self, model: ModelV2, model_out: TensorType, actions: TensorType) -> TensorType: # helper function to compute the pessimistic q value q1 = model.get_q_values(model_out, actions) q2 = model.get_twin_q_values(model_out, actions) return torch.minimum(q1, q2)
def appo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 # TODO: (sven) deprecate this when trajectory view API gets activated. def make_time_major(*args, **kw): return _make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = policy.target_model.from_batch(train_batch) prev_action_dist = dist_class(behaviour_logits, policy.model) values = policy.model.value_function() values_time_major = make_time_major(values) policy.model_vars = policy.model.variables() policy.target_model_vars = policy.target_model.variables() if policy.is_recurrent(): max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - 1 mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) mask = tf.reshape(mask, [-1]) mask = make_time_major(mask, drop_last=policy.config["vtrace"]) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) else: reduce_mean_valid = tf.reduce_mean if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True)") # Prepare actions for loss. loss_actions = actions if is_multidiscrete else tf.expand_dims(actions, axis=1) old_policy_behaviour_logits = tf.stop_gradient(target_model_out) old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) # Prepare KL for Loss mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist), drop_last=True) unpacked_behaviour_logits = tf.split(behaviour_logits, output_hidden_shape, axis=1) unpacked_old_policy_behaviour_logits = tf.split( old_policy_behaviour_logits, output_hidden_shape, axis=1) # Compute vtrace on the CPU for better perf. with tf.device("/cpu:0"): vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=make_time_major( unpacked_behaviour_logits, drop_last=True), target_policy_logits=make_time_major( unpacked_old_policy_behaviour_logits, drop_last=True), actions=tf.unstack(make_time_major(loss_actions, drop_last=True), axis=2), discounts=tf.cast( ~make_time_major(tf.cast(dones, tf.bool), drop_last=True), tf.float32) * policy.config["gamma"], rewards=make_time_major(rewards, drop_last=True), values=values_time_major[:-1], # drop-last=True bootstrap_value=values_time_major[-1], dist_class=Categorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=tf.cast( policy.config["vtrace_clip_rho_threshold"], tf.float32), clip_pg_rho_threshold=tf.cast( policy.config["vtrace_clip_pg_rho_threshold"], tf.float32), ) actions_logp = make_time_major(action_dist.logp(actions), drop_last=True) prev_actions_logp = make_time_major(prev_action_dist.logp(actions), drop_last=True) old_policy_actions_logp = make_time_major( old_policy_action_dist.logp(actions), drop_last=True) is_ratio = tf.clip_by_value( tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp) policy._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) action_kl = tf.reduce_mean(mean_kl, axis=0) \ if is_multidiscrete else mean_kl mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. delta = values_time_major[:-1] - vtrace_returns.vs value_targets = vtrace_returns.vs mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) # The entropy loss. actions_entropy = make_time_major(action_dist.multi_entropy(), drop_last=True) mean_entropy = reduce_mean_valid(actions_entropy) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist)) logp_ratio = tf.math.exp( make_time_major(action_dist.logp(actions)) - make_time_major(prev_action_dist.logp(actions))) advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) action_kl = tf.reduce_mean(mean_kl, axis=0) \ if is_multidiscrete else mean_kl mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = make_time_major( train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) # The entropy loss. mean_entropy = reduce_mean_valid( make_time_major(action_dist.multi_entropy())) # The summed weighted loss total_loss = mean_policy_loss + \ mean_vf_loss * policy.config["vf_loss_coeff"] - \ mean_entropy * policy.config["entropy_coeff"] # Optional additional KL Loss if policy.config["use_kl_loss"]: total_loss += policy.kl_coeff * mean_kl policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_kl = mean_kl policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._value_targets = value_targets # Store stats in policy for stats_fn. return total_loss
def _compute_adv_and_logps( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> None: # uses mean|max to compute estimate of advantages # continuous/discrete action spaces: # A(s_t, a_t) = Q(s_t, a_t) - max_{a^j} Q(s_t, a^j) # where a^j is m times sampled from the policy p(a | s_t) # A(s_t, a_t) = Q(s_t, a_t) - avg( Q(s_t, a^j) ) # where a^j is m times sampled from the policy p(a | s_t) # questions: Do we use pessimistic q approximate or the normal one? advantage_type = self.config["advantage_type"] n_action_sample = self.config["n_action_sample"] batch_size = len(train_batch) out_t, _ = model(train_batch) # construct pi(s_t) for sampling actions pi_s_t = dist_class(model.get_policy_output(out_t), model) policy_actions = pi_s_t.dist.sample((n_action_sample, )) # samples if self._is_action_discrete: flat_actions = policy_actions.reshape(-1) else: flat_actions = policy_actions.reshape(-1, *self.action_space.shape) # compute the logp of the actions in the dataset (for computing actor's loss) action_logp = pi_s_t.dist.log_prob(train_batch[SampleBatch.ACTIONS]) # fix the shape if it's not canonical (i.e. shape[-1] != 1) if len(action_logp.shape) <= 1: action_logp.unsqueeze_(-1) train_batch[SampleBatch.ACTION_LOGP] = action_logp reshaped_s_t = train_batch[SampleBatch.OBS].view( 1, batch_size, *self.observation_space.shape) reshaped_s_t = reshaped_s_t.expand(n_action_sample, batch_size, *self.observation_space.shape) flat_s_t = reshaped_s_t.reshape(-1, *self.observation_space.shape) input_v_t = SampleBatch(**{ SampleBatch.OBS: flat_s_t, SampleBatch.ACTIONS: flat_actions }) out_v_t, _ = model(input_v_t) flat_q_st_pi = self._get_q_value(model, out_v_t, flat_actions) reshaped_q_st_pi = flat_q_st_pi.reshape(-1, batch_size, 1) if advantage_type == "mean": v_t = reshaped_q_st_pi.mean(dim=0) elif advantage_type == "max": v_t, _ = reshaped_q_st_pi.max(dim=0) else: raise ValueError(f"Invalid advantage type: {advantage_type}.") q_t = self._get_q_value(model, out_t, train_batch[SampleBatch.ACTIONS]) adv_t = q_t - v_t train_batch["advantages"] = adv_t # logging self.log("q_batch_avg", q_t.mean()) self.log("q_batch_max", q_t.max()) self.log("q_batch_min", q_t.min()) self.log("v_batch_avg", v_t.mean()) self.log("v_batch_max", v_t.max()) self.log("v_batch_min", v_t.min()) self.log("adv_batch_avg", adv_t.mean()) self.log("adv_batch_max", adv_t.max()) self.log("adv_batch_min", adv_t.min()) self.log("reward_batch_avg", train_batch[SampleBatch.REWARDS].mean())
def sac_actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] # Get the base model output from the train batch. model_out_t, _ = model({ "obs": train_batch[SampleBatch.CUR_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Get the base model output from the next observations in the train batch. model_out_tp1, _ = model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Get the target model's base outputs from the next observations in the # train batch. target_model_out_tp1, _ = policy.target_model({ "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Discrete actions case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = tf.nn.log_softmax(model.get_policy_output(model_out_t), -1) policy_t = tf.math.exp(log_pis_t) log_pis_tp1 = tf.nn.log_softmax( model.get_policy_output(model_out_tp1), -1) policy_tp1 = tf.math.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t) # Target Q-values. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t) twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1) q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_tp1 -= model.alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = tf.one_hot( train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1]) q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1) if policy.config["twin_q"]: twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1) q_tp1_best_masked = \ (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \ q_tp1_best # Continuous actions case. else: # Sample simgle actions from distribution. action_dist_class = _get_dist_class(policy.config, policy.action_space) action_dist_t = action_dist_class( model.get_policy_output(model_out_t), policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = tf.reduce_min( (q_t_det_policy, twin_q_t_det_policy), axis=0) # target q network evaluation q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) if policy.config["twin_q"]: twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 -= model.alpha * log_pis_tp1 q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best # Compute RHS of bellman equation for the Q-loss (critic(s)). q_t_selected_target = tf.stop_gradient( train_batch[SampleBatch.REWARDS] + policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked) # Compute the TD-error (potentially clipped). base_td_error = tf.math.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error # Calculate one or two critic losses (2 in the twin_q case). prio_weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) critic_loss = [tf.reduce_mean(prio_weights * huber_loss(base_td_error))] if policy.config["twin_q"]: critic_loss.append( tf.reduce_mean(prio_weights * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: alpha_loss = tf.reduce_mean( tf.reduce_sum( tf.multiply( tf.stop_gradient(policy_t), -model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)), axis=-1)) actor_loss = tf.reduce_mean( tf.reduce_sum( tf.multiply( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, model.alpha * log_pis_t - tf.stop_gradient(q_t)), axis=-1)) else: alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)) actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy) # Save for stats function. policy.policy_t = policy_t policy.q_t = q_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.alpha_value = model.alpha policy.target_entropy = model.target_entropy # In a custom apply op we handle the losses separately, but return them # combined in one loss here. return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _, train_batch: SampleBatch) -> TensorType: twin_q = policy.config["twin_q"] gamma = policy.config["gamma"] n_step = policy.config["n_step"] use_huber = policy.config["use_huber"] huber_threshold = policy.config["huber_threshold"] l2_reg = policy.config["l2_reg"] input_dict = SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True) input_dict_next = SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True) model_out_t, _ = model(input_dict, [], None) model_out_tp1, _ = model(input_dict_next, [], None) target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None) policy.target_q_func_vars = policy.target_model.variables() # Policy network evaluation. policy_t = model.get_policy_output(model_out_t) policy_tp1 = policy.target_model.get_policy_output(target_model_out_tp1) # Action outputs. if policy.config["smooth_target_policy"]: target_noise_clip = policy.config["target_noise_clip"] clipped_normal_sample = tf.clip_by_value( tf.random.normal(tf.shape(policy_tp1), stddev=policy.config["target_noise"]), -target_noise_clip, target_noise_clip, ) policy_tp1_smoothed = tf.clip_by_value( policy_tp1 + clipped_normal_sample, policy.action_space.low * tf.ones_like(policy_tp1), policy.action_space.high * tf.ones_like(policy_tp1), ) else: # No smoothing, just use deterministic actions. policy_tp1_smoothed = policy_tp1 # Q-net(s) evaluation. # prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) # Q-values for given actions & observations in given current q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy (no noise) in given current state q_t_det_policy = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Target q-net(s) evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1_smoothed) q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) if twin_q: twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 = tf.minimum(q_tp1, twin_q_tp1) q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = ( 1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best # Compute RHS of bellman equation. q_t_selected_target = tf.stop_gradient( tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + gamma**n_step * q_tp1_best_masked) # Compute the error (potentially clipped). if twin_q: td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) + huber_loss( twin_td_error, huber_threshold) else: errors = 0.5 * tf.math.square(td_error) + 0.5 * tf.math.square( twin_td_error) else: td_error = q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) else: errors = 0.5 * tf.math.square(td_error) critic_loss = tf.reduce_mean( tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) * errors) actor_loss = -tf.reduce_mean(q_t_det_policy) # Add l2-regularization if required. if l2_reg is not None: for var in policy.model.policy_variables(): if "bias" not in var.name: actor_loss += l2_reg * tf.nn.l2_loss(var) for var in policy.model.q_variables(): if "bias" not in var.name: critic_loss += l2_reg * tf.nn.l2_loss(var) # Model self-supervised losses. if policy.config["use_state_preprocessor"]: # Expand input_dict in case custom_loss' need them. input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS] input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] if log_once("ddpg_custom_loss"): logger.warning( "You are using a state-preprocessor with DDPG and " "therefore, `custom_loss` will be called on your Model! " "Please be aware that DDPG now uses the ModelV2 API, which " "merges all previously separate sub-models (policy_model, " "q_model, and twin_q_model) into one ModelV2, on which " "`custom_loss` is called, passing it " "[actor_loss, critic_loss] as 1st argument. " "You may have to change your custom loss function to handle " "this.") [actor_loss, critic_loss] = model.custom_loss([actor_loss, critic_loss], input_dict) # Store values for stats function. policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.td_error = td_error policy.q_t = q_t # Return one loss value (even though we treat them separately in our # 2 optimizers: actor and critic). return policy.critic_loss + policy.actor_loss
def context(self) -> contextlib.AbstractContextManager: """Returns a contextmanager for the current TF graph.""" if self.graph: return self.graph.as_default() else: return ModelV2.context(self)
def loss( self, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss function. Args: model: The Model to calculate the loss for. dist_class: The action distr. class. train_batch: The training data. Returns: The PPO loss tensor given the input batch. """ logits, state = model(train_batch) self.cur_lr = self.config["lr"] if self.config["worker_index"]: self.loss_obj = WorkerLoss( model=model, dist_class=dist_class, actions=train_batch[SampleBatch.ACTIONS], curr_logits=logits, behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], advantages=train_batch[Postprocessing.ADVANTAGES], value_fn=model.value_function(), value_targets=train_batch[Postprocessing.VALUE_TARGETS], vf_preds=train_batch[SampleBatch.VF_PREDS], cur_kl_coeff=0.0, entropy_coeff=self.config["entropy_coeff"], clip_param=self.config["clip_param"], vf_clip_param=self.config["vf_clip_param"], vf_loss_coeff=self.config["vf_loss_coeff"], clip_loss=False, ) else: self.var_list = model.named_parameters() # `split` may not exist yet (during test-loss call), use a dummy value. # Cannot use get here due to train_batch being a TrackingDict. if "split" in train_batch: split = train_batch["split"] else: split_shape = ( self.config["inner_adaptation_steps"], self.config["num_workers"], ) split_const = int(train_batch["obs"].shape[0] // (split_shape[0] * split_shape[1])) split = torch.ones(split_shape, dtype=int) * split_const self.loss_obj = MAMLLoss( model=model, dist_class=dist_class, value_targets=train_batch[Postprocessing.VALUE_TARGETS], advantages=train_batch[Postprocessing.ADVANTAGES], actions=train_batch[SampleBatch.ACTIONS], behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], vf_preds=train_batch[SampleBatch.VF_PREDS], cur_kl_coeff=self.kl_coeff_val, policy_vars=self.var_list, obs=train_batch[SampleBatch.CUR_OBS], num_tasks=self.config["num_workers"], split=split, config=self.config, inner_adaptation_steps=self.config["inner_adaptation_steps"], entropy_coeff=self.config["entropy_coeff"], clip_param=self.config["clip_param"], vf_clip_param=self.config["vf_clip_param"], vf_loss_coeff=self.config["vf_loss_coeff"], use_gae=self.config["use_gae"], meta_opt=self.meta_opt, ) return self.loss_obj.loss
def loss( self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = self.target_models[model] model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(self.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [self.action_space.n] elif isinstance(self.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = self.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kwargs): return make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs ) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = target_model(train_batch) prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() values_time_major = _make_time_major(values) drop_last = self.config["vtrace"] and self.config["vtrace_drop_last_ts"] if self.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask, [-1]) mask = _make_time_major(mask, drop_last=drop_last) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid else: reduce_mean_valid = torch.mean if self.config["vtrace"]: logger.debug( "Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})" ) old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split( behaviour_logits, list(output_hidden_shape), dim=1 ) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, list(output_hidden_shape), dim=1 ) else: unpacked_behaviour_logits = torch.chunk( behaviour_logits, output_hidden_shape, dim=1 ) unpacked_old_policy_behaviour_logits = torch.chunk( old_policy_behaviour_logits, output_hidden_shape, dim=1 ) # Prepare actions for loss. loss_actions = ( actions if is_multidiscrete else torch.unsqueeze(actions, dim=1) ) # Prepare KL for loss. action_kl = _make_time_major( old_policy_action_dist.kl(action_dist), drop_last=drop_last ) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=_make_time_major( unpacked_behaviour_logits, drop_last=drop_last ), target_policy_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=drop_last ), actions=torch.unbind( _make_time_major(loss_actions, drop_last=drop_last), dim=2 ), discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float()) * self.config["gamma"], rewards=_make_time_major(rewards, drop_last=drop_last), values=values_time_major[:-1] if drop_last else values_time_major, bootstrap_value=values_time_major[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], ) actions_logp = _make_time_major( action_dist.logp(actions), drop_last=drop_last ) prev_actions_logp = _make_time_major( prev_action_dist.logp(actions), drop_last=drop_last ) old_policy_actions_logp = _make_time_major( old_policy_action_dist.logp(actions), drop_last=drop_last ) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0 ) logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) self._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = vtrace_returns.vs.to(values_time_major.device) if drop_last: delta = values_time_major[:-1] - value_targets else: delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy(), drop_last=drop_last) ) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss action_kl = _make_time_major(prev_action_dist.kl(action_dist)) actions_logp = _make_time_major(action_dist.logp(actions)) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) logp_ratio = torch.exp(actions_logp - prev_actions_logp) advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp( logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"], ), ) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = _make_time_major(train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid(_make_time_major(action_dist.entropy())) # The summed weighted loss total_loss = ( mean_policy_loss + mean_vf_loss * self.config["vf_loss_coeff"] - mean_entropy * self.entropy_coeff ) # Optional additional KL Loss if self.config["use_kl_loss"]: total_loss += self.kl_coeff * mean_kl_loss # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_policy_loss"] = mean_policy_loss model.tower_stats["mean_kl_loss"] = mean_kl_loss model.tower_stats["mean_vf_loss"] = mean_vf_loss model.tower_stats["mean_entropy"] = mean_entropy model.tower_stats["value_targets"] = value_targets model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] if drop_last else values_time_major, [-1] ), ) return total_loss
def ppo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ logits, state = model(train_batch) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len, time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid # non-RNN case: No masking. else: mask = None reduce_mean_valid = torch.mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = torch.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl_loss = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) # Compute a value function loss. if policy.config["use_critic"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_clipped = prev_value_fn_out + torch.clamp( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = torch.pow( vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss = torch.max(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) # Ignore the value function. else: vf_loss = mean_vf_loss = 0.0 total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_policy_loss"] = mean_policy_loss model.tower_stats["mean_vf_loss"] = mean_vf_loss model.tower_stats["vf_explained_var"] = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], model.value_function()) model.tower_stats["mean_entropy"] = mean_entropy model.tower_stats["mean_kl_loss"] = mean_kl_loss return total_loss