def _train(self, experience, weights, episode_data=None, augmented_obs=None, augmented_next_obs=None): """Returns a train op to update the agent's networks. This method trains with the provided batched experience. Args: experience: A time-stacked trajectory object. If augmentations > 1 then a tuple of the form: ``` (trajectory, [augmentation_1, ... , augmentation_{K-1}]) ``` is expected. weights: Optional scalar or elementwise (per-batch-entry) importance weights. episode_data: Tuple of (episode, episode, metric) for contrastive loss. augmented_obs: List of length num_augmentations - 1 of random crops of the trajectory's observation. augmented_next_obs: List of length num_augmentations - 1 of random crops of the trajectory's next_observation. Returns: A train_op. Raises: ValueError: If optimizers are None and no default value was provided to the constructor. """ squeeze_time_dim = not self._critic_network_1.state_spec time_steps, policy_steps, next_time_steps = ( trajectory.experience_to_transitions(experience, squeeze_time_dim)) actions = policy_steps.action trainable_critic_variables = ( self._critic_network_1.trainable_variables + self._critic_network_2.trainable_variables) with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_critic_variables, ( 'No trainable critic variables to ' 'optimize.') tape.watch(trainable_critic_variables) critic_loss = self._critic_loss_weight * self.critic_loss( time_steps, actions, next_time_steps, augmented_obs, augmented_next_obs, td_errors_loss_fn=self._td_errors_loss_fn, gamma=self._gamma, reward_scale_factor=self._reward_scale_factor, weights=weights, training=True) tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.') critic_grads = tape.gradient(critic_loss, trainable_critic_variables) self._apply_gradients(critic_grads, trainable_critic_variables, self._critic_optimizer) total_loss = critic_loss actor_loss = tf.constant(0.0, tf.float32) alpha_loss = tf.constant(0.0, tf.float32) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='critic_loss', data=critic_loss, step=self.train_step_counter) # Only perform actor and alpha updates periodically if self.train_step_counter % self._actor_update_frequency == 0: trainable_actor_variables = self._actor_network.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_actor_variables, ( 'No trainable actor variables to ' 'optimize.') tape.watch(trainable_actor_variables) actor_loss = self._actor_loss_weight * self.actor_loss( time_steps, weights=weights) tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.') actor_grads = tape.gradient(actor_loss, trainable_actor_variables) self._apply_gradients(actor_grads, trainable_actor_variables, self._actor_optimizer) alpha_variable = [self._log_alpha] with tf.GradientTape(watch_accessed_variables=False) as tape: assert alpha_variable, 'No alpha variable to optimize.' tape.watch(alpha_variable) alpha_loss = self._alpha_loss_weight * self.alpha_loss( time_steps, weights=weights) tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.') alpha_grads = tape.gradient(alpha_loss, alpha_variable) self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='actor_loss', data=actor_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar(name='alpha_loss', data=alpha_loss, step=self.train_step_counter) total_loss = critic_loss + actor_loss + alpha_loss # Contrastive loss for PSEs contrastive_loss = 0.0 if self._contrastive_loss_weight > 0: contrastive_vars = self._actor_network.encoder_variables with tf.GradientTape(watch_accessed_variables=True, persistent=True) as tape: contrastive_loss = (self._contrastive_loss_weight * self.contrastive_metric_loss(episode_data)) total_loss = total_loss + contrastive_loss tf.debugging.check_numerics(contrastive_loss, 'Contrastive loss is inf or nan.') contrastive_grads = tape.gradient(contrastive_loss, contrastive_vars) self._apply_gradients(contrastive_grads, contrastive_vars, self._contrastive_optimizer) del tape self.train_step_counter.assign_add(1) self._update_target() # NOTE: Consider keeping track of previous actor/alpha loss. extra = SacContrastiveLossInfo(critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss, contrastive_loss=contrastive_loss) return tf_agent.LossInfo(loss=total_loss, extra=extra)
def _loss(self, experience, weights=None): """Computes loss for behavioral cloning. Args: experience: A `Trajectory` containing experience. weights: Optional scalar or element-wise (per-batch-entry) importance weights. Returns: loss: A `LossInfo` struct. Raises: ValueError: If the number of actions is greater than 1. """ with tf.name_scope('loss'): if self._nested_actions: actions = experience.action else: actions = tf.nest.flatten(experience.action)[0] logits, _ = self._cloning_network(experience.observation, experience.step_type, training=True) error = self._loss_fn(logits, actions) error_dtype = tf.nest.flatten(error)[0].dtype boundary_weights = tf.cast(~experience.is_boundary(), error_dtype) error *= boundary_weights if nest_utils.is_batched_nested_tensors(experience.action, self.action_spec, num_outer_dims=2): # Do a sum over the time dimension. error = tf.reduce_sum(input_tensor=error, axis=1) # Average across the elements of the batch. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. if weights is not None: error *= weights loss = tf.reduce_mean(input_tensor=error) with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar(name='loss', data=loss, step=self.train_step_counter) if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._cloning_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: common.generate_tensor_summaries('errors', error, self.train_step_counter) return tf_agent.LossInfo(loss, BehavioralCloningLossInfo(loss=error))
def total_loss(self, time_steps, actions, returns, weights): # Ensure we see at least one full episode. is_last = time_steps.is_last() num_episodes = tf.reduce_sum(tf.cast(is_last, tf.float32)) tf.debugging.assert_greater( num_episodes, 0.0, message='No complete episode found. REINFORCE requires full episodes ' 'to compute losses.') # Mask out partial episodes at the end of each batch of time_steps. valid_mask = tf.cast(is_last, dtype=tf.float32) valid_mask = tf.math.cumsum(valid_mask, axis=1, reverse=True) valid_mask = tf.cast(valid_mask > 0, dtype=tf.float32) if weights is not None: weights *= valid_mask else: weights = valid_mask advantages = returns if self._baseline: value_preds, _ = self._value_network( time_steps.observation, time_steps.step_type) advantages = returns - value_preds if self._debug_summaries: tf.compat.v2.summary.histogram( name='value_preds', data=value_preds, step=self.train_step_counter) tf.compat.v2.summary.histogram( name='advantages', data=advantages, step=self.train_step_counter) # TODO(b/126592060): replace with tensor normalizer. if self._normalize_returns: advantages = _standard_normalize(advantages, axes=(0, 1)) if self._debug_summaries: tf.compat.v2.summary.histogram( name='normalized_%s'%'advantages' if self._baseline else 'returns', data=advantages, step=self.train_step_counter) tf.nest.assert_same_structure(time_steps, self.time_step_spec) policy_state = _get_initial_policy_state(self.collect_policy, time_steps) actions_distribution = self.collect_policy.distribution( time_steps, policy_state=policy_state).action policy_gradient_loss = self.policy_gradient_loss(actions_distribution, actions, is_last, advantages, num_episodes, weights) entropy_regularization_loss = self.entropy_regularization_loss( actions_distribution, weights) total_loss = policy_gradient_loss + entropy_regularization_loss if self._baseline: value_estimation_loss = self.value_estimation_loss( value_preds, returns, num_episodes, weights) total_loss += value_estimation_loss with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar( name='policy_gradient_loss', data=policy_gradient_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='entropy_regularization_loss', data=entropy_regularization_loss, step=self.train_step_counter) if self._baseline: tf.compat.v2.summary.scalar( name='value_estimation_loss', data=value_estimation_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='total_loss', data=total_loss, step=self.train_step_counter) return tf_agent.LossInfo(total_loss, ())
def _train(self, experience, weights): """Returns a train op to update the agent's networks. This method trains with the provided batched experience. Args: experience: A time-stacked trajectory object. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: A train_op. Raises: ValueError: If optimizers are None and no default value was provided to the constructor. """ transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action trainable_critic_variables = list( object_identity.ObjectIdentitySet( self._critic_network_1.trainable_variables + self._critic_network_2.trainable_variables)) with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_critic_variables, ( 'No trainable critic variables to ' 'optimize.') tape.watch(trainable_critic_variables) critic_loss = self._critic_loss_with_optional_entropy_term( time_steps, actions, next_time_steps, td_errors_loss_fn=self._td_errors_loss_fn, gamma=self._gamma, reward_scale_factor=self._reward_scale_factor, weights=weights, training=True) critic_loss *= self._critic_loss_weight cql_alpha = self._get_cql_alpha() cql_loss = self._cql_loss(time_steps, actions, training=True) if self._bc_debug_mode: cql_critic_loss = cql_loss * cql_alpha else: cql_critic_loss = critic_loss + (cql_loss * cql_alpha) tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.') tf.debugging.check_numerics(cql_loss, 'CQL loss is inf or nan.') critic_grads = tape.gradient(cql_critic_loss, trainable_critic_variables) self._apply_gradients(critic_grads, trainable_critic_variables, self._critic_optimizer) trainable_actor_variables = self._actor_network.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_actor_variables, ( 'No trainable actor variables to ' 'optimize.') tape.watch(trainable_actor_variables) actor_loss = self._actor_loss_weight * self.actor_loss( time_steps, actions=actions, weights=weights) tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.') actor_grads = tape.gradient(actor_loss, trainable_actor_variables) self._apply_gradients(actor_grads, trainable_actor_variables, self._actor_optimizer) alpha_variable = [self._log_alpha] with tf.GradientTape(watch_accessed_variables=False) as tape: assert alpha_variable, 'No alpha variable to optimize.' tape.watch(alpha_variable) alpha_loss = self._alpha_loss_weight * self.alpha_loss( time_steps, weights=weights) tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.') alpha_grads = tape.gradient(alpha_loss, alpha_variable) self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer) # Based on the equation (24), which automates CQL alpha with the "budget" # parameter tau. CQL(H) is now CQL-Lagrange(H): # ``` # min_Q max_{alpha >= 0} alpha * (log_sum_exp(Q(s, a')) - Q(s, a) - tau) # ``` # If the expected difference in Q-values is less than tau, alpha # will adjust to be closer to 0. If the difference is higher than tau, # alpha is likely to take on high values and more aggressively penalize # Q-values. cql_alpha_loss = tf.constant(0.) if self._use_lagrange_cql_alpha: cql_alpha_variable = [self._log_cql_alpha] with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(cql_alpha_variable) cql_alpha_loss = -self._get_cql_alpha() * (cql_loss - self._cql_tau) tf.debugging.check_numerics(cql_alpha_loss, 'CQL alpha loss is inf or nan.') cql_alpha_gradients = tape.gradient(cql_alpha_loss, cql_alpha_variable) self._apply_gradients(cql_alpha_gradients, cql_alpha_variable, self._cql_alpha_optimizer) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='critic_loss', data=critic_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar(name='actor_loss', data=actor_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar(name='alpha_loss', data=alpha_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar(name='cql_loss', data=cql_loss, step=self.train_step_counter) if self._use_lagrange_cql_alpha: tf.compat.v2.summary.scalar(name='cql_alpha_loss', data=cql_alpha_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar(name='cql_alpha', data=cql_alpha, step=self.train_step_counter) tf.compat.v2.summary.scalar(name='sac_alpha', data=tf.exp(self._log_alpha), step=self.train_step_counter) self.train_step_counter.assign_add(1) self._update_target() total_loss = cql_critic_loss + actor_loss + alpha_loss extra = CqlSacLossInfo(critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss, cql_loss=cql_loss, cql_alpha=cql_alpha, cql_alpha_loss=cql_alpha_loss) return tf_agent.LossInfo(loss=total_loss, extra=extra)
def loss(self, observations, actions, rewards, weights=None, training=False): """Computes loss for reward prediction training. Args: observations: A batch of observations. actions: A batch of actions. rewards: A batch of rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output batch loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether the loss is being used for training. Returns: loss: A `LossInfo` containing the loss for the training step. Raises: ValueError: if the number of actions is greater than 1. """ with tf.name_scope('loss'): sample_weights = weights if weights else 1 if self._heteroscedastic: predictions, _ = self._reward_network(observations, training=training) predicted_values = predictions.q_value_logits predicted_log_variance = predictions.log_variance action_predicted_log_variance = common.index_with_actions( predicted_log_variance, tf.cast(actions, dtype=tf.int32)) sample_weights = sample_weights * 0.5 * tf.exp( -action_predicted_log_variance) loss = 0.5 * tf.reduce_mean(action_predicted_log_variance) # loss = 1/(2 * var(x)) * (y - f(x))^2 + 1/2 * log var(x) # Kendall, Alex, and Yarin Gal. "What Uncertainties Do We Need in # Bayesian Deep Learning for Computer Vision?." Advances in Neural # Information Processing Systems. 2017. https://arxiv.org/abs/1703.04977 else: predicted_values, _ = self._reward_network(observations, training=training) loss = tf.constant(0.0) action_predicted_values = common.index_with_actions( predicted_values, tf.cast(actions, dtype=tf.int32)) # Apply Laplacian smoothing on the estimated rewards, if applicable. if self._laplacian_matrix is not None: smoothness_batched = tf.matmul( predicted_values, tf.matmul(self._laplacian_matrix, predicted_values, transpose_b=True)) loss += (self._laplacian_smoothing_weight * tf.reduce_mean( tf.linalg.tensor_diag_part(smoothness_batched) * sample_weights)) loss += self._error_loss_fn( rewards, action_predicted_values, sample_weights, reduction=tf.compat.v1.losses.Reduction.MEAN) return tf_agent.LossInfo(loss, extra=())
def get_epoch_loss(self, time_steps, actions, act_log_probs, returns, normalized_advantages, action_distribution_parameters, weights, train_step, debug_summaries): """Compute the loss and create optimization op for one training epoch. All tensors should have a single batch dimension. Args: time_steps: A minibatch of TimeStep tuples. actions: A minibatch of actions. act_log_probs: A minibatch of action probabilities (probability under the sampling policy). returns: A minibatch of per-timestep returns. normalized_advantages: A minibatch of normalized per-timestep advantages. action_distribution_parameters: Parameters of data-collecting action distribution. Needed for KL computation. weights: Optional scalar or element-wise (per-batch-entry) importance weights. Includes a mask for invalid timesteps. train_step: A train_step variable to increment for each train step. Typically the global_step. debug_summaries: True if debug summaries should be created. Returns: A tf_agent.LossInfo named tuple with the total_loss and all intermediate losses in the extra field contained in a PPOLossInfo named tuple. """ # Evaluate the current policy on timesteps. # batch_size from time_steps batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._collect_policy.get_initial_state(batch_size) distribution_step = self._collect_policy.distribution( time_steps, policy_state) # TODO(eholly): Rename policy distributions to something clear and uniform. current_policy_distribution = distribution_step.action # Call all loss functions and add all loss values. value_estimation_loss = self.value_estimation_loss(time_steps, returns, weights, debug_summaries) policy_gradient_loss = self.policy_gradient_loss( time_steps, actions, tf.stop_gradient(act_log_probs), tf.stop_gradient(normalized_advantages), current_policy_distribution, weights, debug_summaries=debug_summaries) if self._policy_l2_reg > 0.0 or self._value_function_l2_reg > 0.0: l2_regularization_loss = self.l2_regularization_loss(debug_summaries) else: l2_regularization_loss = tf.zeros_like(policy_gradient_loss) if self._entropy_regularization > 0.0: entropy_regularization_loss = self.entropy_regularization_loss( time_steps, current_policy_distribution, weights, debug_summaries) else: entropy_regularization_loss = tf.zeros_like(policy_gradient_loss) kl_penalty_loss = self.kl_penalty_loss( time_steps, action_distribution_parameters, current_policy_distribution, weights, debug_summaries) total_loss = ( policy_gradient_loss + value_estimation_loss + l2_regularization_loss + entropy_regularization_loss + kl_penalty_loss) return tf_agent.LossInfo( total_loss, PPOLossInfo( policy_gradient_loss=policy_gradient_loss, value_estimation_loss=value_estimation_loss, l2_regularization_loss=l2_regularization_loss, entropy_regularization_loss=entropy_regularization_loss, kl_penalty_loss=kl_penalty_loss, ))
def _train(self, experience, weights): """Returns a train op to update the agent's networks. This method trains with the provided batched experience. Args: experience: A time-stacked trajectory object. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: A train_op. Raises: ValueError: If optimizers are None and no default value was provided to the constructor. """ transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action trainable_critic_variables = list( object_identity.ObjectIdentitySet( self._critic_network_1.trainable_variables + self._critic_network_2.trainable_variables)) with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_critic_variables, ( 'No trainable critic variables to ' 'optimize.') tape.watch(trainable_critic_variables) critic_loss = self._critic_loss_weight * self.critic_loss( time_steps, actions, next_time_steps, td_errors_loss_fn=self._td_errors_loss_fn, gamma=self._gamma, reward_scale_factor=self._reward_scale_factor, weights=weights, training=True) tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.') critic_grads = tape.gradient(critic_loss, trainable_critic_variables) self._apply_gradients(critic_grads, trainable_critic_variables, self._critic_optimizer) trainable_actor_variables = self._actor_network.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_actor_variables, ( 'No trainable actor variables to ' 'optimize.') tape.watch(trainable_actor_variables) actor_loss = self._actor_loss_weight * self.actor_loss( time_steps, weights=weights) tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.') actor_grads = tape.gradient(actor_loss, trainable_actor_variables) self._apply_gradients(actor_grads, trainable_actor_variables, self._actor_optimizer) alpha_variable = [self._log_alpha] with tf.GradientTape(watch_accessed_variables=False) as tape: assert alpha_variable, 'No alpha variable to optimize.' tape.watch(alpha_variable) alpha_loss = self._alpha_loss_weight * self.alpha_loss( time_steps, weights=weights) tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.') alpha_grads = tape.gradient(alpha_loss, alpha_variable) self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='critic_loss', data=critic_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar(name='actor_loss', data=actor_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar(name='alpha_loss', data=alpha_loss, step=self.train_step_counter) self.train_step_counter.assign_add(1) self._update_target() total_loss = critic_loss + actor_loss + alpha_loss extra = SacLossInfo(critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss) return tf_agent.LossInfo(loss=total_loss, extra=extra)
def _train(self, experience, weights=None): # TODO(b/120034503): Move the conversion to transitions to the base class. time_steps, actions, next_time_steps = self._experience_to_transitions( experience) # TODO(kbanoop): Apply a loss mask or filter boundary transitions. critic_loss = self.critic_loss(time_steps, actions, next_time_steps, weights=weights) actor_loss = self.actor_loss(time_steps, weights=weights) def clip_and_summarize_gradients(grads_and_vars): """Clips gradients, and summarizes gradients and variables.""" if self._gradient_clipping is not None: grads_and_vars = eager_utils.clip_gradient_norms_fn( self._gradient_clipping)(grads_and_vars) if self._summarize_grads_and_vars: # TODO(kbanoop): Move gradient summaries to train_op after we switch to # eager train op, and move variable summaries to critic_loss. for grad, var in grads_and_vars: with tf.name_scope('Gradients/'): if grad is not None: tf.compat.v2.summary.histogram( name=grad.op.name, data=grad, step=self.train_step_counter) with tf.name_scope('Variables/'): if var is not None: tf.compat.v2.summary.histogram( name=var.op.name, data=var, step=self.train_step_counter) return grads_and_vars critic_train_op = eager_utils.create_train_op( critic_loss, self._critic_optimizer, global_step=self.train_step_counter, transform_grads_fn=clip_and_summarize_gradients, variables_to_train=self._critic_network_1.trainable_weights + self._critic_network_2.trainable_weights, ) actor_train_op = eager_utils.create_train_op( actor_loss, self._actor_optimizer, global_step=None, transform_grads_fn=clip_and_summarize_gradients, variables_to_train=self._actor_network.trainable_weights, ) with tf.control_dependencies([critic_train_op, actor_train_op]): update_targets_op = self._update_targets( self._target_update_tau, self._target_update_period) with tf.control_dependencies([update_targets_op]): total_loss = actor_loss + critic_loss # TODO(kbanoop): Compute per element TD loss and return in loss_info. return tf_agent.LossInfo(total_loss, Td3Info(actor_loss, critic_loss))
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None): """Computes loss for DQN training. Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.train_sequence_length` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute the element wise loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output td_loss will be scaled by these weights, and the final scalar loss is the mean of these values. Returns: loss: An instance of `DqnLossInfo`. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) if self._n_step_update == 1: time_steps, actions, next_time_steps = self._experience_to_transitions( experience) else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, actions, _ = self._experience_to_transitions( first_two_steps) _, _, next_time_steps = self._experience_to_transitions( last_two_steps) with tf.name_scope('loss'): q_values = self._compute_q_values(time_steps, actions) next_q_values = self._compute_next_q_values(next_time_steps) if self._n_step_update == 1: # Special case for n = 1 to avoid a loss of performance. td_targets = compute_td_targets( next_q_values, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount) else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. td_targets = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=next_q_values, time_major=False, provide_all_returns=False) valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) if weights is not None: td_loss *= weights # Average across the elements of the batch. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. loss = tf.reduce_mean(input_tensor=td_loss) # Add network loss (such as regularization loss) if self._q_network.losses: loss = loss + tf.reduce_mean(self._q_network.losses) with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar(name='loss', data=loss, step=self.train_step_counter) if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo( loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def testBaseLossInfo(self): loss_info = tf_agent.LossInfo(0.0, ()) self.assertEqual(loss_info.loss, 0.0) self.assertIsInstance(loss_info, tf_agent.LossInfo)
def _train(self, experience, weights=None): """Updates the policy based on the data in `experience`. Note that `experience` should only contain data points that this agent has not previously seen. If `experience` comes from a replay buffer, this buffer should be cleared between each call to `train`. Args: experience: A batch of experience data in the form of a `Trajectory`. weights: Unused. Returns: A `LossInfo` containing the loss *before* the training step is taken. In most cases, if `weights` is provided, the entries of this tuple will have been calculated with the weights. Note that each Agent chooses its own method of applying weights. """ del weights # unused # If the experience comes from a replay buffer, the reward has shape: # [batch_size, time_steps] # where `time_steps` is the number of driver steps executed in each # training loop. # We flatten the tensors below in order to reflect the effective batch size. reward, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.reward, self._time_step_spec.reward) action, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.action, self._action_spec) observation, _ = nest_utils.flatten_multi_batched_nested_tensors( experience.observation, self._time_step_spec.observation) if self._observation_and_action_constraint_splitter is not None: observation, _ = self._observation_and_action_constraint_splitter( observation) observation = tf.reshape(observation, [-1, self._context_dim]) observation = tf.cast(observation, self._dtype) reward = tf.cast(reward, self._dtype) for k in range(self._num_actions): diag_mask = tf.linalg.tensor_diag( tf.cast(tf.equal(action, k), self._dtype)) observations_for_arm = tf.matmul(diag_mask, observation) rewards_for_arm = tf.matmul(diag_mask, tf.reshape(reward, [-1, 1])) num_samples_for_arm_current = tf.reduce_sum(diag_mask) tf.compat.v1.assign_add(self._num_samples_list[k], num_samples_for_arm_current) num_samples_for_arm_total = self._num_samples_list[k].read_value() # Update the matrix A and b. # pylint: disable=cell-var-from-loop,g-long-lambda def update(cov_matrix, data_vector): return update_a_and_b_with_forgetting(cov_matrix, data_vector, rewards_for_arm, observations_for_arm, self._gamma, self._use_eigendecomp) a_new, b_new, eig_vals, eig_matrix = tf.cond( tf.squeeze(num_samples_for_arm_total) > 0, lambda: update( self._cov_matrix_list[k], self._data_vector_list[k]), lambda: (self._cov_matrix_list[k], self._data_vector_list[k], self._eig_vals_list[k], self._eig_matrix_list[k])) tf.compat.v1.assign(self._cov_matrix_list[k], a_new) tf.compat.v1.assign(self._data_vector_list[k], b_new) tf.compat.v1.assign(self._eig_vals_list[k], eig_vals) tf.compat.v1.assign(self._eig_matrix_list[k], eig_matrix) loss = -1. * tf.reduce_sum(experience.reward) self.compute_summaries(loss) batch_size = tf.cast(tf.compat.dimension_value(tf.shape(reward)[0]), dtype=tf.int64) self._train_step_counter.assign_add(batch_size) return tf_agent.LossInfo(loss=(loss), extra=())
def _loss(self, experience, td_errors_loss_fn=tf.compat.v1.losses.huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes critic loss for CategoricalDQN training. See Algorithm 1 and the discussion immediately preceding it in page 6 of "A Distributional Perspective on Reinforcement Learning" Bellemare et al., 2017 https://arxiv.org/abs/1707.06887 Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.required_experience_time_steps` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional weights used for importance sampling. training: Whether the loss is being used for training. Returns: critic_loss: A scalar critic loss. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) squeeze_time_dim = not self._q_network.state_spec if self._n_step_update == 1: time_steps, policy_steps, next_time_steps = ( trajectory.experience_to_transitions(experience, squeeze_time_dim)) actions = policy_steps.action else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, policy_steps, _ = ( trajectory.experience_to_transitions(first_two_steps, squeeze_time_dim)) actions = policy_steps.action _, _, next_time_steps = (trajectory.experience_to_transitions( last_two_steps, squeeze_time_dim)) with tf.name_scope('critic_loss'): nest_utils.assert_same_structure(actions, self.action_spec) nest_utils.assert_same_structure(time_steps, self.time_step_spec) nest_utils.assert_same_structure(next_time_steps, self.time_step_spec) rank = nest_utils.get_outer_rank(time_steps.observation, self._time_step_spec.observation) # If inputs have a time dimension and the q_network is stateful, # combine the batch and time dimension. batch_squash = (None if rank <= 1 or self._q_network.state_spec in ((), None) else utils.BatchSquash(rank)) network_observation = time_steps.observation if self._observation_and_action_constraint_splitter is not None: network_observation, _ = ( self._observation_and_action_constraint_splitter( network_observation)) # q_logits contains the Q-value logits for all actions. q_logits, _ = self._q_network(network_observation, time_steps.step_type, training=training) if batch_squash is not None: # Squash outer dimensions to a single dimensions for facilitation # computing the loss the following. Required for supporting temporal # inputs, for example. q_logits = batch_squash.flatten(q_logits) actions = batch_squash.flatten(actions) next_time_steps = tf.nest.map_structure( batch_squash.flatten, next_time_steps) next_q_distribution = self._next_q_distribution(next_time_steps) if actions.shape.rank > 1: actions = tf.squeeze(actions, list(range(1, actions.shape.rank))) # Project the sample Bellman update \hat{T}Z_{\theta} onto the original # support of Z_{\theta} (see Figure 1 in paper). batch_size = q_logits.shape[0] or tf.shape(q_logits)[0] tiled_support = tf.tile(self._support, [batch_size]) tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms]) if self._n_step_update == 1: discount = next_time_steps.discount if discount.shape.rank == 1: # We expect discount to have a shape of [batch_size], while # tiled_support will have a shape of [batch_size, num_atoms]. To # multiply these, we add a second dimension of 1 to the discount. discount = tf.expand_dims(discount, -1) next_value_term = tf.multiply(discount, tiled_support, name='next_value_term') reward = next_time_steps.reward if reward.shape.rank == 1: # See the explanation above. reward = tf.expand_dims(reward, -1) reward_term = tf.multiply(reward_scale_factor, reward, name='reward_term') target_support = tf.add(reward_term, gamma * next_value_term, name='target_support') else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. discounted_returns = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=tf.zeros([batch_size], dtype=discounts.dtype), time_major=False, provide_all_returns=False) # Convert discounted_returns from [batch_size] to [batch_size, 1] discounted_returns = tf.expand_dims(discounted_returns, -1) final_value_discount = tf.reduce_prod(discounts, axis=1) final_value_discount = tf.expand_dims(final_value_discount, -1) # Save the values of discounted_returns and final_value_discount in # order to check them in unit tests. self._discounted_returns = discounted_returns self._final_value_discount = final_value_discount target_support = tf.add(discounted_returns, final_value_discount * tiled_support, name='target_support') target_distribution = tf.stop_gradient( project_distribution(target_support, next_q_distribution, self._support)) # Obtain the current Q-value logits for the selected actions. indices = tf.range(batch_size) indices = tf.cast(indices, actions.dtype) reshaped_actions = tf.stack([indices, actions], axis=-1) chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions) # Compute the cross-entropy loss between the logits. If inputs have # a time dimension, compute the sum over the time dimension before # computing the mean over the batch dimension. if batch_squash is not None: target_distribution = batch_squash.unflatten( target_distribution) chosen_action_logits = batch_squash.unflatten( chosen_action_logits) critic_loss = tf.reduce_sum( tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits), axis=1) else: critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits) agg_loss = common.aggregate_losses( per_example_loss=critic_loss, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss dict_losses = { 'critic_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(dict_losses, step=self.train_step_counter, name_scope='Losses/') if self._debug_summaries: distribution_errors = target_distribution - chosen_action_logits with tf.name_scope('distribution_errors'): common.generate_tensor_summaries( 'distribution_errors', distribution_errors, step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean', tf.reduce_mean(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean_abs', tf.reduce_mean(tf.abs(distribution_errors)), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'max', tf.reduce_max(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'min', tf.reduce_min(distribution_errors), step=self.train_step_counter) with tf.name_scope('target_distribution'): common.generate_tensor_summaries( 'target_distribution', target_distribution, step=self.train_step_counter) # TODO(b/127318640): Give appropriate values for td_loss and td_error for # prioritized replay. return tf_agent.LossInfo( total_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
def _train(self, experience, weights): """Returns a train op to update the agent's networks. This method trains with the provided batched experience. Args: experience: A time-stacked trajectory object. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: A train_op. Raises: ValueError: If optimizers are None and no default value was provided to the constructor. """ time_steps, actions, next_time_steps = ( self._experience_to_transitions(experience)) trainable_critic_variables = ( self._critic_network_1.trainable_variables + self._critic_network_2.trainable_variables) with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_critic_variables, ('No trainable critic variables to ' 'optimize.') tape.watch(trainable_critic_variables) critic_loss = self.critic_loss( time_steps, actions, next_time_steps, td_errors_loss_fn=self._td_errors_loss_fn, gamma=self._gamma, reward_scale_factor=self._reward_scale_factor, weights=weights) tf.debugging.check_numerics(critic_loss, 'Critic loss is inf or nan.') critic_grads = tape.gradient(critic_loss, trainable_critic_variables) self._apply_gradients(critic_grads, trainable_critic_variables, self._critic_optimizer) trainable_actor_variables = self._actor_network.trainable_variables with tf.GradientTape(watch_accessed_variables=False) as tape: assert trainable_actor_variables, ('No trainable actor variables to ' 'optimize.') tape.watch(trainable_actor_variables) actor_loss = self.actor_loss(time_steps, weights=weights) tf.debugging.check_numerics(actor_loss, 'Actor loss is inf or nan.') actor_grads = tape.gradient(actor_loss, trainable_actor_variables) self._apply_gradients(actor_grads, trainable_actor_variables, self._actor_optimizer) alpha_variable = [self._log_alpha] with tf.GradientTape(watch_accessed_variables=False) as tape: assert alpha_variable, 'No alpha variable to optimize.' tape.watch(alpha_variable) alpha_loss = self.alpha_loss(time_steps, weights=weights) tf.debugging.check_numerics(alpha_loss, 'Alpha loss is inf or nan.') alpha_grads = tape.gradient(alpha_loss, alpha_variable) self._apply_gradients(alpha_grads, alpha_variable, self._alpha_optimizer) # updates safety critic if not training online safe_rew = next_time_steps.observation['task_agn_rew'] sc_weight = None if self._fail_weight: sc_weight = tf.where(tf.cast(safe_rew, tf.bool), self._fail_weight / 0.5, (1 - self._fail_weight) / 0.5) safety_critic_loss, lambda_loss = self.train_sc( experience, safe_rew, sc_weight, training=(not self._train_critic_online)) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar( name='critic_loss', data=critic_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='actor_loss', data=actor_loss, step=self.train_step_counter) tf.compat.v2.summary.scalar( name='alpha_loss', data=alpha_loss, step=self.train_step_counter) if lambda_loss is not None: tf.compat.v2.summary.scalar( name='lambda_loss', data=lambda_loss, step=self.train_step_counter) if safety_critic_loss is not None: tf.compat.v2.summary.scalar( name='safety_critic_loss', data=safety_critic_loss, step=self.train_step_counter) self.train_step_counter.assign_add(1) self._update_target() total_loss = critic_loss + actor_loss + alpha_loss extra = SafeSacLossInfo( critic_loss=critic_loss, actor_loss=actor_loss, alpha_loss=alpha_loss, safety_critic_loss=safety_critic_loss, lambda_loss=lambda_loss) return tf_agent.LossInfo(loss=total_loss, extra=extra)
def testTrain(self, num_epochs, use_td_lambda_return): agent = ppo_agent.PPOAgent(self._time_step_spec, self._action_spec, tf.train.AdamOptimizer(), actor_net=DummyActorNet( self._action_spec, ), value_net=DummyValueNet(outer_rank=2), normalize_observations=False, num_epochs=num_epochs, use_gae=use_td_lambda_return, use_td_lambda_return=use_td_lambda_return) observations = tf.constant([ [[1, 2], [3, 4], [5, 6]], [[1, 2], [3, 4], [5, 6]], ], dtype=tf.float32) time_steps = ts.TimeStep(step_type=tf.constant([[1] * 3] * 2, dtype=tf.int32), reward=tf.constant([[1] * 3] * 2, dtype=tf.float32), discount=tf.constant([[1] * 3] * 2, dtype=tf.float32), observation=observations) actions = tf.constant([[[0], [1], [1]], [[0], [1], [1]]], dtype=tf.float32) action_distribution_parameters = { 'loc': tf.constant([[[0.0]] * 3] * 2, dtype=tf.float32), 'scale': tf.constant([[[1.0]] * 3] * 2, dtype=tf.float32), } policy_info = action_distribution_parameters experience = trajectory.Trajectory(time_steps.step_type, observations, actions, policy_info, time_steps.step_type, time_steps.reward, time_steps.discount) # Mock the build_train_op to return an op for incrementing this counter. counter = tf.train.get_or_create_global_step() zero = tf.constant(0, dtype=tf.float32) agent.build_train_op = ( lambda *_, **__: tf_agent.LossInfo( counter.assign_add(1), # pylint: disable=g-long-lambda ppo_agent.PPOLossInfo(*[zero] * 5))) train_op = agent.train(experience) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) # Assert that counter starts out at zero. counter_ = sess.run(counter) self.assertEqual(0, counter_) sess.run(train_op) # Assert that train_op ran increment_counter num_epochs times. counter_ = sess.run(counter) self.assertEqual(num_epochs, counter_)