def PPOJointLoss(x, **unused_kwargs): """Definition of the Proximal Policy Optimization loss.""" dist_inputs, values, returns, actions, old_log_probs, mask = x del mask # TODO(lukaszkaiser): make PPO work with Transformer new_log_probs = self._policy_dist.log_prob(dist_inputs, actions) advantages = returns - values l2_value_loss = jnp.sum(advantages**2) * self._value_loss_coeff # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantages ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) entropy_loss = self._policy_dist.entropy(new_log_probs) *\ self._entropy_coeff return -ppo_objective.mean() + l2_value_loss - entropy_loss
def ClippedObjective(probs_ratio, advantages, epsilon): """Clipped Objective from the PPO algorithm.""" assert probs_ratio.shape == advantages.shape, ( f'probs_ratio.shape was {probs_ratio.shape} and' f'advantages.shape was {advantages.shape}') clipped_objective = jnp.clip(probs_ratio, 1 - epsilon, 1 + epsilon) * advantages assert probs_ratio.shape == clipped_objective.shape, ( f'probs_ratio.shape was {probs_ratio.shape} and' f'clipped_objective.shape was {clipped_objective.shape}') return clipped_objective
def f(new_log_probs, advantages, old_log_probs, mask): # Old log probs have an undesirable extra dimension which we remove here old_log_probs = old_log_probs.squeeze(axis=-1) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - epsilon, 1 + epsilon) * advantages ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) return -np.sum(ppo_objective * mask) / np.sum(mask)
def PPOLoss(x, epsilon, **unused_kwargs): """Definition of the Proximal Policy Optimization loss.""" (new_log_probs, advantages, old_log_probs, mask) = x # Old log probs have an undesirable extra dimension which we remove here old_log_probs = old_log_probs.squeeze(axis=-1) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - epsilon, 1 + epsilon) * advantages ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) return -np.sum(ppo_objective * mask) / np.sum(mask)
def f(new_log_probs, advantages, old_log_probs, mask): # new_log_probs of the shape float32[128,1] # advantages of the shape int32[128,1] # old_log_probs of the shape int32[128,1] # mask of the shape int32[128,1] if new_log_probs.shape != advantages.shape: raise ValueError('New log-probs and advantages shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, advantages.shape)) if new_log_probs.shape != old_log_probs.shape: raise ValueError('New log-probs and old log-probs shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, old_log_probs.shape)) if new_log_probs.shape != mask.shape: raise ValueError('New log-probs and mask shapes should be the same' ', %s != %s' % (new_log_probs.shape, mask.shape)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) if advantages.shape != probs_ratio.shape: raise ValueError('New log-probs and old log probs shapes ' 'should be the same, %s != %s' % (advantages.shape, probs_ratio.shape)) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantages if unclipped_objective.shape != probs_ratio.shape: raise ValueError('unclipped_objective and clipped_objective shapes ' 'should be the same, %s != %s' % ( unclipped_objective.shape, clipped_objective.shape)) ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) if ppo_objective.shape != mask.shape: raise ValueError('ppo_objective and mask shapes ' 'should be the same, %s != %s' % ( ppo_objective.shape, mask.shape)) ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask) entropy_vec = self._policy_dist.entropy( new_log_probs) * self._entropy_coeff entropy_loss = jnp.mean(entropy_vec) combined_loss = ppo_loss - entropy_loss return combined_loss
def ClippedObjective(probs_ratio, advantages, epsilon): """Clipped Objective from the PPO algorithm.""" clipped_objective = jnp.clip(probs_ratio, 1 - epsilon, 1 + epsilon) * advantages return clipped_objective