示例#1
0
 def ClippedObjectiveMean(
     dist_inputs, values, returns, actions, old_log_probs):
   """Clipped objective from the PPO algorithm."""
   advantages = returns - values
   probs_ratio = rl_layers.ProbsRatio(
       dist_inputs, actions, old_log_probs,
       log_prob_fun=self._policy_dist.log_prob)
   clipped_objective = rl_layers.ClippedObjective(
       probs_ratio, advantages, epsilon=self._epsilon)
   return jnp.mean(clipped_objective)
示例#2
0
 def f(dist_inputs, values, returns, actions, old_log_probs):
   """Clipped objective from the PPO algorithm."""
   advantages = returns - values
   probs_ratio = rl_layers.ProbsRatio(
       dist_inputs, actions, old_log_probs,
       log_prob_fun=self._policy_dist.log_prob)
   # advantages are of the shape [128,1,1]
   # and probs_ratio are of the shape [128,1]
   advantages = advantages.squeeze(axis=2)
   clipped_objective = rl_layers.ClippedObjective(
       probs_ratio, advantages, epsilon=self._epsilon)
   return jnp.mean(clipped_objective)