Exemple #1
0
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the Proximal Policy Optimization loss."""
            del mask  # TODO(lukaszkaiser): make PPO work with Transformer

            ppo_objective = rl_layers.PPOObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                old_log_probs,
                log_prob_fun=self._policy_dist.log_prob,
                epsilon=self._epsilon,
                normalize_advantages=self._normalize_advantages)

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            return -ppo_objective.mean() + l2_value_loss - entropy_loss
        def f(dist_inputs, values, returns, actions, old_log_probs, mask):
            """Definition of the Proximal Policy Optimization loss."""
            del mask  # TODO(lukaszkaiser): make PPO work with Transformer
            # We have dist_inputs of the shape float32[128,1,18]
            assert len(dist_inputs.shape) == 3, (
                f'dist_inputs.shape was {dist_inputs.shape}'
                f'but expected length of the tensor shape is 3')
            # values of the shape float32[128,1,1]
            # returns of the shape float32[128,1,1]
            # and old_log_probs of the shape float32[128,1]
            assert values.shape == returns.shape, (
                f'values.shape was {values.shape}'
                f'returns.shape was {returns.shape}')
            assert returns.shape[0:2] == old_log_probs.shape, (
                f'returns.shape was {returns.shape}'
                f'old_log_probs.shape was {old_log_probs.shape}')
            # actions is a tensor of the shape int32[128,1]
            assert len(
                actions.shape) == 2, f'actions.shape was {actions.shape}'
            # which agrees with returns/values on the first two coordinates
            assert actions.shape[0:2] == returns.shape[0:2], (
                f'actions.shape was {actions.shape} and'
                f'returns.shape was {returns.shape}')

            ppo_objective = rl_layers.PPOObjective(
                dist_inputs,
                stop_gradient(values),
                returns,
                actions,
                old_log_probs,
                log_prob_fun=self._policy_dist.log_prob,
                epsilon=self._epsilon,
                normalize_advantages=self._normalize_advantages)

            # we insist that ppo_objective is a vector of shape [128,1]
            assert len(ppo_objective.shape) == 2, (
                f'ppo_objective was {ppo_objective}')
            # which agrees with returns/values/actions on the first two coordinates
            assert ppo_objective.shape[0:2] == values.shape[0:2], (
                f'ppo_objective.shape was {ppo_objective.shape} and '
                f'values.shape was {values.shape}')

            entropy_loss = rl_layers.EntropyLoss(
                dist_inputs,
                actions,
                log_prob_fun=self._policy_dist.log_prob,
                entropy_coeff=self._entropy_coeff,
                entropy_fun=self._policy_dist.entropy)

            assert jnp.ndim(
                entropy_loss) == 0, f'entropy_loss was {entropy_loss}'

            l2_value_loss = rl_layers.ValueLoss(
                values, returns, value_loss_coeff=self._value_loss_coeff)

            assert jnp.ndim(
                l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}'

            return -ppo_objective.mean() + l2_value_loss - entropy_loss
Exemple #3
0
 def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs):
   """Clipped objective from the PPO algorithm."""
   ppo_objective = rl_layers.PPOObjective(
       dist_inputs, values, returns, dones, rewards, actions, old_log_probs,
       log_prob_fun=self._policy_dist.log_prob,
       epsilon=self._epsilon,
       normalize_advantages=self._normalize_advantages)
   return jnp.mean(ppo_objective)
Exemple #4
0
 def f(dist_inputs, values, returns, actions, old_log_probs):
     return rl_layers.PPOObjective(
         dist_inputs,
         values,
         returns,
         actions,
         old_log_probs,
         log_prob_fun=self._policy_dist.log_prob,
         epsilon=self._epsilon,
         normalize_advantages=self._normalize_advantages)
Exemple #5
0
 def ppo_objective(self):
   """PPO objective with local parameters."""
   return tl.Fn(
       lambda dist_inputs, values, returns, actions, old_log_probs:
       rl_layers.PPOObjective(
           dist_inputs, values, returns, actions, old_log_probs,
           log_prob_fun=self._policy_dist.log_prob,
           epsilon=self._epsilon,
           normalize_advantages=self._normalize_advantages),
       n_out=1)