def pg_torch_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """The basic policy gradients loss function. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Pass the training data through our model to get distribution parameters. dist_inputs, _ = model.from_batch(train_batch) # Create an action distribution object. action_dist = dist_class(dist_inputs, model) # Calculate the vanilla PG loss based on: # L = -E[ log(pi(a|s)) * A] log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS]) # Save the loss in the policy object for the stats_fn below. policy.pi_err = -torch.mean( log_probs * train_batch[Postprocessing.ADVANTAGES]) return policy.pi_err
def spl_torch_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """The basic policy gradients loss function. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Pass the training data through our model to get distribution parameters. dist_inputs, _ = model.from_batch(train_batch) # Create an action distribution object. predictions = dist_class(dist_inputs, model) targets = [] if policy.config["learn_action"]: targets.append(train_batch[SampleBatch.ACTIONS]) if policy.config["learn_reward"]: targets.append(train_batch[SampleBatch.REWARDS]) assert len(targets) > 0 targets = torch.cat(targets, dim=0) # Save the loss in the policy object for the spl_stats below. policy.spl_loss = policy.config["loss_fn"](predictions.dist.probs, targets) policy.entropy = predictions.dist.entropy().mean() return policy.spl_loss
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) state_values = model.value_function() advantages = train_batch[Postprocessing.ADVANTAGES] actions = train_batch[SampleBatch.ACTIONS] # Advantage estimation. adv = advantages - state_values adv_squared = torch.mean(torch.pow(adv, 2.0)) # Value loss. policy.v_loss = 0.5 * adv_squared # Policy loss. # Update averaged advantage norm. policy.ma_adv_norm.add_(1e-6 * (adv_squared - policy.ma_adv_norm)) # Exponentially weighted advantages. exp_advs = torch.exp(policy.config["beta"] * (adv / (1e-8 + torch.pow(policy.ma_adv_norm, 0.5)))) # log\pi_\theta(a|s) logprobs = action_dist.logp(actions) policy.p_loss = -1.0 * torch.mean(exp_advs.detach() * logprobs) # Combine both losses. policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * \ policy.v_loss explained_var = explained_variance(advantages, state_values) policy.explained_variance = torch.mean(explained_var) return policy.total_loss
def pg_tf_loss( policy: Policy, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """The basic policy gradients loss function. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Pass the training data through our model to get distribution parameters. dist_inputs, _ = model.from_batch(train_batch) # Create an action distribution object. action_dist = dist_class(dist_inputs, model) # Calculate the vanilla PG loss based on: # L = -E[ log(pi(a|s)) * A] return -tf.reduce_mean( action_dist.logp(train_batch[SampleBatch.ACTIONS]) * tf.cast(train_batch[Postprocessing.ADVANTAGES], dtype=tf.float32))
def pg_torch_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """The basic policy gradients loss function. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Pass the training data through our model to get distribution parameters. dist_inputs, _ = model.from_batch(train_batch) # Create an action distribution object. action_dist = dist_class(dist_inputs, model) # Calculate the vanilla PG loss based on: # L = -E[ log(pi(a|s)) * A] log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS]) # Final policy loss. policy_loss = -torch.mean( log_probs * train_batch[Postprocessing.ADVANTAGES]) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["policy_loss"] = policy_loss return policy_loss
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) value_estimates = model.value_function() policy.loss = MARWILLoss(policy, value_estimates, action_dist, train_batch, policy.config["vf_coeff"], policy.config["beta"]) return policy.loss.total_loss
def spl_torch_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: """The basic policy gradients loss function. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Pass the training data through our model to get distribution parameters. dist_inputs, _ = model.from_batch(train_batch) # Create an action distribution object. action_dist = dist_class(dist_inputs, model) if policy.config["explore"]: # Adding that because of a bug in TorchCategorical # which modify dist_inputs through action_dist: _, _ = policy.exploration.get_exploration_action( action_distribution=action_dist, timestep=policy.global_timestep, explore=policy.config["explore"], ) action_dist = dist_class(dist_inputs, policy.model) targets = [] if policy.config["learn_action"]: targets.append(train_batch[SampleBatch.ACTIONS]) if policy.config["learn_reward"]: targets.append(train_batch[SampleBatch.REWARDS]) assert len(targets) > 0, (f"In config, use learn_action=True and/or " f"learn_reward=True to specify which target to " f"use in supervised learning") targets = torch.cat(targets, dim=0) # Save the loss in the policy object for the spl_stats below. policy.spl_loss = policy.config["loss_fn"](action_dist.dist.probs, targets) policy.entropy = action_dist.dist.entropy().mean() return policy.spl_loss
def actor_critic_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: logits, _ = model.from_batch(train_batch) values = model.value_function() if policy.is_recurrent(): B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) valid_mask = torch.reshape(mask_orig, [-1]) else: valid_mask = torch.ones_like(values, dtype=torch.bool) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) pi_err = -torch.sum( torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask)) # Compute a value function loss. if policy.config["use_critic"]: value_err = 0.5 * torch.sum( torch.pow( torch.masked_select( values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], valid_mask), 2.0)) # Ignore the value function. else: value_err = 0.0 entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] - entropy * policy.config["entropy_coeff"]) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["entropy"] = entropy model.tower_stats["pi_err"] = pi_err model.tower_stats["value_err"] = value_err return total_loss
def actor_critic_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if policy.is_recurrent(): max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = tf.reshape(mask, [-1]) else: mask = tf.ones_like(train_batch[SampleBatch.REWARDS]) policy.loss = A3CLoss(action_dist, train_batch[SampleBatch.ACTIONS], train_batch[Postprocessing.ADVANTAGES], train_batch[Postprocessing.VALUE_TARGETS], model.value_function(), mask, policy.config["vf_loss_coeff"], policy.config["entropy_coeff"], policy.config.get("use_critic", True)) return policy.loss.total_loss
def marwil_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) actions = train_batch[SampleBatch.ACTIONS] # log\pi_\theta(a|s) logprobs = action_dist.logp(actions) # Advantage estimation. if policy.config["beta"] != 0.0: cumulative_rewards = train_batch[Postprocessing.ADVANTAGES] state_values = model.value_function() adv = cumulative_rewards - state_values adv_squared_mean = torch.mean(torch.pow(adv, 2.0)) explained_var = explained_variance(cumulative_rewards, state_values) policy.explained_variance = torch.mean(explained_var) # Policy loss. # Update averaged advantage norm. rate = policy.config["moving_average_sqd_adv_norm_update_rate"] policy._moving_average_sqd_adv_norm.add_( rate * (adv_squared_mean - policy._moving_average_sqd_adv_norm)) # Exponentially weighted advantages. exp_advs = torch.exp( policy.config["beta"] * (adv / (1e-8 + torch.pow(policy._moving_average_sqd_adv_norm, 0.5)))) policy.p_loss = -torch.mean(exp_advs.detach() * logprobs) # Value loss. policy.v_loss = 0.5 * adv_squared_mean else: # Policy loss (simple BC loss term). policy.p_loss = -1.0 * torch.mean(logprobs) # Value loss. policy.v_loss = 0.0 # Combine both losses. policy.total_loss = policy.p_loss + policy.config["vf_coeff"] * \ policy.v_loss return policy.total_loss
def actor_critic_loss(policy: Policy, model: ModelV2, dist_class: ActionDistribution, train_batch: SampleBatch) -> TensorType: logits, _ = model.from_batch(train_batch) values = model.value_function() if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len) valid_mask = torch.reshape(mask_orig, [-1]) else: valid_mask = torch.ones_like(values, dtype=torch.bool) dist = dist_class(logits, model) log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) pi_err = -torch.sum( torch.masked_select(log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask)) # Compute a value function loss. if policy.config["use_critic"]: value_err = 0.5 * torch.sum( torch.pow( torch.masked_select( values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], valid_mask), 2.0)) # Ignore the value function. else: value_err = 0.0 entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) total_loss = (pi_err + value_err * policy.config["vf_loss_coeff"] - entropy * policy.config["entropy_coeff"]) policy.entropy = entropy policy.pi_err = pi_err policy.value_err = value_err return total_loss
def appo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 # TODO: (sven) deprecate this when trajectory view API gets activated. def make_time_major(*args, **kw): return _make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = policy.target_model.from_batch(train_batch) prev_action_dist = dist_class(behaviour_logits, policy.model) values = policy.model.value_function() values_time_major = make_time_major(values) policy.model_vars = policy.model.variables() policy.target_model_vars = policy.target_model.variables() if policy.is_recurrent(): max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - 1 mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) mask = tf.reshape(mask, [-1]) mask = make_time_major(mask, drop_last=policy.config["vtrace"]) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) else: reduce_mean_valid = tf.reduce_mean if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True)") # Prepare actions for loss. loss_actions = actions if is_multidiscrete else tf.expand_dims(actions, axis=1) old_policy_behaviour_logits = tf.stop_gradient(target_model_out) old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) # Prepare KL for Loss mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist), drop_last=True) unpacked_behaviour_logits = tf.split(behaviour_logits, output_hidden_shape, axis=1) unpacked_old_policy_behaviour_logits = tf.split( old_policy_behaviour_logits, output_hidden_shape, axis=1) # Compute vtrace on the CPU for better perf. with tf.device("/cpu:0"): vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=make_time_major( unpacked_behaviour_logits, drop_last=True), target_policy_logits=make_time_major( unpacked_old_policy_behaviour_logits, drop_last=True), actions=tf.unstack(make_time_major(loss_actions, drop_last=True), axis=2), discounts=tf.cast( ~make_time_major(tf.cast(dones, tf.bool), drop_last=True), tf.float32) * policy.config["gamma"], rewards=make_time_major(rewards, drop_last=True), values=values_time_major[:-1], # drop-last=True bootstrap_value=values_time_major[-1], dist_class=Categorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=tf.cast( policy.config["vtrace_clip_rho_threshold"], tf.float32), clip_pg_rho_threshold=tf.cast( policy.config["vtrace_clip_pg_rho_threshold"], tf.float32), ) actions_logp = make_time_major(action_dist.logp(actions), drop_last=True) prev_actions_logp = make_time_major(prev_action_dist.logp(actions), drop_last=True) old_policy_actions_logp = make_time_major( old_policy_action_dist.logp(actions), drop_last=True) is_ratio = tf.clip_by_value( tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp) policy._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) action_kl = tf.reduce_mean(mean_kl, axis=0) \ if is_multidiscrete else mean_kl mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. delta = values_time_major[:-1] - vtrace_returns.vs value_targets = vtrace_returns.vs mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) # The entropy loss. actions_entropy = make_time_major(action_dist.multi_entropy(), drop_last=True) mean_entropy = reduce_mean_valid(actions_entropy) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist)) logp_ratio = tf.math.exp( make_time_major(action_dist.logp(actions)) - make_time_major(prev_action_dist.logp(actions))) advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) action_kl = tf.reduce_mean(mean_kl, axis=0) \ if is_multidiscrete else mean_kl mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = make_time_major( train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) # The entropy loss. mean_entropy = reduce_mean_valid( make_time_major(action_dist.multi_entropy())) # The summed weighted loss total_loss = mean_policy_loss + \ mean_vf_loss * policy.config["vf_loss_coeff"] - \ mean_entropy * policy.config["entropy_coeff"] # Optional additional KL Loss if policy.config["use_kl_loss"]: total_loss += policy.kl_coeff * mean_kl policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_kl = mean_kl policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._value_targets = value_targets # Store stats in policy for stats_fn. return total_loss
def ppo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ logits, state = model.from_batch(train_batch, is_training=True) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: B = len(train_batch["seq_lens"]) max_seq_len = logits.shape[0] // B mask = sequence_mask(train_batch["seq_lens"], max_seq_len, time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid # non-RNN case: No masking. else: mask = None reduce_mean_valid = torch.mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = torch.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) if policy.config["use_gae"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_clipped = prev_value_fn_out + torch.clamp( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = torch.pow( vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss = torch.max(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) else: mean_vf_loss = 0.0 total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl - policy.entropy_coeff * curr_entropy) # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._vf_explained_var = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], policy.model.value_function()) policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl return total_loss
def appo_surrogate_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> TensorType: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = policy.target_models[model] model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = target_model.from_batch(train_batch) prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() values_time_major = _make_time_major(values) if policy.is_recurrent(): max_seq_len = torch.max(train_batch["seq_lens"]) mask = sequence_mask(train_batch["seq_lens"], max_seq_len) mask = torch.reshape(mask, [-1]) mask = _make_time_major(mask, drop_last=policy.config["vtrace"]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid else: reduce_mean_valid = torch.mean if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True)") old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_old_policy_behaviour_logits = torch.chunk( old_policy_behaviour_logits, output_hidden_shape, dim=1) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze( actions, dim=1) # Prepare KL for loss. action_kl = _make_time_major(old_policy_action_dist.kl(action_dist), drop_last=True) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits, drop_last=True), target_policy_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=True), actions=torch.unbind(_make_time_major(loss_actions, drop_last=True), dim=2), discounts=(1.0 - _make_time_major(dones, drop_last=True).float()) * policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=True), values=values_time_major[:-1], # drop-last=True bootstrap_value=values_time_major[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"] ) actions_logp = _make_time_major(action_dist.logp(actions), drop_last=True) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions), drop_last=True) old_policy_actions_logp = _make_time_major( old_policy_action_dist.logp(actions), drop_last=True) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) policy._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = vtrace_returns.vs.to(values_time_major.device) delta = values_time_major[:-1] - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy(), drop_last=True)) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss action_kl = _make_time_major(prev_action_dist.kl(action_dist)) actions_logp = _make_time_major(action_dist.logp(actions)) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) logp_ratio = torch.exp(actions_logp - prev_actions_logp) advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = _make_time_major( train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy())) # The summed weighted loss total_loss = mean_policy_loss + \ mean_vf_loss * policy.config["vf_loss_coeff"] - \ mean_entropy * policy.config["entropy_coeff"] # Optional additional KL Loss if policy.config["use_kl_loss"]: total_loss += policy.kl_coeff * mean_kl policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_kl = mean_kl policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._value_targets = value_targets policy._vf_explained_var = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] if policy.config["vtrace"] else values_time_major, [-1]), ) return total_loss
def ppo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ logits, state = model.from_batch(train_batch) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: # Derive max_seq_len from the data itself, not from the seq_lens # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still # 0-padded up to T=5 (as it's the case for attention nets). B = tf.shape(train_batch["seq_lens"])[0] max_seq_len = tf.shape(logits)[0] // B mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) mask = tf.reshape(mask, [-1]) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) # non-RNN case: No masking. else: mask = None reduce_mean_valid = tf.reduce_mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = tf.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = tf.minimum( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * tf.clip_by_value( logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) if policy.config["use_gae"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = tf.math.square(value_fn_out - train_batch[Postprocessing.VALUE_TARGETS]) vf_clipped = prev_value_fn_out + tf.clip_by_value( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = tf.math.square(vf_clipped - train_batch[Postprocessing.VALUE_TARGETS]) vf_loss = tf.maximum(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) total_loss = reduce_mean_valid( -surrogate_loss + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) else: mean_vf_loss = tf.constant(0.0) total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl - policy.entropy_coeff * curr_entropy) # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl return total_loss