def f(dist_inputs, values, returns, actions, old_log_probs, mask): """Definition of the Proximal Policy Optimization loss.""" del mask # TODO(lukaszkaiser): make PPO work with Transformer ppo_objective = rl_layers.PPOObjective( dist_inputs, stop_gradient(values), returns, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) entropy_loss = rl_layers.EntropyLoss( dist_inputs, actions, log_prob_fun=self._policy_dist.log_prob, entropy_coeff=self._entropy_coeff, entropy_fun=self._policy_dist.entropy) l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) return -ppo_objective.mean() + l2_value_loss - entropy_loss
def f(dist_inputs, values, returns, actions, old_log_probs, mask): """Definition of the Proximal Policy Optimization loss.""" del mask # TODO(lukaszkaiser): make PPO work with Transformer # We have dist_inputs of the shape float32[128,1,18] assert len(dist_inputs.shape) == 3, ( f'dist_inputs.shape was {dist_inputs.shape}' f'but expected length of the tensor shape is 3') # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # and old_log_probs of the shape float32[128,1] assert values.shape == returns.shape, ( f'values.shape was {values.shape}' f'returns.shape was {returns.shape}') assert returns.shape[0:2] == old_log_probs.shape, ( f'returns.shape was {returns.shape}' f'old_log_probs.shape was {old_log_probs.shape}') # actions is a tensor of the shape int32[128,1] assert len( actions.shape) == 2, f'actions.shape was {actions.shape}' # which agrees with returns/values on the first two coordinates assert actions.shape[0:2] == returns.shape[0:2], ( f'actions.shape was {actions.shape} and' f'returns.shape was {returns.shape}') ppo_objective = rl_layers.PPOObjective( dist_inputs, stop_gradient(values), returns, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) # we insist that ppo_objective is a vector of shape [128,1] assert len(ppo_objective.shape) == 2, ( f'ppo_objective was {ppo_objective}') # which agrees with returns/values/actions on the first two coordinates assert ppo_objective.shape[0:2] == values.shape[0:2], ( f'ppo_objective.shape was {ppo_objective.shape} and ' f'values.shape was {values.shape}') entropy_loss = rl_layers.EntropyLoss( dist_inputs, actions, log_prob_fun=self._policy_dist.log_prob, entropy_coeff=self._entropy_coeff, entropy_fun=self._policy_dist.entropy) assert jnp.ndim( entropy_loss) == 0, f'entropy_loss was {entropy_loss}' l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) assert jnp.ndim( l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' return -ppo_objective.mean() + l2_value_loss - entropy_loss
def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs): """Clipped objective from the PPO algorithm.""" ppo_objective = rl_layers.PPOObjective( dist_inputs, values, returns, dones, rewards, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) return jnp.mean(ppo_objective)
def f(dist_inputs, values, returns, actions, old_log_probs): return rl_layers.PPOObjective( dist_inputs, values, returns, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages)
def ppo_objective(self): """PPO objective with local parameters.""" return tl.Fn( lambda dist_inputs, values, returns, actions, old_log_probs: rl_layers.PPOObjective( dist_inputs, values, returns, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages), n_out=1)