def _step(self, transitions: Sequence[tf.Tensor]): """Do a batch of SGD on the actor + critic loss.""" observations, actions, rewards, discounts, final_observation = transitions with tf.GradientTape() as tape: # Build actor and critic losses. logits, values = snt.BatchApply(self._network)(observations) _, bootstrap_value = self._network(final_observation) critic_loss, (advantages, _) = trfl.td_lambda( state_values=values, rewards=rewards, pcontinues=self._discount * discounts, bootstrap_value=bootstrap_value, lambda_=self._td_lambda) actor_loss = trfl.discrete_policy_gradient_loss( logits, actions, advantages) loss = tf.reduce_mean(actor_loss + critic_loss) gradients = tape.gradient(loss, self._network.trainable_variables) self._optimizer.apply(gradients, self._network.trainable_variables)
def _step(self, trajectory: sequence.Trajectory): """Do a batch of SGD on actor + critic loss on a sequence of experience.""" observations, actions, rewards, discounts = trajectory # Add dummy batch dimensions. actions = tf.expand_dims(actions, axis=-1) # [T, 1] rewards = tf.expand_dims(rewards, axis=-1) # [T, 1] discounts = tf.expand_dims(discounts, axis=-1) # [T, 1] observations = tf.expand_dims(observations, axis=1) # [T+1, 1, ...] # Extract final observation for bootstrapping. observations, final_observation = observations[:-1], observations[-1] with tf.GradientTape() as tape: # Build actor and critic losses. (logits, values), state = snt.dynamic_unroll(self._network, observations, self._rollout_initial_state) (_, bootstrap_value), state = self._network(final_observation, state) values = tf.squeeze(values, axis=-1) bootstrap_value = tf.squeeze(bootstrap_value, axis=-1) critic_loss, (advantages, _) = trfl.td_lambda( state_values=values, rewards=rewards, pcontinues=self._discount * discounts, bootstrap_value=bootstrap_value, lambda_=self._td_lambda) actor_loss = trfl.discrete_policy_gradient_loss( logits, actions, advantages) entropy_loss = trfl.discrete_policy_entropy_loss(logits).loss loss = actor_loss + critic_loss + self._entropy_cost * entropy_loss loss = tf.reduce_mean(loss) gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, 5.) self._optimizer.apply(gradients, self._network.trainable_variables) return state
def _step(self, sequence: Sequence[tf.Tensor]): """Do a batch of SGD on actor + critic loss on a sequence of experience.""" (observations, actions, rewards, discounts, masks, final_obs, final_mask) = sequence masks = tf.expand_dims(masks, axis=-1) with tf.GradientTape() as tape: # Build actor and critic losses. state = self._rollout_initial_state logits_sequence = [] values = [] for t in range(self._sequence_length): (logits, value), state = self._network( (observations[t], masks[t]), state) logits_sequence.append(logits) values.append(value) (_, bootstrap_value), _ = self._network((final_obs, final_mask), state) values = tf.squeeze(tf.stack(values, axis=0), axis=-1) logits = tf.stack(logits_sequence, axis=0) bootstrap_value = tf.squeeze(bootstrap_value, axis=-1) critic_loss, (advantages, _) = trfl.td_lambda( state_values=values, rewards=rewards, pcontinues=self._discount * discounts, bootstrap_value=bootstrap_value, lambda_=self._td_lambda) actor_loss = trfl.discrete_policy_gradient_loss( logits, actions, advantages) loss = tf.reduce_mean(actor_loss + critic_loss) gradients = tape.gradient(loss, self._network.trainable_variables) gradients, _ = tf.clip_by_global_norm(gradients, 5.) self._optimizer.apply(gradients, self._network.trainable_variables) return state