Пример #1
0
 def f(preds, values, returns, actions, mask):
   advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1)
   logps = self._policy_dist.log_prob(preds, actions)
   awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)(
       (logps, advantages, jnp.zeros_like(logps), mask))
   l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff
   return awr_loss + l2_value_loss
Пример #2
0
 def AWRJointLoss(x, **unused_kwargs):  # pylint: disable=invalid-name
   preds, values, returns, actions, mask = x
   advantages = jnp.squeeze(returns - values, axis=-1)
   logps = self._policy_dist.log_prob(preds, actions)
   awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)(
       (logps, advantages, jnp.zeros_like(logps), mask))
   l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff
   return awr_loss + l2_value_loss