def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask): """Definition of the A2C loss.""" del dones, rewards, old_log_probs # Typically we have dist_inputs of the shape float32[128,1,18] assert len(dist_inputs.shape) == 3, ( f'dist_inputs.shape was {dist_inputs.shape} ' f'but expected length of the tensor shape is 3') # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] assert values.shape == returns.shape, ( f'values.shape was {values.shape}' f'returns.shape was (returns.shape)') # actions of the shape int32[128,1] in the case of discrete actions # and float32[128,1,6] in the case of of half-cheetah # actions agree with returns/values on the first two coordinates assert actions.shape[0:2] == returns.shape[0:2], ( f'actions.shape was {actions.shape}' f'returns.shape was (returns.shape)') # and mask of the shape float32[128,1] assert len(mask.shape) == 2, f'mask.shape was {mask.shape}' # which agrees with returns/values/actions on the first two coordinates assert mask.shape[0:2] == returns.shape[0:2], ( f'mask.shape was {mask.shape}' f'returns.shape was (returns.shape)') a2c_objective = rl_layers.A2CObjective( dist_inputs, stop_gradient(values), returns, actions, mask, log_prob_fun=self._policy_dist.log_prob, normalize_advantages=self._normalize_advantages) # we insist that a2c_objective is a scalar assert jnp.ndim( a2c_objective) == 0, f'a2c_objective was {a2c_objective}' entropy_loss = rl_layers.EntropyLoss( dist_inputs, actions, log_prob_fun=self._policy_dist.log_prob, entropy_coeff=self._entropy_coeff, entropy_fun=self._policy_dist.entropy) assert jnp.ndim( entropy_loss) == 0, f'entropy_loss was {entropy_loss}' l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) assert jnp.ndim( l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' combined_loss = a2c_objective + l2_value_loss - entropy_loss return combined_loss
def f(dist_inputs, values, returns, actions, old_log_probs, mask): """Definition of the Proximal Policy Optimization loss.""" del mask # TODO(lukaszkaiser): make PPO work with Transformer # We have dist_inputs of the shape float32[128,1,18] assert len(dist_inputs.shape) == 3, ( f'dist_inputs.shape was {dist_inputs.shape}' f'but expected length of the tensor shape is 3') # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # and old_log_probs of the shape float32[128,1] assert values.shape == returns.shape, ( f'values.shape was {values.shape}' f'returns.shape was {returns.shape}') assert returns.shape[0:2] == old_log_probs.shape, ( f'returns.shape was {returns.shape}' f'old_log_probs.shape was {old_log_probs.shape}') # actions is a tensor of the shape int32[128,1] assert len( actions.shape) == 2, f'actions.shape was {actions.shape}' # which agrees with returns/values on the first two coordinates assert actions.shape[0:2] == returns.shape[0:2], ( f'actions.shape was {actions.shape} and' f'returns.shape was {returns.shape}') ppo_objective = rl_layers.PPOObjective( dist_inputs, stop_gradient(values), returns, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) # we insist that ppo_objective is a vector of shape [128,1] assert len(ppo_objective.shape) == 2, ( f'ppo_objective was {ppo_objective}') # which agrees with returns/values/actions on the first two coordinates assert ppo_objective.shape[0:2] == values.shape[0:2], ( f'ppo_objective.shape was {ppo_objective.shape} and ' f'values.shape was {values.shape}') entropy_loss = rl_layers.EntropyLoss( dist_inputs, actions, log_prob_fun=self._policy_dist.log_prob, entropy_coeff=self._entropy_coeff, entropy_fun=self._policy_dist.entropy) assert jnp.ndim( entropy_loss) == 0, f'entropy_loss was {entropy_loss}' l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) assert jnp.ndim( l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' return -ppo_objective.mean() + l2_value_loss - entropy_loss
def _beta_gamma_with_correct_axes(self, x, weights): # Expand the parameters to have the right axes. beta, gamma = weights # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in self._axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] return beta, gamma