def _train(self, batch, weights=None): critic_variables = self._qf.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert critic_variables, 'No qf variables to optimize.' tape.watch(critic_variables) critic_loss = self.critic_loss(batch['observations'], batch['actions'], batch['opponent_actions'], batch['target_actions'], batch['rewards'], batch['next_observations'], weights=weights) tf.debugging.check_numerics(critic_loss, 'qf loss is inf or nan.') critic_grads = tape.gradient(critic_loss, critic_variables) tf_utils.apply_gradients(critic_grads, critic_variables, self._qf_optimizer, self._gradient_clipping) self._train_step += 1 losses = { 'critic_loss': critic_loss.numpy(), } return losses
def _train(self, batch, env, agent_id, weights=None): critic_variables = self._qf.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert critic_variables, 'No qf variables to optimize.' tape.watch(critic_variables) critic_loss = self.critic_loss( env, agent_id, batch['observations'], batch['actions'], batch['opponent_actions'], batch['target_actions'], batch['rewards'], batch['next_observations'], batch['terminals'], #batch['obs_v_idxes'], #batch['next_obs_v_idxes'], weights=weights) tf.debugging.check_numerics(critic_loss, 'qf loss is inf or nan.') critic_grads = tape.gradient(critic_loss, critic_variables) # print(critic_grads) tf_utils.apply_gradients(critic_grads, critic_variables, self._qf_optimizer, self._gradient_clipping) actor_variables = self._policy.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert actor_variables, 'No actor variables to optimize.' tape.watch(actor_variables) actor_loss = self.actor_loss(env, agent_id, batch['observations'], batch['opponent_actions'], weights=weights) tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.') actor_grads = tape.gradient(actor_loss, actor_variables) tf_utils.apply_gradients(actor_grads, actor_variables, self._policy_optimizer, self._gradient_clipping) self._train_step += 1 if self._train_step % self._target_update_period == 0: self._update_target() losses = { 'pg_loss': actor_loss.numpy(), 'critic_loss': critic_loss.numpy(), } return losses