コード例 #1
0
ファイル: actor_critic_joint.py プロジェクト: hugochan/trax
 def UnclippedObjectiveMean(dist_inputs, values,
                            returns, actions, old_log_probs):
   """Unclipped objective Mean from the PPO algorithm."""
   advantages = returns - values
   probs_ratio = rl_layers.ProbsRatio(
       dist_inputs, actions, old_log_probs,
       log_prob_fun=self._policy_dist.log_prob)
   unclipped_objective = rl_layers.UnclippedObjective(
       probs_ratio, advantages)
   return jnp.mean(unclipped_objective)
コード例 #2
0
ファイル: actor_critic_joint.py プロジェクト: srush/trax
 def f(dist_inputs, values, returns, actions, old_log_probs):
   """Unclipped objective Mean from the PPO algorithm."""
   advantages = returns - values
   probs_ratio = rl_layers.ProbsRatio(
       dist_inputs, actions, old_log_probs,
       log_prob_fun=self._policy_dist.log_prob)
   # advantages are of the shape [128,1,1]
   # and probs_ratio are of the shape [128,1]
   advantages = advantages.squeeze(axis=2)
   unclipped_objective = rl_layers.UnclippedObjective(
       probs_ratio, advantages)
   return jnp.mean(unclipped_objective)