def f(preds, values, returns, actions, mask): advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def _do_custom_gradients(self, x, weights, state, rng): """Calls this layer for a forward pass, but with custom gradients.""" def _do_forward(y, weights): old_weights, old_state, old_rng = self._weights, self._state, self._rng self._weights = weights res = self.forward(y) s = self._state self._weights, self._state, self._rng = old_weights, old_state, old_rng return res, s def do_forward_vjp(y, weights): """Custom gradient (vjp) function.""" old_weights, old_state, old_rng = self._weights, self._state, self._rng self._weights = weights output = self.forward(y) new_state = self._state self._weights, self._state, self._rng = old_weights, old_state, old_rng def vjpfun(grad): grad = grad[0] # Ignore dummy gradient wrt state. res = self.backward(y, output, grad, weights, state, new_state, rng) return res return (output, new_state), vjpfun do_forward = math.custom_grad(do_forward_vjp, _do_forward) output, state = do_forward(x, weights) # TODO(lukaszkaiser): Investigate why we need this stop_gradient state = math.stop_gradient(state) return output, state
def f(log_probs, advantages, old_log_probs, mask): if reweight: # Use new policy weights for sampled actions instead. mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs) if sampled_all_discrete: # Actions were sampled uniformly; weight them. mask *= jnp.exp(old_log_probs) weights = jnp.minimum(awr_weights(advantages, beta), w_max) return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)
def f(dist_inputs, values, returns, actions, old_log_probs, mask): """Definition of the Proximal Policy Optimization loss.""" del mask # TODO(lukaszkaiser): make PPO work with Transformer ppo_objective = rl_layers.PPOObjective( dist_inputs, stop_gradient(values), returns, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) entropy_loss = rl_layers.EntropyLoss( dist_inputs, actions, log_prob_fun=self._policy_dist.log_prob, entropy_coeff=self._entropy_coeff, entropy_fun=self._policy_dist.entropy) l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) return -ppo_objective.mean() + l2_value_loss - entropy_loss
def f(dist_inputs, values, returns, dones, rewards, actions, old_log_probs, mask): """Definition of the A2C loss.""" del dones, rewards, old_log_probs # Typically we have dist_inputs of the shape float32[128,1,18] assert len(dist_inputs.shape) == 3, ( f'dist_inputs.shape was {dist_inputs.shape} ' f'but expected length of the tensor shape is 3') # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] assert values.shape == returns.shape, ( f'values.shape was {values.shape}' f'returns.shape was (returns.shape)') # actions of the shape int32[128,1] in the case of discrete actions # and float32[128,1,6] in the case of of half-cheetah # actions agree with returns/values on the first two coordinates assert actions.shape[0:2] == returns.shape[0:2], ( f'actions.shape was {actions.shape}' f'returns.shape was (returns.shape)') # and mask of the shape float32[128,1] assert len(mask.shape) == 2, f'mask.shape was {mask.shape}' # which agrees with returns/values/actions on the first two coordinates assert mask.shape[0:2] == returns.shape[0:2], ( f'mask.shape was {mask.shape}' f'returns.shape was (returns.shape)') a2c_objective = rl_layers.A2CObjective( dist_inputs, stop_gradient(values), returns, actions, mask, log_prob_fun=self._policy_dist.log_prob, normalize_advantages=self._normalize_advantages) # we insist that a2c_objective is a scalar assert jnp.ndim( a2c_objective) == 0, f'a2c_objective was {a2c_objective}' entropy_loss = rl_layers.EntropyLoss( dist_inputs, actions, log_prob_fun=self._policy_dist.log_prob, entropy_coeff=self._entropy_coeff, entropy_fun=self._policy_dist.entropy) assert jnp.ndim( entropy_loss) == 0, f'entropy_loss was {entropy_loss}' l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) assert jnp.ndim( l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' combined_loss = a2c_objective + l2_value_loss - entropy_loss return combined_loss
def f(dist_inputs, values, returns, actions, old_log_probs, mask): """Definition of the Proximal Policy Optimization loss.""" del mask # TODO(lukaszkaiser): make PPO work with Transformer # We have dist_inputs of the shape float32[128,1,18] assert len(dist_inputs.shape) == 3, ( f'dist_inputs.shape was {dist_inputs.shape}' f'but expected length of the tensor shape is 3') # values of the shape float32[128,1,1] # returns of the shape float32[128,1,1] # and old_log_probs of the shape float32[128,1] assert values.shape == returns.shape, ( f'values.shape was {values.shape}' f'returns.shape was {returns.shape}') assert returns.shape[0:2] == old_log_probs.shape, ( f'returns.shape was {returns.shape}' f'old_log_probs.shape was {old_log_probs.shape}') # actions is a tensor of the shape int32[128,1] assert len( actions.shape) == 2, f'actions.shape was {actions.shape}' # which agrees with returns/values on the first two coordinates assert actions.shape[0:2] == returns.shape[0:2], ( f'actions.shape was {actions.shape} and' f'returns.shape was {returns.shape}') ppo_objective = rl_layers.PPOObjective( dist_inputs, stop_gradient(values), returns, actions, old_log_probs, log_prob_fun=self._policy_dist.log_prob, epsilon=self._epsilon, normalize_advantages=self._normalize_advantages) # we insist that ppo_objective is a vector of shape [128,1] assert len(ppo_objective.shape) == 2, ( f'ppo_objective was {ppo_objective}') # which agrees with returns/values/actions on the first two coordinates assert ppo_objective.shape[0:2] == values.shape[0:2], ( f'ppo_objective.shape was {ppo_objective.shape} and ' f'values.shape was {values.shape}') entropy_loss = rl_layers.EntropyLoss( dist_inputs, actions, log_prob_fun=self._policy_dist.log_prob, entropy_coeff=self._entropy_coeff, entropy_fun=self._policy_dist.entropy) assert jnp.ndim( entropy_loss) == 0, f'entropy_loss was {entropy_loss}' l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) assert jnp.ndim( l2_value_loss) == 0, f'l2_value_loss was {l2_value_loss}' return -ppo_objective.mean() + l2_value_loss - entropy_loss
def AWRJointLoss(x, **unused_kwargs): # pylint: disable=invalid-name preds, values, returns, actions, mask = x advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean( (returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def f(dist_inputs, values, returns, actions, old_log_probs, mask): """Definition of the A2C loss.""" del old_log_probs a2c_objective = rl_layers.A2CObjective( dist_inputs, stop_gradient(values), returns, actions, mask, log_prob_fun=self._policy_dist.log_prob, normalize_advantages=self._normalize_advantages) entropy_loss = rl_layers.EntropyLoss( dist_inputs, actions, log_prob_fun=self._policy_dist.log_prob, entropy_coeff=self._entropy_coeff, entropy_fun=self._policy_dist.entropy) l2_value_loss = rl_layers.ValueLoss( values, returns, value_loss_coeff=self._value_loss_coeff) return a2c_objective.mean() + l2_value_loss - entropy_loss
def f(log_probs, advantages, old_log_probs, mask): if reweight: # Use new policy weights for sampled actions instead. mask *= jnp.exp(math.stop_gradient(log_probs) - old_log_probs) weights = jnp.minimum(awr_weights(advantages, beta), w_max) return -jnp.sum(log_probs * weights * mask) / jnp.sum(mask)