def _compute_distributional_critic_loss( sampled_q_t_all: List[tf.Tensor], q_tm1_all: List[tf.Tensor], r_t_all: tf.Tensor, d_t: tf.Tensor, discount: float, num_samples: int): """Compute loss and sampled Q-values for distributional critics.""" # Compute average logits by first reshaping them and normalizing them # across atoms. batch_size = r_t_all.get_shape()[0] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(discount, dtype=d_t.dtype) critic_losses = [] sampled_q_ts = [] for idx, (sampled_q_t_distributions, q_tm1_distribution) in enumerate( zip(sampled_q_t_all, q_tm1_all)): # Compute loss for distributional critic for objective c sampled_logits = tf.reshape( sampled_q_t_distributions.logits, [num_samples, batch_size, -1]) # [N, B, A] sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) # Construct the expected distributional value for bootstrapping. q_t_distribution = networks.DiscreteValuedDistribution( values=sampled_q_t_distributions.values, logits=averaged_logits) # Compute critic distributional loss. critic_loss = losses.categorical( q_tm1_distribution, r_t_all[:, idx], discount * d_t, q_t_distribution) critic_losses.append(tf.reduce_mean(critic_loss)) # Compute Q-values of sampled actions and reshape to [N, B]. sampled_q_ts.append(tf.reshape( sampled_q_t_distributions.mean(), (num_samples, -1))) critic_loss = tf.reduce_mean(critic_losses) sampled_q_t = tf.stack(sampled_q_ts, axis=-1) # [N, B, C] return critic_loss, sampled_q_t
def _step(self, sample: reverb.ReplaySample) -> Dict[str, tf.Tensor]: # Transpose batch and sequence axes, i.e. [B, T, ...] to [T, B, ...]. sample = tf2_utils.batch_to_sequence(sample) observations = sample.observation actions = sample.action rewards = sample.reward discounts = sample.discount dtype = rewards.dtype # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=discounts.dtype) # Loss cumulants across time. These cannot be python mutable objects. critic_loss = 0. policy_loss = 0. # Each transition induces a policy loss, which we then weight using # the `policy_loss_coef_t`; shape [B], see https://arxiv.org/abs/2006.15134. # `policy_loss_coef` is a scalar average of these coefficients across # the batch and sequence length dimensions. policy_loss_coef = 0. per_device_batch_size = actions.shape[1] # Initialize recurrent states. critic_state = self._critic_network.initial_state( per_device_batch_size) target_critic_state = critic_state policy_state = self._policy_network.initial_state( per_device_batch_size) target_policy_state = policy_state with tf.GradientTape(persistent=True) as tape: for t in range(1, self._sequence_length): o_tm1 = tree.map_structure(operator.itemgetter(t - 1), observations) a_tm1 = tree.map_structure(operator.itemgetter(t - 1), actions) r_t = tree.map_structure(operator.itemgetter(t - 1), rewards) d_t = tree.map_structure(operator.itemgetter(t - 1), discounts) o_t = tree.map_structure(operator.itemgetter(t), observations) if t != 1: # By only updating the target critic state here we are forcing # the target critic to ignore observations[0]. Otherwise, the # target_critic will be unrolled for one more timestep than critic. # The smaller the sequence length, the more problematic this is: if # you use RNN on sequences of length 2, you would expect the code to # never use recurrent connections. But if you don't skip updating the # target_critic_state on observation[0] here, it won't be the case. _, target_critic_state = self._target_critic_network( o_tm1, a_tm1, target_critic_state) # ========================= Critic learning ============================ q_tm1, next_critic_state = self._critic_network( o_tm1, a_tm1, critic_state) target_action_distribution, target_policy_state = self._target_policy_network( o_t, target_policy_state) sampled_actions_t = target_action_distribution.sample( self._num_action_samples_td_learning) # [N, B, ...] tiled_o_t = tf2_utils.tile_nested( o_t, self._num_action_samples_td_learning) tiled_target_critic_state = tf2_utils.tile_nested( target_critic_state, self._num_action_samples_td_learning) # Compute the target critic's Q-value of the sampled actions. sampled_q_t, _ = snt.BatchApply(self._target_critic_network)( tiled_o_t, sampled_actions_t, tiled_target_critic_state) # Compute average logits by first reshaping them to [N, B, A] and then # normalizing them across atoms. new_shape = [ self._num_action_samples_td_learning, r_t.shape[0], -1 ] sampled_logits = tf.reshape(sampled_q_t.logits, new_shape) sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) # Construct the expected distributional value for bootstrapping. q_t = networks.DiscreteValuedDistribution( values=sampled_q_t.values, logits=averaged_logits) critic_loss_t = losses.categorical(q_tm1, r_t, discount * d_t, q_t) critic_loss_t = tf.reduce_mean(critic_loss_t) # ========================= Actor learning ============================= action_distribution_tm1, policy_state = self._policy_network( o_tm1, policy_state) q_tm1_mean = q_tm1.mean() # Compute the estimate of the value function based on # self._num_action_samples_policy_weight samples from the policy. tiled_o_tm1 = tf2_utils.tile_nested( o_tm1, self._num_action_samples_policy_weight) tiled_critic_state = tf2_utils.tile_nested( critic_state, self._num_action_samples_policy_weight) action_tm1 = action_distribution_tm1.sample( self._num_action_samples_policy_weight) tiled_z_tm1, _ = snt.BatchApply(self._critic_network)( tiled_o_tm1, action_tm1, tiled_critic_state) tiled_v_tm1 = tf.reshape( tiled_z_tm1.mean(), [self._num_action_samples_policy_weight, -1]) # Use mean, min, or max to aggregate Q(s, a_i), a_i ~ pi(s) into the # final estimate of the value function. if self._baseline_reduce_function == 'mean': v_tm1_estimate = tf.reduce_mean(tiled_v_tm1, axis=0) elif self._baseline_reduce_function == 'max': v_tm1_estimate = tf.reduce_max(tiled_v_tm1, axis=0) elif self._baseline_reduce_function == 'min': v_tm1_estimate = tf.reduce_min(tiled_v_tm1, axis=0) # Assert that action_distribution_tm1 is a batch of multivariate # distributions (in contrast to e.g. a [batch, action_size] collection # of 1d distributions). assert len(action_distribution_tm1.batch_shape) == 1 policy_loss_batch = -action_distribution_tm1.log_prob(a_tm1) advantage = q_tm1_mean - v_tm1_estimate if self._policy_improvement_modes == 'exp': policy_loss_coef_t = tf.math.minimum( tf.math.exp(advantage / self._beta), self._ratio_upper_bound) elif self._policy_improvement_modes == 'binary': policy_loss_coef_t = tf.cast(advantage > 0, dtype=dtype) elif self._policy_improvement_modes == 'all': # Regress against all actions (effectively pure BC). policy_loss_coef_t = 1. policy_loss_coef_t = tf.stop_gradient(policy_loss_coef_t) policy_loss_batch *= policy_loss_coef_t policy_loss_t = tf.reduce_mean(policy_loss_batch) critic_state = next_critic_state critic_loss += critic_loss_t policy_loss += policy_loss_t policy_loss_coef += tf.reduce_mean( policy_loss_coef_t) # For logging. # Divide by sequence length to get mean losses. critic_loss /= tf.cast(self._sequence_length, dtype=dtype) policy_loss /= tf.cast(self._sequence_length, dtype=dtype) policy_loss_coef /= tf.cast(self._sequence_length, dtype=dtype) # Compute gradients. critic_gradients = tape.gradient( critic_loss, self._critic_network.trainable_variables) policy_gradients = tape.gradient( policy_loss, self._policy_network.trainable_variables) # Delete the tape manually because of the persistent=True flag. del tape # Sync gradients across GPUs or TPUs. ctx = tf.distribute.get_replica_context() critic_gradients = ctx.all_reduce('mean', critic_gradients) policy_gradients = ctx.all_reduce('mean', policy_gradients) # Maybe clip gradients. if self._clipping: policy_gradients = tf.clip_by_global_norm(policy_gradients, 40.)[0] critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0] # Apply gradients. self._critic_optimizer.apply(critic_gradients, self._critic_network.trainable_variables) self._policy_optimizer.apply(policy_gradients, self._policy_network.trainable_variables) source_variables = (self._critic_network.variables + self._policy_network.variables) target_variables = (self._target_critic_network.variables + self._target_policy_network.variables) # Make online -> target network update ops. if tf.math.mod(self._num_steps, self._target_update_period) == 0: for src, dest in zip(source_variables, target_variables): dest.assign(src) self._num_steps.assign_add(1) return { 'critic_loss': critic_loss, 'policy_loss': policy_loss, 'policy_loss_coef': policy_loss_coef, }
def _step(self) -> types.NestedTensor: # Update target network. online_policy_variables = self._policy_network.variables target_policy_variables = self._target_policy_network.variables online_critic_variables = ( *self._observation_network.variables, *self._critic_network.variables, ) target_critic_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, ) # Make online policy -> target policy network update ops. if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: for src, dest in zip(online_policy_variables, target_policy_variables): dest.assign(src) # Make online critic -> target critic network update ops. if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: for src, dest in zip(online_critic_variables, target_critic_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data # Get batch size and scalar dtype. batch_size = r_t.shape[0] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=d_t.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(o_tm1) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tf.stop_gradient(self._target_observation_network(o_t)) # Get online and target action distributions from policy networks. online_action_distribution = self._policy_network(o_t) target_action_distribution = self._target_policy_network(o_t) # Sample actions to evaluate policy; of size [N, B, ...]. sampled_actions = target_action_distribution.sample( self._num_samples) # Tile embedded observations to feed into the target critic network. # Note: this is more efficient than tiling before the embedding layer. tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] # Compute target-estimated distributional value of sampled actions at o_t. sampled_q_t_distributions = self._target_critic_network( # Merge batch dimensions; to shape [N*B, ...]. snt.merge_leading_dims(tiled_o_t, num_dims=2), snt.merge_leading_dims(sampled_actions, num_dims=2)) # Compute average logits by first reshaping them and normalizing them # across atoms. new_shape = [self._num_samples, batch_size, -1] # [N, B, A] sampled_logits = tf.reshape(sampled_q_t_distributions.logits, new_shape) sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) # Construct the expected distributional value for bootstrapping. q_t_distribution = networks.DiscreteValuedDistribution( values=sampled_q_t_distributions.values, logits=averaged_logits) # Compute online critic value distribution of a_tm1 in state o_tm1. q_tm1_distribution = self._critic_network(o_tm1, a_tm1) # Compute critic distributional loss. critic_loss = losses.categorical(q_tm1_distribution, r_t, discount * d_t, q_t_distribution) critic_loss = tf.reduce_mean(critic_loss) # Compute Q-values of sampled actions and reshape to [N, B]. sampled_q_values = sampled_q_t_distributions.mean() sampled_q_values = tf.reshape(sampled_q_values, (self._num_samples, -1)) # Compute MPO policy loss. policy_loss, policy_stats = self._policy_loss_module( online_action_distribution=online_action_distribution, target_action_distribution=target_action_distribution, actions=sampled_actions, q_values=sampled_q_values) # For clarity, explicitly define which variables are trained by which loss. critic_trainable_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) policy_trainable_variables = self._policy_network.trainable_variables # The following are the MPO dual variables, stored in the loss module. dual_trainable_variables = self._policy_loss_module.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) policy_gradients, dual_gradients = tape.gradient( policy_loss, (policy_trainable_variables, dual_trainable_variables)) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tuple( tf.clip_by_global_norm(policy_gradients, 40.)[0]) critic_gradients = tuple( tf.clip_by_global_norm(critic_gradients, 40.)[0]) # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) # Losses to track. fetches = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } fetches.update(policy_stats) # Log MPO stats. return fetches
def _step(self) -> types.NestedTensor: # Update target network. online_policy_variables = self._policy_network.variables target_policy_variables = self._target_policy_network.variables online_critic_variables = ( *self._observation_network.variables, *self._critic_network.variables, ) target_critic_variables = ( *self._target_observation_network.variables, *self._target_critic_network.variables, ) # Make online policy -> target policy network update ops. if tf.math.mod(self._num_steps, self._target_policy_update_period) == 0: for src, dest in zip(online_policy_variables, target_policy_variables): dest.assign(src) # Make online critic -> target critic network update ops. if tf.math.mod(self._num_steps, self._target_critic_update_period) == 0: for src, dest in zip(online_critic_variables, target_critic_variables): dest.assign(src) self._num_steps.assign_add(1) # Get data from replay (dropping extras if any). Note there is no # extra data here because we do not insert any into Reverb. inputs = next(self._iterator) o_tm1, a_tm1, r_t, d_t, o_t = inputs.data # Get batch size and scalar dtype. batch_size = r_t.shape[0] # Cast the additional discount to match the environment discount dtype. discount = tf.cast(self._discount, dtype=d_t.dtype) with tf.GradientTape(persistent=True) as tape: # Maybe transform the observation before feeding into policy and critic. # Transforming the observations this way at the start of the learning # step effectively means that the policy and critic share observation # network weights. o_tm1 = self._observation_network(o_tm1) # This stop_gradient prevents gradients to propagate into the target # observation network. In addition, since the online policy network is # evaluated at o_t, this also means the policy loss does not influence # the observation network training. o_t = tf.stop_gradient(self._target_observation_network(o_t)) # Get online and target action distributions from policy networks. online_action_distribution = self._policy_network(o_t) target_action_distribution = self._target_policy_network(o_t) # Sample actions to evaluate policy; of size [N, B, ...]. sampled_actions = target_action_distribution.sample( self._num_samples) # Tile embedded observations to feed into the target critic network. # Note: this is more efficient than tiling before the embedding layer. tiled_o_t = tf2_utils.tile_tensor(o_t, self._num_samples) # [N, B, ...] # Compute target-estimated distributional value of sampled actions at o_t. sampled_q_t_all = self._target_critic_network( # Merge batch dimensions; to shape [N*B, ...]. snt.merge_leading_dims(tiled_o_t, num_dims=2), snt.merge_leading_dims(sampled_actions, num_dims=2)) # Compute online critic value distribution of a_tm1 in state o_tm1. q_tm1_all = self._critic_network(o_tm1, a_tm1) # Compute rewards for objectives with defined reward_fn reward_stats = {} r_t_all = [] for objective in self._objectives: if hasattr(objective, 'reward_fn'): r = objective.reward_fn(o_tm1, a_tm1, r_t) reward_stats['{}_reward'.format( objective.name)] = tf.reduce_mean(r) r_t_all.append(r) r_t_all = tf.stack(r_t_all, axis=-1) r_t_all.get_shape().assert_has_rank(2) # [B, C] if isinstance(sampled_q_t_all, list): # Distributional critics # Compute average logits by first reshaping them and normalizing them # across atoms. critic_losses = [] sampled_q_ts = [] for idx, (sampled_q_t_distributions, q_tm1_distribution) in enumerate( zip(sampled_q_t_all, q_tm1_all)): # Compute loss for distributional critic for objective c sampled_logits = tf.reshape( sampled_q_t_distributions.logits, [self._num_samples, batch_size, -1]) # [N, B, A] sampled_logprobs = tf.math.log_softmax(sampled_logits, axis=-1) averaged_logits = tf.reduce_logsumexp(sampled_logprobs, axis=0) # Construct the expected distributional value for bootstrapping. q_t_distribution = networks.DiscreteValuedDistribution( values=sampled_q_t_distributions.values, logits=averaged_logits) # Compute critic distributional loss. critic_loss = losses.categorical(q_tm1_distribution, r_t_all[:, idx], discount * d_t, q_t_distribution) critic_losses.append(tf.reduce_mean(critic_loss)) # Compute Q-values of sampled actions and reshape to [N, B]. sampled_q_ts.append( tf.reshape(sampled_q_t_distributions.mean(), (self._num_samples, -1))) critic_loss = tf.reduce_mean(critic_losses) sampled_q_t = tf.stack(sampled_q_ts, axis=-1) # [N, B, C] else: # Reshape Q-value samples back to original batch dimensions and average # them to compute the TD-learning bootstrap target. sampled_q_t = tf.reshape(sampled_q_t_all, (self._num_samples, batch_size, self._num_critic_heads)) # [N,B,C] q_t = tf.reduce_mean(sampled_q_t, axis=0) # [B, C] # Flatten q_t and q_tm1; necessary for trfl.td_learning q_t = tf.reshape(q_t, [-1]) # [B*C] q_tm1 = tf.reshape(q_tm1_all, [-1]) # [B*C] # Flatten r_t_all; necessary for trfl.td_learning r_t_all = tf.reshape(r_t_all, [-1]) # [B*C] # Broadcast and then flatten d_t, to match shape of q_t and q_tm1 d_t = tf.tile(d_t, [self._num_critic_heads]) # [B*C] # Critic loss. critic_loss = trfl.td_learning(q_tm1, r_t_all, discount * d_t, q_t).loss critic_loss = tf.reduce_mean(critic_loss) # Add sampled Q-values for objectives with defined objective_fn sampled_q_idx = 0 sampled_q_t_k = [] for objective in self._objectives: if hasattr(objective, 'reward_fn'): sampled_q_t_k.append( tf.stop_gradient(sampled_q_t[..., sampled_q_idx])) sampled_q_idx += 1 if hasattr(objective, 'objective_fn'): sampled_q_t_k.append( tf.stop_gradient( objective.objective_fn(sampled_actions, sampled_q_t))) sampled_q_t_k = tf.stack(sampled_q_t_k, axis=-1) # [N, B, K] # Compute MPO policy loss. policy_loss, policy_stats = self._policy_loss_module( online_action_distribution=online_action_distribution, target_action_distribution=target_action_distribution, actions=sampled_actions, q_values=sampled_q_t_k) # For clarity, explicitly define which variables are trained by which loss. critic_trainable_variables = ( # In this agent, the critic loss trains the observation network. self._observation_network.trainable_variables + self._critic_network.trainable_variables) policy_trainable_variables = self._policy_network.trainable_variables # The following are the MPO dual variables, stored in the loss module. dual_trainable_variables = self._policy_loss_module.trainable_variables # Compute gradients. critic_gradients = tape.gradient(critic_loss, critic_trainable_variables) policy_gradients, dual_gradients = tape.gradient( policy_loss, (policy_trainable_variables, dual_trainable_variables)) # Delete the tape manually because of the persistent=True flag. del tape # Maybe clip gradients. if self._clipping: policy_gradients = tuple( tf.clip_by_global_norm(policy_gradients, 40.)[0]) critic_gradients = tuple( tf.clip_by_global_norm(critic_gradients, 40.)[0]) # Apply gradients. self._critic_optimizer.apply(critic_gradients, critic_trainable_variables) self._policy_optimizer.apply(policy_gradients, policy_trainable_variables) self._dual_optimizer.apply(dual_gradients, dual_trainable_variables) # Losses to track. fetches = { 'critic_loss': critic_loss, 'policy_loss': policy_loss, } fetches.update(policy_stats) # Log MPO stats. fetches.update(reward_stats) # Log reward stats. return fetches