def _step(self) -> Dict[str, tf.Tensor]: """Does an SGD step on a batch of sequences.""" # Retrieve a batch of data from replay. inputs: reverb.ReplaySample = next(self._iterator) data = tf2_utils.batch_to_sequence(inputs.data) observations, actions, rewards, discounts, extra = (data.observation, data.action, data.reward, data.discount, data.extras) core_state = tree.map_structure(lambda s: s[0], extra['core_state']) # actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] with tf.GradientTape() as tape: # Unroll current policy over observations. (logits, values), _ = snt.static_unroll(self._network, observations, core_state) # Compute importance sampling weights: current policy / behavior policy. behaviour_logits = extra['logits'] pi_behaviour = tfd.Categorical(logits=behaviour_logits[:-1]) pi_target = tfd.Categorical(logits=logits[:-1]) log_rhos = pi_target.log_prob(actions) - pi_behaviour.log_prob( actions) # Optionally clip rewards. rewards = tf.clip_by_value( rewards, tf.cast(-self._max_abs_reward, rewards.dtype), tf.cast(self._max_abs_reward, rewards.dtype)) # Critic loss. vtrace_returns = trfl.vtrace_from_importance_weights( log_rhos=tf.cast(log_rhos, tf.float32), discounts=tf.cast(self._discount * discounts, tf.float32), rewards=tf.cast(rewards, tf.float32), values=tf.cast(values[:-1], tf.float32), bootstrap_value=values[-1], ) critic_loss = tf.square(vtrace_returns.vs - values[:-1]) # Policy-gradient loss. policy_gradient_loss = trfl.policy_gradient( policies=pi_target, actions=actions, action_values=vtrace_returns.pg_advantages, ) # Entropy regulariser. entropy_loss = trfl.policy_entropy_loss(pi_target).loss # Combine weighted sum of actor & critic losses. loss = tf.reduce_mean(policy_gradient_loss + self._baseline_cost * critic_loss + self._entropy_cost * entropy_loss) # Compute gradients and optionally apply clipping. gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) self._optimizer.apply(gradients, self._network.trainable_variables) metrics = { 'loss': loss, 'critic_loss': tf.reduce_mean(critic_loss), 'entropy_loss': tf.reduce_mean(entropy_loss), 'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss), } return metrics
def _step(self, data: Step) -> Dict[str, tf.Tensor]: """Does an SGD step on a batch of sequences.""" observations, actions, rewards, discounts, _, extra = data core_state = tree.map_structure(lambda s: s[0], extra['core_state']) actions = actions[:-1] # [T-1] rewards = rewards[:-1] # [T-1] discounts = discounts[:-1] # [T-1] # Workaround for NO_OP actions # In some environments, passing NO_OP(-1) actions would lead to a crash. # These actions (at episode boundaries) should be ignored anyway, # so we replace NO_OP actions with a valid action index (0). actions = (tf.zeros_like(actions) * tf.cast(actions == -1, tf.int32) + actions * tf.cast(actions != -1, tf.int32)) with tf.GradientTape() as tape: # Unroll current policy over observations. (logits, values), _ = snt.static_unroll(self._network, observations, core_state) pi = tfd.Categorical(logits=logits[:-1]) # Optionally clip rewards. rewards = tf.clip_by_value(rewards, tf.cast(-self._max_abs_reward, rewards.dtype), tf.cast(self._max_abs_reward, rewards.dtype)) # Compute actor & critic losses. discounted_returns = trfl.generalized_lambda_returns( rewards=tf.cast(rewards, tf.float32), pcontinues=tf.cast(self._discount*discounts, tf.float32), values=tf.cast(values[:-1], tf.float32), bootstrap_value=tf.cast(values[-1], tf.float32) ) advantages = discounted_returns - values[:-1] critic_loss = tf.square(advantages) policy_gradient_loss = trfl.policy_gradient( policies=pi, actions=actions, action_values=advantages ) entropy_loss = trfl.policy_entropy_loss(pi).loss loss = tf.reduce_mean(policy_gradient_loss + self._baseline_cost * critic_loss + self._entropy_cost * entropy_loss) # Compute gradients and optionally apply clipping. gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, self._max_gradient_norm) self._optimizer.apply(gradients, self._network.trainable_variables) metrics = { 'loss': loss, 'critic_loss': tf.reduce_mean(critic_loss), 'entropy_loss': tf.reduce_mean(entropy_loss), 'policy_gradient_loss': tf.reduce_mean(policy_gradient_loss), } return metrics