def UnclippedObjectiveMean(dist_inputs, values, returns, actions, old_log_probs): """Unclipped objective Mean from the PPO algorithm.""" advantages = returns - values probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) unclipped_objective = rl_layers.UnclippedObjective( probs_ratio, advantages) return jnp.mean(unclipped_objective)
def f(dist_inputs, values, returns, actions, old_log_probs): """Unclipped objective Mean from the PPO algorithm.""" advantages = returns - values probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) # advantages are of the shape [128,1,1] # and probs_ratio are of the shape [128,1] advantages = advantages.squeeze(axis=2) unclipped_objective = rl_layers.UnclippedObjective( probs_ratio, advantages) return jnp.mean(unclipped_objective)