def ClippedObjective(probs_ratio, advantages, epsilon): """Clipped Objective from the PPO algorithm.""" assert probs_ratio.shape == advantages.shape, ( f'probs_ratio.shape was {probs_ratio.shape} and' f'advantages.shape was {advantages.shape}') clipped_objective = jnp.clip(probs_ratio, 1 - epsilon, 1 + epsilon) * advantages assert probs_ratio.shape == clipped_objective.shape, ( f'probs_ratio.shape was {probs_ratio.shape} and' f'clipped_objective.shape was {clipped_objective.shape}') return clipped_objective
def f(new_log_probs, advantages, old_log_probs, mask): # new_log_probs of the shape float32[128,1] # advantages of the shape int32[128,1] # old_log_probs of the shape int32[128,1] # mask of the shape int32[128,1] if new_log_probs.shape != advantages.shape: raise ValueError('New log-probs and advantages shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, advantages.shape)) if new_log_probs.shape != old_log_probs.shape: raise ValueError('New log-probs and old log-probs shapes ' 'should be the same, %s != %s' % (new_log_probs.shape, old_log_probs.shape)) if new_log_probs.shape != mask.shape: raise ValueError('New log-probs and mask shapes should be the same' ', %s != %s' % (new_log_probs.shape, mask.shape)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) if advantages.shape != probs_ratio.shape: raise ValueError('New log-probs and old log probs shapes ' 'should be the same, %s != %s' % (advantages.shape, probs_ratio.shape)) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantages if unclipped_objective.shape != probs_ratio.shape: raise ValueError('unclipped_objective and clipped_objective shapes ' 'should be the same, %s != %s' % ( unclipped_objective.shape, clipped_objective.shape)) ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) if ppo_objective.shape != mask.shape: raise ValueError('ppo_objective and mask shapes ' 'should be the same, %s != %s' % ( ppo_objective.shape, mask.shape)) ppo_loss = -jnp.sum(ppo_objective * mask) / jnp.sum(mask) entropy_vec = self._policy_dist.entropy( new_log_probs) * self._entropy_coeff entropy_loss = jnp.mean(entropy_vec) combined_loss = ppo_loss - entropy_loss return combined_loss