def f(preds, values, returns, actions, mask): advantages = jnp.squeeze(returns - stop_gradient(values), axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss
def AWRJointLoss(x, **unused_kwargs): # pylint: disable=invalid-name preds, values, returns, actions, mask = x advantages = jnp.squeeze(returns - values, axis=-1) logps = self._policy_dist.log_prob(preds, actions) awr_loss = actor_critic.AWRLoss(beta=self._beta, w_max=self._w_max)( (logps, advantages, jnp.zeros_like(logps), mask)) l2_value_loss = jnp.mean((returns - values)**2) * self._value_loss_coeff return awr_loss + l2_value_loss