def _apply_loss(self, aggregated_losses, variables_to_train, tape, optimizer): total_loss = aggregated_losses.total_loss tf.debugging.check_numerics(total_loss, "Loss is inf or nan") assert list(variables_to_train), "No variables in the agent's network." grads = tape.gradient(total_loss, variables_to_train) grads_and_vars = list(zip(grads, variables_to_train)) if self._gradient_clipping is not None: grads_and_vars = eager_utils.clip_gradient_norms( grads_and_vars, self._gradient_clipping) if self.summarize_grads_and_vars: eager_utils.add_variables_summaries(grads_and_vars, self.train_step_counter) optimizer.apply_gradients(grads_and_vars) if self.summaries_enabled: dict_losses = { "loss": aggregated_losses.weighted, "reg_loss": aggregated_losses.regularization, "total_loss": total_loss } common.summarize_scalar_dict(dict_losses, step=self.train_step_counter, name_scope="Losses/")
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action valid_mask = tf.cast(~time_steps.is_last(), tf.float32) with tf.name_scope('loss'): # q_values is already gathered by actions q_values = self._compute_q_values(time_steps, actions, training=training) next_q_values = self._compute_next_all_q_values( next_time_steps, policy_steps.info) q_target_values = self._compute_next_all_q_values( time_steps, policy_steps.info) # This applies to any value of n_step_update and also in the RNN-DQN case. # In the RNN-DQN case, inputs and outputs contain a time dimension. #td_targets = compute_td_targets( # next_q_values, # rewards=reward_scale_factor * next_time_steps.reward, # discounts=gamma * next_time_steps.discount) td_targets = compute_munchausen_td_targets( next_q_values=next_q_values, q_target_values=q_target_values, actions=actions, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount, multi_dim_actions=self._action_spec.shape.rank > 0, alpha=self.alpha, entropy_tau=self.entropy_tau) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes loss for DQN training. Args: experience: A batch of experience data in the form of a `Trajectory` or `Transition`. The structure of `experience` must match that of `self.collect_policy.step_spec`. If a `Trajectory`, all tensors in `experience` must be shaped `[B, T, ...]` where `T` must be equal to `self.train_sequence_length` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute the element wise loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output td_loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether this loss is being used for training. Returns: loss: An instance of `DqnLossInfo`. Raises: ValueError: if the number of actions is greater than 1. """ transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action with tf.name_scope('loss'): q_values = self._compute_q_values(time_steps, actions, training=training) next_q_values = self._compute_next_q_values( next_time_steps, policy_steps.info) # This applies to any value of n_step_update and also in the RNN-DQN case. # In the RNN-DQN case, inputs and outputs contain a time dimension. td_targets = compute_td_targets( next_q_values, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount) valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors( time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = {'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss} common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo(total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes loss for DQN training. Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.train_sequence_length` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute the element wise loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output td_loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether this loss is being used for training. Returns: loss: An instance of `DqnLossInfo`. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) squeeze_time_dim = not self._q_network.state_spec if self._n_step_update == 1: time_steps, policy_steps, next_time_steps = ( trajectory.experience_to_transitions(experience, squeeze_time_dim)) actions = policy_steps.action else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, policy_steps, _ = ( trajectory.experience_to_transitions(first_two_steps, squeeze_time_dim)) actions = policy_steps.action _, _, next_time_steps = (trajectory.experience_to_transitions( last_two_steps, squeeze_time_dim)) with tf.name_scope('loss'): q_values = self._compute_q_values(time_steps, actions, training=training) next_q_values = self._compute_next_q_values( next_time_steps, policy_steps.info) if self._n_step_update == 1: # Special case for n = 1 to avoid a loss of performance. td_targets = compute_td_targets( next_q_values, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount) else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. td_targets = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=next_q_values, time_major=False, provide_all_returns=False) valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def _loss(self, experience, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes critic loss for CategoricalDQN training. See Algorithm 1 and the discussion immediately preceding it in page 6 of "A Distributional Perspective on Reinforcement Learning" Bellemare et al., 2017 https://arxiv.org/abs/1707.06887 Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.required_experience_time_steps` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional weights used for importance sampling. training: Whether the loss is being used for training. Returns: critic_loss: A scalar critic loss. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) squeeze_time_dim = not self._q_network.state_spec if self._n_step_update == 1: time_steps, policy_steps, next_time_steps = ( trajectory.experience_to_transitions(experience, squeeze_time_dim)) actions = policy_steps.action else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, policy_steps, _ = ( trajectory.experience_to_transitions(first_two_steps, squeeze_time_dim)) actions = policy_steps.action _, _, next_time_steps = (trajectory.experience_to_transitions( last_two_steps, squeeze_time_dim)) with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) rank = nest_utils.get_outer_rank(time_steps.observation, self._time_step_spec.observation) # If inputs have a time dimension and the q_network is stateful, # combine the batch and time dimension. batch_squash = (None if rank <= 1 or self._q_network.state_spec in ((), None) else network_utils.BatchSquash(rank)) network_observation = time_steps.observation if self._observation_and_action_constraint_splitter is not None: network_observation, _ = ( self._observation_and_action_constraint_splitter( network_observation)) # q_logits contains the Q-value logits for all actions. q_logits, _ = self._q_network(network_observation, time_steps.step_type, training=training) if batch_squash is not None: # Squash outer dimensions to a single dimensions for facilitation # computing the loss the following. Required for supporting temporal # inputs, for example. q_logits = batch_squash.flatten(q_logits) actions = batch_squash.flatten(actions) next_time_steps = tf.nest.map_structure( batch_squash.flatten, next_time_steps) next_q_distribution = self._next_q_distribution(next_time_steps) if actions.shape.rank > 1: actions = tf.squeeze(actions, list(range(1, actions.shape.rank))) # Project the sample Bellman update \hat{T}Z_{\theta} onto the original # support of Z_{\theta} (see Figure 1 in paper). batch_size = q_logits.shape[0] or tf.shape(q_logits)[0] tiled_support = tf.tile(self._support, [batch_size]) tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms]) if self._n_step_update == 1: discount = next_time_steps.discount if discount.shape.rank == 1: # We expect discount to have a shape of [batch_size], while # tiled_support will have a shape of [batch_size, num_atoms]. To # multiply these, we add a second dimension of 1 to the discount. discount = tf.expand_dims(discount, -1) next_value_term = tf.multiply(discount, tiled_support, name='next_value_term') reward = next_time_steps.reward if reward.shape.rank == 1: # See the explanation above. reward = tf.expand_dims(reward, -1) reward_term = tf.multiply(reward_scale_factor, reward, name='reward_term') target_support = tf.add(reward_term, gamma * next_value_term, name='target_support') else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. discounted_returns = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=tf.zeros([batch_size], dtype=discounts.dtype), time_major=False, provide_all_returns=False) # Convert discounted_returns from [batch_size] to [batch_size, 1] discounted_returns = tf.expand_dims(discounted_returns, -1) final_value_discount = tf.reduce_prod(discounts, axis=1) final_value_discount = tf.expand_dims(final_value_discount, -1) # Save the values of discounted_returns and final_value_discount in # order to check them in unit tests. self._discounted_returns = discounted_returns self._final_value_discount = final_value_discount target_support = tf.add(discounted_returns, final_value_discount * tiled_support, name='target_support') target_distribution = tf.stop_gradient( project_distribution(target_support, next_q_distribution, self._support)) # Obtain the current Q-value logits for the selected actions. indices = tf.range(batch_size) indices = tf.cast(indices, actions.dtype) reshaped_actions = tf.stack([indices, actions], axis=-1) chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions) # Compute the cross-entropy loss between the logits. If inputs have # a time dimension, compute the sum over the time dimension before # computing the mean over the batch dimension. if batch_squash is not None: target_distribution = batch_squash.unflatten( target_distribution) chosen_action_logits = batch_squash.unflatten( chosen_action_logits) critic_loss = tf.reduce_sum( tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits), axis=1) else: critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss dict_losses = { 'critic_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(dict_losses, step=self.train_step_counter, name_scope='Losses/') if self._debug_summaries: distribution_errors = target_distribution - chosen_action_logits with tf.name_scope('distribution_errors'): common.generate_tensor_summaries( 'distribution_errors', distribution_errors, step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean', tf.reduce_mean(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean_abs', tf.reduce_mean(tf.abs(distribution_errors)), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'max', tf.reduce_max(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'min', tf.reduce_min(distribution_errors), step=self.train_step_counter) with tf.name_scope('target_distribution'): common.generate_tensor_summaries( 'target_distribution', target_distribution, step=self.train_step_counter) # TODO(b/127318640): Give appropriate values for td_loss and td_error for # prioritized replay. return tf_agent.LossInfo( total_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
def _loss(self, experience, weights=None): """Computes loss for behavioral cloning. Args: experience: A `Trajectory` containing experience. weights: Optional scalar or element-wise (per-batch-entry) importance weights. Returns: loss: A `LossInfo` struct. Raises: ValueError: If the number of actions is greater than 1. """ with tf.name_scope('loss'): if self._nested_actions: actions = experience.action else: actions = tf.nest.flatten(experience.action)[0] batch_size = (tf.compat.dimension_value( experience.step_type.shape[0]) or tf.shape(experience.step_type)[0]) logits, _ = self._cloning_network( experience.observation, experience.step_type, training=True, network_state=self._cloning_network.get_initial_state( batch_size)) error = self._loss_fn(logits, actions) error_dtype = tf.nest.flatten(error)[0].dtype boundary_weights = tf.cast(~experience.is_boundary(), error_dtype) error *= boundary_weights if nest_utils.is_batched_nested_tensors(experience.action, self.action_spec, num_outer_dims=2): # Do a sum over the time dimension. error = tf.reduce_sum(input_tensor=error, axis=1) # Average across the elements of the batch. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=error, sample_weight=weights, regularization_loss=self._cloning_network.losses) total_loss = agg_loss.total_loss dict_losses = { 'loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(dict_losses, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._cloning_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: common.generate_tensor_summaries('errors', error, self.train_step_counter) return tf_agent.LossInfo(total_loss, BehavioralCloningLossInfo(loss=error))
def total_loss(self, experience: traj.Trajectory, returns: types.Tensor, weights: types.Tensor, training: bool = False) -> tf_agent.LossInfo: # Ensure we see at least one full episode. time_steps = ts.TimeStep(experience.step_type, tf.zeros_like(experience.reward), tf.zeros_like(experience.discount), experience.observation) is_last = experience.is_last() num_episodes = tf.reduce_sum(tf.cast(is_last, tf.float32)) tf.debugging.assert_greater( num_episodes, 0.0, message= 'No complete episode found. REINFORCE requires full episodes ' 'to compute losses.') # Mask out partial episodes at the end of each batch of time_steps. # NOTE: We use is_last rather than is_boundary because the last transition # is the transition with the last valid reward. In other words, the # reward on the boundary transitions do not have valid rewards. Since # REINFORCE is calculating a loss w.r.t. the returns (and not bootstrapping) # keeping the boundary transitions is irrelevant. valid_mask = tf.cast(experience.is_last(), dtype=tf.float32) valid_mask = tf.math.cumsum(valid_mask, axis=1, reverse=True) valid_mask = tf.cast(valid_mask > 0, dtype=tf.float32) if weights is not None: weights *= valid_mask else: weights = valid_mask advantages = returns value_preds = None if self._baseline: value_preds, _ = self._value_network(time_steps.observation, time_steps.step_type, training=True) if self._debug_summaries: tf.compat.v2.summary.histogram(name='value_preds', data=value_preds, step=self.train_step_counter) advantages = self._advantage_fn(returns, value_preds) if self._debug_summaries: tf.compat.v2.summary.histogram(name='advantages', data=advantages, step=self.train_step_counter) # TODO(b/126592060): replace with tensor normalizer. if self._normalize_returns: advantages = _standard_normalize(advantages, axes=(0, 1)) if self._debug_summaries: tf.compat.v2.summary.histogram( name='normalized_%s' % ('advantages' if self._baseline else 'returns'), data=advantages, step=self.train_step_counter) nest_utils.assert_same_structure(time_steps, self.time_step_spec) policy_state = _get_initial_policy_state(self.collect_policy, time_steps) actions_distribution = self.collect_policy.distribution( time_steps, policy_state=policy_state).action policy_gradient_loss = self.policy_gradient_loss( actions_distribution, experience.action, experience.is_boundary(), advantages, num_episodes, weights, ) entropy_regularization_loss = self.entropy_regularization_loss( actions_distribution, weights) network_regularization_loss = tf.nn.scale_regularization_loss( self._actor_network.losses) total_loss = (policy_gradient_loss + network_regularization_loss + entropy_regularization_loss) losses_dict = { 'policy_gradient_loss': policy_gradient_loss, 'policy_network_regularization_loss': network_regularization_loss, 'entropy_regularization_loss': entropy_regularization_loss, 'value_estimation_loss': 0.0, 'value_network_regularization_loss': 0.0, } value_estimation_loss = None if self._baseline: value_estimation_loss = self.value_estimation_loss( value_preds, returns, num_episodes, weights) value_network_regularization_loss = tf.nn.scale_regularization_loss( self._value_network.losses) total_loss += value_estimation_loss + value_network_regularization_loss losses_dict['value_estimation_loss'] = value_estimation_loss losses_dict['value_network_regularization_loss'] = ( value_network_regularization_loss) loss_info_extra = ReinforceAgentLossInfo(**losses_dict) losses_dict[ 'total_loss'] = total_loss # Total loss not in loss_info_extra. common.summarize_scalar_dict(losses_dict, self.train_step_counter, name_scope='Losses/') return tf_agent.LossInfo(total_loss, loss_info_extra)
def _loss_h(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action valid_mask = tf.cast(~time_steps.is_last(), tf.float32) with tf.name_scope('loss'): # q_values is already gathered by actions h_values = self._compute_h_values(time_steps, actions, training=training) multi_dim_actions = self._action_spec.shape.rank > 0 next_q_all_values = self._compute_next_all_q_values( next_time_steps, policy_steps.info) next_h_all_values = self._compute_next_all_h_values( next_time_steps, policy_steps.info) next_h_actions = tf.argmax(next_h_all_values, axis=1) # next_h_values here is used only for logging next_h_values = self._compute_next_h_values( next_time_steps, policy_steps.info) # next_q_values refer to Q(r,s') in Eqs.(4),(5) next_q_values = common.index_with_actions( next_q_all_values, tf.cast(next_h_actions, dtype=tf.int32), multi_dim_actions) h_target_all_values = self._compute_next_all_h_values( time_steps, policy_steps.info) h_target_values = common.index_with_actions( h_target_all_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions) td_targets = compute_momentum_td_targets( q_target_values=next_q_values, h_target_values=h_target_values, beta=self.beta()) td_error = valid_mask * (td_targets - h_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, h_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._h_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_h_values = h_values - next_h_values common.generate_tensor_summaries('td_error_h', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss_h', td_loss, self.train_step_counter) common.generate_tensor_summaries('h_values', h_values, self.train_step_counter) common.generate_tensor_summaries('next_h_values', next_h_values, self.train_step_counter) common.generate_tensor_summaries('diff_h_values', diff_h_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def _loss(self, net, individual_iql_time_step, individual_iql_next_time_step, time_steps, actions, next_time_steps, i, t, td_errors_loss_fn, gamma=1.0, weights=None, training=False): with tf.name_scope('loss'): individual_target = tf.reshape( net._compute_next_q_values(next_time_steps, index=i, time=t), [-1, 1]) individual_main = tf.reshape( net._compute_q_values(time_steps, actions, index=i, time=t, training=True), [-1, 1]) reward = tf.reshape(individual_iql_next_time_step.reward, [-1, 1]) discount = tf.reshape(individual_iql_next_time_step.discount, [-1, 1]) td_targets = tf.stop_gradient(reward + gamma * discount * individual_target) valid_mask = tf.reshape( tf.cast(~individual_iql_time_step.is_last(), tf.float32), [-1, 1]) td_error = valid_mask * (td_targets - individual_main) td_loss = valid_mask * tf.compat.v1.losses.absolute_difference( td_targets, individual_main, reduction=tf.compat.v1.losses.Reduction.NONE) # td_loss = valid_mask * td_errors_loss_fn(td_targets, q_total) if nest_utils.is_batched_nested_tensors(individual_iql_time_step, net.agent.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. agg_loss = common.aggregate_losses(per_example_loss=td_loss, sample_weight=weights) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self.summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in net.agent.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self.debug_summaries: diff_q_values = individual_main - individual_target common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_total', individual_main, self.train_step_counter) common.generate_tensor_summaries('target_q_total', individual_target, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def _loss(self, time_steps, policy_steps, next_time_steps, agents, nameDict, networkDict, td_errors_loss_fn, gamma=1.0, weights=None, training=False): with tf.name_scope('loss'): total_agents_target = [] total_agents_main = [] for i, flexAgent in enumerate(agents): for node in nameDict: target = None for type, names in nameDict[node].items(): if flexAgent.id in names: target = [] main = [] for net in networkDict[node][type]: action_index = -1 for t in range(24): action_index += 1 actions = tf.gather(policy_steps.action[i], indices=action_index, axis=-1) individual_target = net._compute_next_q_values( next_time_steps, index=i, time=t) individual_main = net._compute_q_values( time_steps, actions, index=i, time=t, training=True) target.append( tf.reshape(individual_target, [-1, 1])) main.append( tf.reshape(individual_main, [-1, 1])) break if target is not None: break total_agents_target.append(tf.concat(target, -1)) total_agents_main.append(tf.concat(main, -1)) total_agents_target = tf.concat(total_agents_target, -1) total_agents_main = tf.concat(total_agents_main, -1) q_total, _ = self.QMIXNet(total_agents_main, time_steps.observation, training=training) q_total = tf.squeeze(q_total) target_q_total, _ = self.TargetQMIXNet(total_agents_target, next_time_steps.observation, training=False) target_q_total = tf.squeeze(target_q_total) """using the mean reward for all the agents""" mean_reward = tf.reduce_mean(next_time_steps.reward, axis=1) td_targets = tf.stop_gradient(mean_reward + gamma * next_time_steps.discount * target_q_total) valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_total) td_loss = valid_mask * tf.compat.v1.losses.absolute_difference( td_targets, q_total, reduction=tf.compat.v1.losses.Reduction.NONE) # td_loss = valid_mask * td_errors_loss_fn(td_targets, q_total) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. agg_loss = common.aggregate_losses(per_example_loss=td_loss, sample_weight=weights) total_loss = agg_loss.total_loss if self.summary_writer is not None: with self.summary_writer.as_default(): tf.summary.scalar('loss', total_loss, step=self.train_step_counter) losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self.summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self.QMIXNet.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self.debug_summaries: diff_q_values = q_total - target_q_total common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_total', q_total, self.train_step_counter) common.generate_tensor_summaries('target_q_total', target_q_total, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))