def _compute_critic_loss(sampled_q_t_all: tf.Tensor, q_tm1_all: tf.Tensor, r_t_all: tf.Tensor, d_t: tf.Tensor, discount: float, num_samples: int, num_critic_heads: int): """Compute loss and sampled Q-values for (non-distributional) critics.""" # Reshape Q-value samples back to original batch dimensions and average # them to compute the TD-learning bootstrap target. batch_size = r_t_all.get_shape()[0] sampled_q_t = tf.reshape( sampled_q_t_all, (num_samples, batch_size, num_critic_heads)) # [N,B,C] q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B, C] # Flatten q_t and q_tm1; necessary for trfl.td_learning q_t = tf.reshape(q_t, [-1]) # [B*C] q_tm1 = tf.reshape(q_tm1_all, [-1]) # [B*C] # Flatten r_t_all; necessary for trfl.td_learning r_t_all = tf.reshape(r_t_all, [-1]) # [B*C] # Broadcast and then flatten d_t, to match shape of q_t and q_tm1 d_t = tf.tile(d_t, [num_critic_heads]) # [B*C] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(discount, dtype=d_t.dtype) # Critic loss. critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t, q_t).loss critic_loss = tf.reduce_mean(critic_loss) return critic_loss, sampled_q_t
def dev_critic_loss(self, dev_dataset=None): critic_loss_sum = 0. count = 0. for sample in dev_dataset: o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=d_t.dtype) q_t = self._critic_network(o_t, self._policy_network(o_t)) q_tm1 = self._critic_network(o_tm1, a_tm1) # Critic loss. if self._distributional: critic_loss = losses.categorical(q_tm1, r_t, discount * d_t, q_t) else: # Squeeze into the shape expected by the td_learning implementation. q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B] q_t = tf.squeeze(q_t, axis=-1) # [B] critic_loss = trfl.td_learning(q_tm1, r_t, discount * d_t, q_t).loss critic_loss_sum += tf.reduce_mean(critic_loss, axis=[0]) count += 1. return critic_loss_sum / count
def _step(self): # Update target network. online_variables = ( *self._observation_network.variables, *self._critic_network.variables, *self._policy_network.variables, ) target_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, *self._target_policy_network.variables, ) # Make online -> target network update ops. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(online_variables, target_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. inputs = next(self._iterator) transitions: types.Transition = inputs.data # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=transitions.discount.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(transitions.observation) o_t = self._target_observation_network( transitions.next_observation) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tree.map_structure(tf.stop_gradient, o_t) # Critic learning. q_tm1 = self._critic_network(o_tm1, transitions.action) q_t = self._target_critic_network(o_t, self._target_policy_network(o_t)) # Squeeze into the shape expected by the td_learning implementation. q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B] q_t = tf.squeeze(q_t, axis=-1) # [B] # Critic loss. critic_loss = trfl.td_learning(q_tm1, transitions.reward, discount * transitions.discount, q_t).loss critic_loss = tf.reduce_mean(critic_loss, axis=0) # Actor learning. dpg_a_t = self._policy_network(o_t) dpg_q_t = self._critic_network(o_t, dpg_a_t) # Actor loss. If clipping is true use dqda clipping and clip the norm. dqda_clipping = 1.0 if self._clipping else None policy_loss = losses.dpg(dpg_q_t, dpg_a_t, tape=tape, dqda_clipping=dqda_clipping, clip_norm=self._clipping) policy_loss = tf.reduce_mean(policy_loss, axis=0) # Get trainable variables. policy_variables = self._policy_network.trainable_variables critic_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) # Compute gradients. policy_gradients = tape.gradient(policy_loss, policy_variables) critic_gradients = tape.gradient(critic_loss, critic_variables) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] # Apply gradients. self._policy_optimizer.apply(policy_gradients, policy_variables) self._critic_optimizer.apply(critic_gradients, critic_variables) # Losses to track. return { 'critic_loss': critic_loss, 'policy_loss': policy_loss, }
def _step(self) -> types.NestedTensor: # Update target network. online_policy_variables = self._policy_network.variables target_policy_variables = self._target_policy_network.variables online_critic_variables = ( *self._observation_network.variables, *self._critic_network.variables, ) target_critic_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, ) # Make online policy -> target policy network update ops. if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: for src, dest in zip(online_policy_variables, target_policy_variables): dest.assign(src) # Make online critic -> target critic network update ops. if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: for src, dest in zip(online_critic_variables, target_critic_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data # Get batch size and scalar dtype. batch_size = r_t.shape[0] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=d_t.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(o_tm1) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tf.stop_gradient(self._target_observation_network(o_t)) # Get online and target action distributions from policy networks. online_action_distribution = self._policy_network(o_t) target_action_distribution = self._target_policy_network(o_t) # Sample actions to evaluate policy; of size [N, B, ...]. sampled_actions = target_action_distribution.sample( self._num_samples) # Tile embedded observations to feed into the target critic network. # Note: this is more efficient than tiling before the embedding layer. tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] # Compute target-estimated distributional value of sampled actions at o_t. sampled_q_t_all = self._target_critic_network( # Merge batch dimensions; to shape [N*B, ...]. snt.merge_leading_dims(tiled_o_t, num_dims=2), snt.merge_leading_dims(sampled_actions, num_dims=2)) # Compute online critic value distribution of a_tm1 in state o_tm1. q_tm1_all = self._critic_network(o_tm1, a_tm1) # Compute rewards for objectives with defined reward_fn reward_stats = {} r_t_all = [] for objective in self._objectives: if hasattr(objective, 'reward_fn'): r = objective.reward_fn(o_tm1, a_tm1, r_t) reward_stats['{}_reward'.format( objective.name)] = tf.reduce_mean(r) r_t_all.append(r) r_t_all = tf.stack(r_t_all, axis=-1) r_t_all.get_shape().assert_has_rank(2) # [B, C] if isinstance(sampled_q_t_all, list): # Distributional critics # Compute average logits by first reshaping them and normalizing them # across atoms. critic_losses = [] sampled_q_ts = [] for idx, (sampled_q_t_distributions, q_tm1_distribution) in enumerate( zip(sampled_q_t_all, q_tm1_all)): # Compute loss for distributional critic for objective c sampled_logits = tf.reshape( sampled_q_t_distributions.logits, [self._num_samples, batch_size, -1]) # [N, B, A] sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) # Construct the expected distributional value for bootstrapping. q_t_distribution = networks.DiscreteValuedDistribution( values=sampled_q_t_distributions.values, logits=averaged_logits) # Compute critic distributional loss. critic_loss = losses.categorical(q_tm1_distribution, r_t_all[:, idx], discount * d_t, q_t_distribution) critic_losses.append(tf.reduce_mean(critic_loss)) # Compute Q-values of sampled actions and reshape to [N, B]. sampled_q_ts.append( tf.reshape(sampled_q_t_distributions.mean(), (self._num_samples, -1))) critic_loss = tf.reduce_mean(critic_losses) sampled_q_t = tf.stack(sampled_q_ts, axis=-1) # [N, B, C] else: # Reshape Q-value samples back to original batch dimensions and average # them to compute the TD-learning bootstrap target. sampled_q_t = tf.reshape(sampled_q_t_all, (self._num_samples, batch_size, self._num_critic_heads)) # [N,B,C] q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B, C] # Flatten q_t and q_tm1; necessary for trfl.td_learning q_t = tf.reshape(q_t, [-1]) # [B*C] q_tm1 = tf.reshape(q_tm1_all, [-1]) # [B*C] # Flatten r_t_all; necessary for trfl.td_learning r_t_all = tf.reshape(r_t_all, [-1]) # [B*C] # Broadcast and then flatten d_t, to match shape of q_t and q_tm1 d_t = tf.tile(d_t, [self._num_critic_heads]) # [B*C] # Critic loss. critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t, q_t).loss critic_loss = tf.reduce_mean(critic_loss) # Add sampled Q-values for objectives with defined objective_fn sampled_q_idx = 0 sampled_q_t_k = [] for objective in self._objectives: if hasattr(objective, 'reward_fn'): sampled_q_t_k.append( tf.stop_gradient(sampled_q_t[..., sampled_q_idx])) sampled_q_idx += 1 if hasattr(objective, 'objective_fn'): sampled_q_t_k.append( tf.stop_gradient( objective.objective_fn(sampled_actions, sampled_q_t))) sampled_q_t_k = tf.stack(sampled_q_t_k, axis=-1) # [N, B, K] # Compute MPO policy loss. policy_loss, policy_stats = self._policy_loss_module( online_action_distribution=online_action_distribution, target_action_distribution=target_action_distribution, actions=sampled_actions, q_values=sampled_q_t_k) # For clarity, explicitly define which variables are trained by which loss. critic_trainable_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) policy_trainable_variables = self._policy_network.trainable_variables # The following are the MPO dual variables, stored in the loss module. dual_trainable_variables = self._policy_loss_module.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) policy_gradients, dual_gradients = tape.gradient( policy_loss, (policy_trainable_variables, dual_trainable_variables)) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tuple( tf.clip_by_global_norm(policy_gradients, 40.)[0]) critic_gradients = tuple( tf.clip_by_global_norm(critic_gradients, 40.)[0]) # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) # Losses to track. fetches = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } fetches.update(policy_stats) # Log MPO stats. fetches.update(reward_stats) # Log reward stats. return fetches
def _step(self) -> types.Nest: # Update target network. online_policy_variables = self._policy_network.variables target_policy_variables = self._target_policy_network.variables online_critic_variables = ( *self._observation_network.variables, *self._critic_network.variables, ) target_critic_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, ) # Make online policy -> target policy network update ops. if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: for src, dest in zip(online_policy_variables, target_policy_variables): dest.assign(src) # Make online critic -> target critic network update ops. if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: for src, dest in zip(online_critic_variables, target_critic_variables): dest.assign(src) # Increment number of learner steps for periodic update bookkeeping. self._num_steps.assign_add(1) # Get next batch of data. inputs = next(self._iterator) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. transitions: types.Transition = inputs.data # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=transitions.discount.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(transitions.observation) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tf.stop_gradient( self._target_observation_network(transitions.next_observation)) # Get action distributions from policy networks. online_action_distribution = self._policy_network(o_t) target_action_distribution = self._target_policy_network(o_t) # Get sampled actions to evaluate policy; of size [N, B, ...]. sampled_actions = target_action_distribution.sample( self._num_samples) tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] # Compute the target critic's Q-value of the sampled actions in state o_t. sampled_q_t = self._target_critic_network( # Merge batch dimensions; to shape [N*B, ...]. snt.merge_leading_dims(tiled_o_t, num_dims=2), snt.merge_leading_dims(sampled_actions, num_dims=2)) # Reshape Q-value samples back to original batch dimensions and average # them to compute the TD-learning bootstrap target. sampled_q_t = tf.reshape(sampled_q_t, (self._num_samples, -1)) # [N, B] q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B] # Compute online critic value of a_tm1 in state o_tm1. q_tm1 = self._critic_network(o_tm1, transitions.action) # [B, 1] q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B]; necessary for trfl.td_learning. # Critic loss. critic_loss = trfl.td_learning(q_tm1, transitions.reward, discount * transitions.discount, q_t).loss critic_loss = tf.reduce_mean(critic_loss) # Actor learning. policy_loss, policy_stats = self._policy_loss_module( online_action_distribution=online_action_distribution, target_action_distribution=target_action_distribution, actions=sampled_actions, q_values=sampled_q_t) # For clarity, explicitly define which variables are trained by which loss. critic_trainable_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) policy_trainable_variables = self._policy_network.trainable_variables # The following are the MPO dual variables, stored in the loss module. dual_trainable_variables = self._policy_loss_module.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) policy_gradients, dual_gradients = tape.gradient( policy_loss, (policy_trainable_variables, dual_trainable_variables)) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tuple( tf.clip_by_global_norm(policy_gradients, 40.)[0]) critic_gradients = tuple( tf.clip_by_global_norm(critic_gradients, 40.)[0]) # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) # Losses to track. fetches = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } fetches.update(policy_stats) # Log MPO stats. return fetches
def _step(self) -> Dict[str, tf.Tensor]: # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. sample = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=d_t.dtype) q_t = self._target_critic_network(o_t, self._policy_network(o_t)) if not self._distributional and self._vmin is not None: q_t = tf.clip_by_value(q_t, self._vmin, self._vmax) logging.info('Clip target critic network output with [%f, %f]', self._vmin, self._vmax) with tf.GradientTape() as tape: # Critic learning. q_tm1 = self._critic_network(o_tm1, a_tm1) # Critic loss. if self._distributional: critic_loss = losses.categorical(q_tm1, r_t, discount * d_t, q_t) else: # Squeeze into the shape expected by the td_learning implementation. q_tm1 = tf.squeeze(q_tm1, axis=-1) # [B] q_t = tf.squeeze(q_t, axis=-1) # [B] critic_loss = trfl.td_learning(q_tm1, r_t, discount * d_t, q_t).loss critic_loss = tf.reduce_mean(critic_loss, axis=[0]) # Get trainable variables. critic_variables = self._critic_network.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_variables) # Maybe clip gradients. if self._clipping: critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_variables) source_variables = self._critic_network.variables target_variables = self._target_critic_network.variables # Make online -> target network update ops. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(source_variables, target_variables): dest.assign(src) if self._init_observations is not None: if tf.math.mod(self._num_steps, 100) == 0: # init_obs = tf.convert_to_tensor(self._init_observations, tf.float32) init_obs = tree.map_structure(tf.convert_to_tensor, self._init_observations) init_actions = self._policy_network(init_obs) init_critic = tf.reduce_mean(self._critic_mean(init_obs, init_actions)) else: init_critic = tf.constant(0.) else: init_critic = tf.constant(0.) self._num_steps.assign_add(1) # Losses to track. return { 'critic_loss': critic_loss, 'q_s0': init_critic, }