def actor_loss(self, time_steps, rb_actions=None, weights=None, q_combinator='min', entropy_coef=1e-4): """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps. rb_actions: Actions from the replay buffer. While not used in the main RCE method, we used these actions to train a behavior policy for the ablation experiment studying how to sample actions for the success examples. weights: Optional scalar or elementwise (per-batch-entry) importance weights. q_combinator: Whether to combine the two Q-functions by taking the 'min' (as in TD3) or the 'max'. entropy_coef: Coefficient for entropy regularization term. We found that 1e-4 worked well for all environments. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): nest_utils.assert_same_structure(time_steps, self.time_step_spec) actions, log_pi = self._actions_and_log_probs(time_steps) target_input = (time_steps.observation, actions) target_q_values1, _ = self._critic_network_1(target_input, time_steps.step_type, training=False) target_q_values2, _ = self._critic_network_2(target_input, time_steps.step_type, training=False) if q_combinator == 'min': target_q_values = tf.minimum(target_q_values1, target_q_values2) else: assert q_combinator == 'max' target_q_values = tf.maximum(target_q_values1, target_q_values2) if entropy_coef == 0: actor_loss = -target_q_values else: actor_loss = entropy_coef * log_pi - target_q_values if actor_loss.shape.rank > 1: # Sum over the time dimension. actor_loss = tf.reduce_sum(actor_loss, axis=range(1, actor_loss.shape.rank)) reg_loss = self._actor_network.losses if self._actor_network else None agg_loss = common.aggregate_losses(per_example_loss=actor_loss, sample_weight=weights, regularization_loss=reg_loss) actor_loss = agg_loss.total_loss self._actor_loss_debug_summaries(actor_loss, actions, log_pi, target_q_values, time_steps) return actor_loss
def actor_loss(self, time_steps: ts.TimeStep, actions: types.Tensor, weights: Optional[types.Tensor] = None, training: Optional[bool] = True) -> types.Tensor: """Computes actor_loss equivalent to the SAC actor_loss. Uses behavioral cloning for the first `self._num_bc_steps` of training. Args: time_steps: A batch of timesteps. actions: A batch of actions. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether training should be applied. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): nest_utils.assert_same_structure(time_steps, self.time_step_spec) sampled_actions, sampled_log_pi = self._actions_and_log_probs( time_steps, training=training) # Behavioral cloning: train the policy to reproduce actions from # the dataset. if self.train_step_counter < self._num_bc_steps: distribution, _ = self._actor_network(time_steps.observation, time_steps.step_type, ()) actor_log_prob = distribution.log_prob(actions) actor_loss = tf.exp( self._log_alpha) * sampled_log_pi - actor_log_prob target_q_values = tf.zeros(tf.shape(sampled_log_pi)) else: target_input = (time_steps.observation, sampled_actions) target_q_values1, _ = self._critic_network_1( target_input, time_steps.step_type, training=False) target_q_values2, _ = self._critic_network_2( target_input, time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) actor_loss = tf.exp( self._log_alpha) * sampled_log_pi - target_q_values if actor_loss.shape.rank > 1: # Sum over the time dimension. actor_loss = tf.reduce_sum(actor_loss, axis=range(1, actor_loss.shape.rank)) reg_loss = self._actor_network.losses if self._actor_network else None agg_loss = common.aggregate_losses(per_example_loss=actor_loss, sample_weight=weights, regularization_loss=reg_loss) actor_loss = agg_loss.total_loss self._actor_loss_debug_summaries(actor_loss, sampled_actions, sampled_log_pi, target_q_values, time_steps) return actor_loss
def alpha_loss(self, time_steps, weights=None): """Computes the alpha_loss for EC-SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: alpha_loss: A scalar alpha loss. """ with tf.name_scope('alpha_loss'): tf.nest.assert_same_structure(time_steps, self.time_step_spec) unused_actions, log_pi = self._actions_and_log_probs(time_steps) entropy_diff = tf.stop_gradient(-log_pi - self._target_entropy) alpha_loss = (self._log_alpha * entropy_diff) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Sum over the time dimension. alpha_loss = tf.reduce_sum(input_tensor=alpha_loss, axis=1) else: alpha_loss = tf.expand_dims(alpha_loss, 0) agg_loss = common.aggregate_losses(per_example_loss=alpha_loss, sample_weight=weights) alpha_loss = agg_loss.total_loss self._alpha_loss_debug_summaries(alpha_loss, entropy_diff) return alpha_loss
def test_aggregate_losses_with_time_dim_and_float_weights(self): per_example_loss = tf.constant([[4., 2., 3.], [1, 1, 1]]) sample_weights = 0.5 aggregated_losses = common.aggregate_losses(per_example_loss, sample_weights) expected_per_example_loss = 0.5 * (4 + 2 + 3 + 1 + 1 + 1) / 6 self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss), expected_per_example_loss)
def test_aggregate_losses_three_dimensions(self): per_example_loss = tf.constant([[[4., 2., 3.], [1, 1, 1]], [[8., 4., 6.], [2, 2, 2]]]) aggregated_losses = common.aggregate_losses(per_example_loss) expected_per_example_loss = (4 + 2 + 3 + 1 + 1 + 1 + 8 + 4 + 6 + 2 + 2 + 2) / 12 self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss), expected_per_example_loss)
def test_aggregate_losses_with_time_dim_and_weights_with_batch_dim(self): per_example_loss = tf.constant([[4., 2., 3.], [1, 1, 1]]) sample_weights = tf.constant([ 1., 0., ]) aggregated_losses = common.aggregate_losses(per_example_loss, sample_weights) expected_per_example_loss = (4 + 2 + 3) / 6 self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss), expected_per_example_loss)
def _loss(self, experience, weights=None, training: bool = False): experience = self._as_trajectory(experience) per_example_loss = self._bc_loss_fn(experience, training=training) aggregated_losses = common.aggregate_losses( per_example_loss=per_example_loss, sample_weight=weights, regularization_loss=self._cloning_network.losses) return tf_agent.LossInfo( loss=aggregated_losses.total_loss, extra=BehavioralCloningLossInfo(per_example_loss))
def actor_loss(self, time_steps, actions, weights = None, ce_loss = False): """Computes the actor_loss for C-learning training. Args: time_steps: A batch of timesteps. actions: A batch of actions. weights: Optional scalar or elementwise (per-batch-entry) importance weights. ce_loss: (bool) Whether to update the actor using the cross entropy loss, which corresponds to using the log C-value. The default actor loss differs by not including the log. Empirically we observed no difference. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): nest_utils.assert_same_structure(time_steps, self.time_step_spec) sampled_actions, log_pi = self._actions_and_log_probs(time_steps) target_input = (time_steps.observation, sampled_actions) target_q_values1, _ = self._critic_network_1( target_input, time_steps.step_type, training=False) target_q_values2, _ = self._critic_network_2( target_input, time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) if ce_loss: actor_loss = tf.keras.losses.binary_crossentropy( tf.ones_like(target_q_values), target_q_values) else: actor_loss = -1.0 * target_q_values if actor_loss.shape.rank > 1: # Sum over the time dimension. actor_loss = tf.reduce_sum( actor_loss, axis=range(1, actor_loss.shape.rank)) reg_loss = self._actor_network.losses if self._actor_network else None agg_loss = common.aggregate_losses( per_example_loss=actor_loss, sample_weight=weights, regularization_loss=reg_loss) actor_loss = agg_loss.total_loss self._actor_loss_debug_summaries(actor_loss, actions, log_pi, target_q_values, time_steps) return actor_loss
def _train(self, experience, weights=None): with tf.GradientTape() as tape: per_example_loss = self._loss_fn(experience) aggregated_losses = common.aggregate_losses( per_example_loss=per_example_loss, sample_weight=weights, regularization_loss=self._cloning_network.losses) self._apply_loss(aggregated_losses, self._cloning_network.trainable_weights, tape, self._optimizer) self.train_step_counter.assign_add(1) return tf_agent.LossInfo(aggregated_losses.total_loss, BehavioralCloningLossInfo(per_example_loss))
def test_aggregate_4d_losses_and_2d_weights(self): per_example_loss = tf.constant([[[[4., 2., 3.], [1, 1, 1]], [[8., 4., 6.], [2, 2, 2]]], [[[4., 2., 3.], [1, 1, 1]], [[8., 4., 6.], [2, 2, 2]]]]) # 2x2x2x3 sample_weights = tf.constant([[ 1., 0., ], [ 0., 0., ]]) aggregated_losses = common.aggregate_losses(per_example_loss, sample_weights) expected_per_example_loss = (4 + 2 + 3 + 1 + 1 + 1) / 24 self.assertAlmostEqual( self.evaluate(aggregated_losses.total_loss), expected_per_example_loss)
def actor_loss(self, time_steps, weights=None): """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): tf.nest.assert_same_structure(time_steps, self.time_step_spec) actions, log_pi = self._actions_and_log_probs(time_steps) target_input = (time_steps.observation, actions) target_q_values1, _ = self._critic_network_1(target_input, time_steps.step_type, training=False) target_q_values2, _ = self._critic_network_2(target_input, time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) # Stop gradients to avoid updates to shared layers between critic and # actor. They could still be updated through the actor if desired, but we # do not want gradients to flow to shared variables throught the critic. target_q_values = tf.stop_gradient(target_q_values) actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Sum over the time dimension. actor_loss = tf.reduce_sum(input_tensor=actor_loss, axis=1) reg_loss = self._actor_network.losses if self._actor_network else None agg_loss = common.aggregate_losses(per_example_loss=actor_loss, sample_weight=weights, regularization_loss=reg_loss) actor_loss = agg_loss.total_loss self._actor_loss_debug_summaries(actor_loss, actions, log_pi, target_q_values, time_steps) return actor_loss
def _add_auxiliary_losses(self, transition, weights, losses_dict): """Computes auxiliary losses, updating losses_dict in place.""" total_auxiliary_loss = 0 if self._auxiliary_loss_fns is not None: for auxiliary_loss_fn in self._auxiliary_loss_fns: auxiliary_loss, auxiliary_reg_loss = auxiliary_loss_fn( network=self._q_network, transition=transition) agg_auxiliary_loss = common.aggregate_losses( per_example_loss=auxiliary_loss, sample_weight=weights, regularization_loss=auxiliary_reg_loss) total_auxiliary_loss += agg_auxiliary_loss.total_loss losses_dict.update({ 'auxiliary_loss_{}'.format(auxiliary_loss_fn.__name__): agg_auxiliary_loss.weighted, 'auxiliary_reg_loss_{}'.format(auxiliary_loss_fn.__name__): agg_auxiliary_loss.regularization, }) return total_auxiliary_loss
def actor_loss(self, time_steps: ts.TimeStep, weights: Optional[types.Tensor] = None, training: Optional[bool] = True) -> types.Tensor: """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether training should be applied. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): nest_utils.assert_same_structure(time_steps, self.time_step_spec) actions, log_pi = self._actions_and_log_probs(time_steps, training=training) target_input = (time_steps.observation, actions) # We do not update critic during actor loss. target_q_values1, _ = self._critic_network_1( target_input, step_type=time_steps.step_type, training=False) target_q_values2, _ = self._critic_network_2( target_input, step_type=time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values if actor_loss.shape.rank > 1: # Sum over the time dimension. actor_loss = tf.reduce_sum(actor_loss, axis=range(1, actor_loss.shape.rank)) reg_loss = self._actor_network.losses if self._actor_network else None agg_loss = common.aggregate_losses(per_example_loss=actor_loss, sample_weight=weights, regularization_loss=reg_loss) actor_loss = agg_loss.total_loss self._actor_loss_debug_summaries(actor_loss, actions, log_pi, target_q_values, time_steps) return actor_loss
def alpha_loss(self, time_steps: ts.TimeStep, weights: Optional[types.Tensor] = None, training: bool = False) -> types.Tensor: """Computes the alpha_loss for EC-SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used during training. Returns: alpha_loss: A scalar alpha loss. """ with tf.name_scope('alpha_loss'): nest_utils.assert_same_structure(time_steps, self.time_step_spec) # We do not update actor during alpha loss. unused_actions, log_pi = self._actions_and_log_probs( time_steps, training=False) entropy_diff = tf.stop_gradient(-log_pi - self._target_entropy) if self._use_log_alpha_in_alpha_loss: alpha_loss = (self._log_alpha * entropy_diff) else: alpha_loss = (tf.exp(self._log_alpha) * entropy_diff) if alpha_loss.shape.rank > 1: # Sum over the time dimension. alpha_loss = tf.reduce_sum(alpha_loss, axis=range(1, alpha_loss.shape.rank)) agg_loss = common.aggregate_losses(per_example_loss=alpha_loss, sample_weight=weights) alpha_loss = agg_loss.total_loss self._alpha_loss_debug_summaries(alpha_loss, entropy_diff) return alpha_loss
def actor_loss(self, time_steps: ts.TimeStep, weights: Optional[types.Tensor] = None) -> types.Tensor: """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): nest_utils.assert_same_structure(time_steps, self.time_step_spec) actions, log_pi = self._actions_and_log_probs(time_steps) target_input = (time_steps.observation, actions) target_q_values1, _ = self._critic_network_1(target_input, time_steps.step_type, training=False) target_q_values2, _ = self._critic_network_2(target_input, time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Sum over the time dimension. actor_loss = tf.reduce_sum(input_tensor=actor_loss, axis=1) reg_loss = self._actor_network.losses if self._actor_network else None agg_loss = common.aggregate_losses(per_example_loss=actor_loss, sample_weight=weights, regularization_loss=reg_loss) actor_loss = agg_loss.total_loss self._actor_loss_debug_summaries(actor_loss, actions, log_pi, target_q_values, time_steps) return actor_loss
def actor_loss(self, time_steps: ts.TimeStep, weights: Optional[types.Tensor] = None, training: bool = False) -> types.Tensor: """Computes the actor_loss for TD3 training. Args: time_steps: A batch of timesteps. weights: Optional scalar or element-wise (per-batch-entry) importance weights. training: Whether this loss is being used for training. # TODO(b/124383618): Add an action norm regularizer. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): actions, _ = self._actor_network(time_steps.observation, time_steps.step_type, training=training) q_values, _ = self._critic_network_1( (time_steps.observation, actions), time_steps.step_type, training=False) actor_loss = -q_values # Sum over the time dimension. if actor_loss.shape.rank > 1: actor_loss = tf.reduce_sum(actor_loss, axis=range(1, actor_loss.shape.rank)) actor_loss = common.aggregate_losses( per_example_loss=actor_loss, sample_weight=weights).total_loss with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar(name='actor_loss', data=actor_loss, step=self.train_step_counter) return actor_loss
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes loss for DQN training. Args: experience: A batch of experience data in the form of a `Trajectory` or `Transition`. The structure of `experience` must match that of `self.collect_policy.step_spec`. If a `Trajectory`, all tensors in `experience` must be shaped `[B, T, ...]` where `T` must be equal to `self.train_sequence_length` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute the element wise loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output td_loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether this loss is being used for training. Returns: loss: An instance of `DqnLossInfo`. Raises: ValueError: if the number of actions is greater than 1. """ transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action with tf.name_scope('loss'): q_values = self._compute_q_values(time_steps, actions, training=training) next_q_values = self._compute_next_q_values( next_time_steps, policy_steps.info) # This applies to any value of n_step_update and also in the RNN-DQN case. # In the RNN-DQN case, inputs and outputs contain a time dimension. td_targets = compute_td_targets( next_q_values, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount) valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors( time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = {'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss} common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo(total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def _loss_h(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action valid_mask = tf.cast(~time_steps.is_last(), tf.float32) with tf.name_scope('loss'): # q_values is already gathered by actions h_values = self._compute_h_values(time_steps, actions, training=training) multi_dim_actions = self._action_spec.shape.rank > 0 next_q_all_values = self._compute_next_all_q_values( next_time_steps, policy_steps.info) next_h_all_values = self._compute_next_all_h_values( next_time_steps, policy_steps.info) next_h_actions = tf.argmax(next_h_all_values, axis=1) # next_h_values here is used only for logging next_h_values = self._compute_next_h_values( next_time_steps, policy_steps.info) # next_q_values refer to Q(r,s') in Eqs.(4),(5) next_q_values = common.index_with_actions( next_q_all_values, tf.cast(next_h_actions, dtype=tf.int32), multi_dim_actions) h_target_all_values = self._compute_next_all_h_values( time_steps, policy_steps.info) h_target_values = common.index_with_actions( h_target_all_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions) td_targets = compute_momentum_td_targets( q_target_values=next_q_values, h_target_values=h_target_values, beta=self.beta()) td_error = valid_mask * (td_targets - h_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, h_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._h_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_h_values = h_values - next_h_values common.generate_tensor_summaries('td_error_h', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss_h', td_loss, self.train_step_counter) common.generate_tensor_summaries('h_values', h_values, self.train_step_counter) common.generate_tensor_summaries('next_h_values', next_h_values, self.train_step_counter) common.generate_tensor_summaries('diff_h_values', diff_h_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes loss for DQN training. Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.train_sequence_length` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute the element wise loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output td_loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether this loss is being used for training. Returns: loss: An instance of `DqnLossInfo`. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) squeeze_time_dim = not self._q_network.state_spec if self._n_step_update == 1: time_steps, policy_steps, next_time_steps = ( trajectory.experience_to_transitions(experience, squeeze_time_dim)) actions = policy_steps.action else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, policy_steps, _ = ( trajectory.experience_to_transitions(first_two_steps, squeeze_time_dim)) actions = policy_steps.action _, _, next_time_steps = (trajectory.experience_to_transitions( last_two_steps, squeeze_time_dim)) with tf.name_scope('loss'): q_values = self._compute_q_values(time_steps, actions, training=training) next_q_values = self._compute_next_q_values( next_time_steps, policy_steps.info) if self._n_step_update == 1: # Special case for n = 1 to avoid a loss of performance. td_targets = compute_td_targets( next_q_values, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount) else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. td_targets = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=next_q_values, time_major=False, provide_all_returns=False) valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def critic_loss(self, time_steps, actions, next_time_steps, augmented_obs, augmented_next_obs, td_errors_loss_fn, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. augmented_obs: List of observations. augmented_next_obs: List of next_observations. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) td_targets = self._compute_td_targets(next_time_steps, reward_scale_factor, gamma) # Compute td_targets with augmentations. for i in range(self._num_augmentations - 1): augmented_next_time_steps = next_time_steps._replace( observation=augmented_next_obs[i]) augmented_td_targets = self._compute_td_targets( augmented_next_time_steps, reward_scale_factor, gamma) td_targets = td_targets + augmented_td_targets # Average td_target estimation over augmentations. if self._num_augmentations > 1: td_targets = td_targets / self._num_augmentations pred_td_targets1, pred_td_targets2, critic_loss = ( self._compute_prediction_critic_loss( (time_steps.observation, actions), td_targets, time_steps, training, td_errors_loss_fn)) # Add Q Augmentations to the critic loss. for i in range(self._num_augmentations - 1): augmented_time_steps = time_steps._replace(observation=augmented_obs[i]) _, _, loss = ( self._compute_prediction_critic_loss( (augmented_time_steps.observation, actions), td_targets, augmented_time_steps, training, td_errors_loss_fn)) critic_loss = critic_loss + loss agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) return critic_loss
def _loss(self, experience, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes critic loss for CategoricalDQN training. See Algorithm 1 and the discussion immediately preceding it in page 6 of "A Distributional Perspective on Reinforcement Learning" Bellemare et al., 2017 https://arxiv.org/abs/1707.06887 Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.required_experience_time_steps` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional weights used for importance sampling. training: Whether the loss is being used for training. Returns: critic_loss: A scalar critic loss. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) squeeze_time_dim = not self._q_network.state_spec if self._n_step_update == 1: time_steps, policy_steps, next_time_steps = ( trajectory.experience_to_transitions(experience, squeeze_time_dim)) actions = policy_steps.action else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, policy_steps, _ = ( trajectory.experience_to_transitions(first_two_steps, squeeze_time_dim)) actions = policy_steps.action _, _, next_time_steps = (trajectory.experience_to_transitions( last_two_steps, squeeze_time_dim)) with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) rank = nest_utils.get_outer_rank(time_steps.observation, self._time_step_spec.observation) # If inputs have a time dimension and the q_network is stateful, # combine the batch and time dimension. batch_squash = (None if rank <= 1 or self._q_network.state_spec in ((), None) else network_utils.BatchSquash(rank)) network_observation = time_steps.observation if self._observation_and_action_constraint_splitter is not None: network_observation, _ = ( self._observation_and_action_constraint_splitter( network_observation)) # q_logits contains the Q-value logits for all actions. q_logits, _ = self._q_network(network_observation, time_steps.step_type, training=training) if batch_squash is not None: # Squash outer dimensions to a single dimensions for facilitation # computing the loss the following. Required for supporting temporal # inputs, for example. q_logits = batch_squash.flatten(q_logits) actions = batch_squash.flatten(actions) next_time_steps = tf.nest.map_structure( batch_squash.flatten, next_time_steps) next_q_distribution = self._next_q_distribution(next_time_steps) if actions.shape.rank > 1: actions = tf.squeeze(actions, list(range(1, actions.shape.rank))) # Project the sample Bellman update \hat{T}Z_{\theta} onto the original # support of Z_{\theta} (see Figure 1 in paper). batch_size = q_logits.shape[0] or tf.shape(q_logits)[0] tiled_support = tf.tile(self._support, [batch_size]) tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms]) if self._n_step_update == 1: discount = next_time_steps.discount if discount.shape.rank == 1: # We expect discount to have a shape of [batch_size], while # tiled_support will have a shape of [batch_size, num_atoms]. To # multiply these, we add a second dimension of 1 to the discount. discount = tf.expand_dims(discount, -1) next_value_term = tf.multiply(discount, tiled_support, name='next_value_term') reward = next_time_steps.reward if reward.shape.rank == 1: # See the explanation above. reward = tf.expand_dims(reward, -1) reward_term = tf.multiply(reward_scale_factor, reward, name='reward_term') target_support = tf.add(reward_term, gamma * next_value_term, name='target_support') else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. discounted_returns = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=tf.zeros([batch_size], dtype=discounts.dtype), time_major=False, provide_all_returns=False) # Convert discounted_returns from [batch_size] to [batch_size, 1] discounted_returns = tf.expand_dims(discounted_returns, -1) final_value_discount = tf.reduce_prod(discounts, axis=1) final_value_discount = tf.expand_dims(final_value_discount, -1) # Save the values of discounted_returns and final_value_discount in # order to check them in unit tests. self._discounted_returns = discounted_returns self._final_value_discount = final_value_discount target_support = tf.add(discounted_returns, final_value_discount * tiled_support, name='target_support') target_distribution = tf.stop_gradient( project_distribution(target_support, next_q_distribution, self._support)) # Obtain the current Q-value logits for the selected actions. indices = tf.range(batch_size) indices = tf.cast(indices, actions.dtype) reshaped_actions = tf.stack([indices, actions], axis=-1) chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions) # Compute the cross-entropy loss between the logits. If inputs have # a time dimension, compute the sum over the time dimension before # computing the mean over the batch dimension. if batch_squash is not None: target_distribution = batch_squash.unflatten( target_distribution) chosen_action_logits = batch_squash.unflatten( chosen_action_logits) critic_loss = tf.reduce_sum( tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits), axis=1) else: critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss dict_losses = { 'critic_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(dict_losses, step=self.train_step_counter, name_scope='Losses/') if self._debug_summaries: distribution_errors = target_distribution - chosen_action_logits with tf.name_scope('distribution_errors'): common.generate_tensor_summaries( 'distribution_errors', distribution_errors, step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean', tf.reduce_mean(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean_abs', tf.reduce_mean(tf.abs(distribution_errors)), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'max', tf.reduce_max(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'min', tf.reduce_min(distribution_errors), step=self.train_step_counter) with tf.name_scope('target_distribution'): common.generate_tensor_summaries( 'target_distribution', target_distribution, step=self.train_step_counter) # TODO(b/127318640): Give appropriate values for td_loss and td_error for # prioritized replay. return tf_agent.LossInfo( total_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
def _critic_loss_with_optional_entropy_term( self, time_steps: ts.TimeStep, actions: types.Tensor, next_time_steps: ts.TimeStep, td_errors_loss_fn: types.LossFn, gamma: types.Float = 1.0, reward_scale_factor: types.Float = 1.0, weights: Optional[types.Tensor] = None, training: bool = False) -> types.Tensor: r"""Computes the critic loss for CQL-SAC training. The original SAC critic loss is: ``` (q(s, a) - (r(s, a) + \gamma q(s', a') - \gamma \alpha \log \pi(a'|s')))^2 ``` The CQL-SAC critic loss makes the entropy term optional. CQL may value unseen actions higher since it lower-bounds the value of seen actions. This makes the policy entropy potentially redundant in the target term, since it will further enhance unseen actions' effects. If self._include_critic_entropy_term is False, this loss equation becomes: ``` (q(s, a) - (r(s, a) + \gamma q(s', a')))^2 ``` Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) # We do not update actor or target networks in critic loss. next_actions, next_log_pis = self._actions_and_log_probs( next_time_steps, training=False) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_2( target_input, next_time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) if self._include_critic_entropy_term: target_q_values -= (tf.exp(self._log_alpha) * next_log_pis) reward = next_time_steps.reward if self._reward_noise_variance > 0: reward_noise = tf.random.normal( tf.shape(reward), 0.0, self._reward_noise_variance, seed=self._reward_seed_stream()) reward += reward_noise td_targets = tf.stop_gradient(reward_scale_factor * reward + gamma * next_time_steps.discount * target_q_values) pred_input = (time_steps.observation, actions) pred_td_targets1, _ = self._critic_network_1(pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_2(pred_input, time_steps.step_type, training=training) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if critic_loss.shape.rank > 1: # Sum over the time dimension. critic_loss = tf.reduce_sum(critic_loss, axis=range(1, critic_loss.shape.rank)) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) return critic_loss
def critic_loss(self, time_steps, expert_experience, actions, next_time_steps, future_time_steps, td_errors_loss_fn, gamma = 1.0, reward_scale_factor = 1.0, weights = None, training = False, loss_name='c', use_done=False, q_combinator='min'): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. expert_experience: An array of success examples. actions: A batch of actions. next_time_steps: A batch of next timesteps. future_time_steps: A batch of future timesteps, used for n-step returns. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. loss_name: Which loss function to use. Use 'c' for RCE and 'q' for SQIL. use_done: Whether to use the terminal flag from the environment in the Bellman backup. We found that omitting it led to better results. q_combinator: Whether to combine the two Q-functions by taking the 'min' (as in TD3) or the 'max'. Returns: critic_loss: A scalar critic loss. """ assert weights is None with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, _ = self._actions_and_log_probs(next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_2( target_input, next_time_steps.step_type, training=False) if self._n_step is not None: future_actions, _ = self._actions_and_log_probs(future_time_steps) future_input = (future_time_steps.observation, future_actions) future_q_values1, _ = self._target_critic_network_1( future_input, future_time_steps.step_type, training=False) future_q_values2, _ = self._target_critic_network_2( future_input, future_time_steps.step_type, training=False) gamma_n = gamma**self._n_step # Discount for n-step returns target_q_values1 = (target_q_values1 + gamma_n * future_q_values1) / 2.0 target_q_values2 = (target_q_values2 + gamma_n * future_q_values2) / 2.0 if q_combinator == 'min': target_q_values = tf.minimum(target_q_values1, target_q_values2) else: assert q_combinator == 'max' target_q_values = tf.maximum(target_q_values1, target_q_values2) batch_size = time_steps.observation.shape[0] if loss_name == 'q': if use_done: td_targets = gamma * next_time_steps.discount * target_q_values else: td_targets = gamma * target_q_values else: assert loss_name == 'c' w = target_q_values / (1 - target_q_values) td_targets = gamma * w / (gamma * w + 1) if use_done: td_targets = next_time_steps.discount * td_targets weights = tf.concat([1 + gamma * w, (1 - gamma) * tf.ones(batch_size)], axis=0) td_targets = tf.stop_gradient(td_targets) td_targets = tf.concat([td_targets, tf.ones(batch_size)], axis=0) # Note that the actions only depend on the observations. We create the # expert_time_steps object simply to make this look like a time step # object. expert_time_steps = time_steps._replace(observation=expert_experience) if self._use_behavior_policy: policy_state = self._train_policy.get_initial_state(batch_size) action_distribution = self._behavior_policy.distribution( time_steps, policy_state=policy_state).action # Sample actions and log_pis from transformed distribution. expert_actions = tf.nest.map_structure(lambda d: d.sample(), action_distribution) else: expert_actions, _ = self._actions_and_log_probs(expert_time_steps) observation = time_steps.observation pred_input = (tf.concat([observation, expert_experience], axis=0), tf.concat([actions, expert_actions], axis=0)) pred_td_targets1, _ = self._critic_network_1( pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_2( pred_input, time_steps.step_type, training=training) self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if critic_loss.shape.rank > 1: # Sum over the time dimension. critic_loss = tf.reduce_sum( critic_loss, axis=range(1, critic_loss.shape.rank)) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) return critic_loss
def critic_loss( self, time_steps, actions, next_time_steps, td_errors_loss_fn, gamma=1.0, weights=None, training=False, w_clipping=None, self_normalized=False, lambda_fix=False, ): """Computes the critic loss for C-learning training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. w_clipping: Maximum value used for clipping the weights. Use -1 to do no clipping; use None to use the recommended value of 1 / (1 - gamma). self_normalized: Whether to normalize the weights to the average is 1. Empirically this usually hurts performance. lambda_fix: Whether to include the adjustment when using future positives. Empirically this has little effect. Returns: critic_loss: A scalar critic loss. """ del weights if w_clipping is None: w_clipping = 1 / (1 - gamma) rfp = gin.query_parameter('goal_fn.relabel_future_prob') rnp = gin.query_parameter('goal_fn.relabel_next_prob') assert rfp + rnp == 0.5 with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, _ = self._actions_and_log_probs(next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_2( target_input, next_time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) w = tf.stop_gradient(target_q_values / (1 - target_q_values)) if w_clipping >= 0: w = tf.clip_by_value(w, 0, w_clipping) tf.debugging.assert_all_finite(w, 'Not all elements of w are finite') if self_normalized: w = w / tf.reduce_mean(w) batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] half_batch = batch_size // 2 float_batch_size = tf.cast(batch_size, float) num_next = tf.cast(tf.round(float_batch_size * rnp), tf.int32) num_future = tf.cast(tf.round(float_batch_size * rfp), tf.int32) if lambda_fix: lambda_coef = 2 * rnp weights = tf.concat([ tf.fill((num_next, ), (1 - gamma)), tf.fill((num_future, ), 1.0), (1 + lambda_coef * gamma * w)[half_batch:] ], axis=0) else: weights = tf.concat([ tf.fill((num_next, ), (1 - gamma)), tf.fill((num_future, ), 1.0), (1 + gamma * w)[half_batch:] ], axis=0) # Note that we assume that episodes never terminate. If they do, then # we need to include next_time_steps.discount in the (negative) TD target. # We exclude the termination here so that we can use termination to # indicate task success during evaluation. In the evaluation setting, # task success depends on the task, but we don't want the termination # here to depend on the task. Hence, we ignored it. if lambda_fix: lambda_coef = 2 * rnp y = lambda_coef * gamma * w / (1 + lambda_coef * gamma * w) else: y = gamma * w / (1 + gamma * w) td_targets = tf.stop_gradient(next_time_steps.reward + (1 - next_time_steps.reward) * y) if rfp > 0: td_targets = tf.concat( [tf.ones(half_batch), td_targets[half_batch:]], axis=0) observation = time_steps.observation pred_input = (observation, actions) pred_td_targets1, _ = self._critic_network_1(pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_2(pred_input, time_steps.step_type, training=training) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if critic_loss.shape.rank > 1: # Sum over the time dimension. critic_loss = tf.reduce_sum(critic_loss, axis=range(1, critic_loss.shape.rank)) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2, weights) return critic_loss
def critic_loss(self, time_steps: ts.TimeStep, actions: types.Tensor, next_time_steps: ts.TimeStep, td_errors_loss_fn: types.LossFn, gamma: types.Float = 1.0, reward_scale_factor: types.Float = 1.0, weights: Optional[types.Tensor] = None, training: bool = False) -> types.Tensor: """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, next_log_pis = self._actions_and_log_probs( next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_2( target_input, next_time_steps.step_type, training=False) target_q_values = (tf.minimum(target_q_values1, target_q_values2) - tf.exp(self._log_alpha) * next_log_pis) td_targets = tf.stop_gradient( reward_scale_factor * next_time_steps.reward + gamma * next_time_steps.discount * target_q_values) pred_input = (time_steps.observation, actions) pred_td_targets1, _ = self._critic_network_1(pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_2(pred_input, time_steps.step_type, training=training) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if critic_loss.shape.rank > 1: # Sum over the time dimension. critic_loss = tf.reduce_sum(critic_loss, axis=range(1, critic_loss.shape.rank)) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_1.losses + self._critic_network_2.losses)) critic_loss = agg_loss.total_loss self._critic_loss_debug_summaries(td_targets, pred_td_targets1, pred_td_targets2) return critic_loss
def test_aggregate_losses_with_time_dimension(self): per_example_loss = tf.constant([[4., 2., 3.], [1, 1, 1]]) aggregated_losses = common.aggregate_losses(per_example_loss) expected_per_example_loss = (4 + 2 + 3 + 1 + 1 + 1) / 6 self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss), expected_per_example_loss)
def test_aggregate_losses_without_time_dimension_with_weights(self): per_example_loss = tf.constant([4., 2., 3.]) sample_weights = tf.constant([1., 1., 0.]) aggregated_losses = common.aggregate_losses(per_example_loss, sample_weights) self.assertAlmostEqual(self.evaluate(aggregated_losses.total_loss), 2)
def critic_no_entropy_loss(self, time_steps, actions, next_time_steps, td_errors_loss_fn, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. training: Whether this loss is being used for training. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_no_entropy_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, _ = self._actions_and_log_probs(next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_no_entropy_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_no_entropy_2( target_input, next_time_steps.step_type, training=False) target_q_values = tf.minimum( target_q_values1, target_q_values2 ) # entropy has been removed from the target critic function td_targets = tf.stop_gradient( reward_scale_factor * next_time_steps.reward + gamma * next_time_steps.discount * target_q_values) pred_input = (time_steps.observation, actions) pred_td_targets1, _ = self._critic_network_no_entropy_1( pred_input, time_steps.step_type, training=training) pred_td_targets2, _ = self._critic_network_no_entropy_2( pred_input, time_steps.step_type, training=training) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Sum over the time dimension. critic_loss = tf.reduce_sum(input_tensor=critic_loss, axis=1) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, sample_weight=weights, regularization_loss=(self._critic_network_no_entropy_1.losses + self._critic_network_no_entropy_2.losses)) critic_no_entropy_loss = agg_loss.total_loss self._critic_no_entropy_loss_debug_summaries( td_targets, pred_td_targets1, pred_td_targets2) return critic_no_entropy_loss
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action valid_mask = tf.cast(~time_steps.is_last(), tf.float32) with tf.name_scope('loss'): # q_values is already gathered by actions q_values = self._compute_q_values(time_steps, actions, training=training) next_q_values = self._compute_next_all_q_values( next_time_steps, policy_steps.info) q_target_values = self._compute_next_all_q_values( time_steps, policy_steps.info) # This applies to any value of n_step_update and also in the RNN-DQN case. # In the RNN-DQN case, inputs and outputs contain a time dimension. #td_targets = compute_td_targets( # next_q_values, # rewards=reward_scale_factor * next_time_steps.reward, # discounts=gamma * next_time_steps.discount) td_targets = compute_munchausen_td_targets( next_q_values=next_q_values, q_target_values=q_target_values, actions=actions, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount, multi_dim_actions=self._action_spec.shape.rank > 0, alpha=self.alpha, entropy_tau=self.entropy_tau) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def _loss(self, experience, weights=None): """Computes loss for behavioral cloning. Args: experience: A `Trajectory` containing experience. weights: Optional scalar or element-wise (per-batch-entry) importance weights. Returns: loss: A `LossInfo` struct. Raises: ValueError: If the number of actions is greater than 1. """ with tf.name_scope('loss'): if self._nested_actions: actions = experience.action else: actions = tf.nest.flatten(experience.action)[0] batch_size = (tf.compat.dimension_value( experience.step_type.shape[0]) or tf.shape(experience.step_type)[0]) logits, _ = self._cloning_network( experience.observation, experience.step_type, training=True, network_state=self._cloning_network.get_initial_state( batch_size)) error = self._loss_fn(logits, actions) error_dtype = tf.nest.flatten(error)[0].dtype boundary_weights = tf.cast(~experience.is_boundary(), error_dtype) error *= boundary_weights if nest_utils.is_batched_nested_tensors(experience.action, self.action_spec, num_outer_dims=2): # Do a sum over the time dimension. error = tf.reduce_sum(input_tensor=error, axis=1) # Average across the elements of the batch. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=error, sample_weight=weights, regularization_loss=self._cloning_network.losses) total_loss = agg_loss.total_loss dict_losses = { 'loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(dict_losses, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._cloning_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: common.generate_tensor_summaries('errors', error, self.train_step_counter) return tf_agent.LossInfo(total_loss, BehavioralCloningLossInfo(loss=error))