def ClippedObjectiveMean( dist_inputs, values, returns, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" advantages = returns - values probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) clipped_objective = rl_layers.ClippedObjective( probs_ratio, advantages, epsilon=self._epsilon) return jnp.mean(clipped_objective)
def f(dist_inputs, values, returns, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" advantages = returns - values probs_ratio = rl_layers.ProbsRatio( dist_inputs, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob) # advantages are of the shape [128,1,1] # and probs_ratio are of the shape [128,1] advantages = advantages.squeeze(axis=2) clipped_objective = rl_layers.ClippedObjective( probs_ratio, advantages, epsilon=self._epsilon) return jnp.mean(clipped_objective)