def alpha_loss(self, time_steps, weights=None): """Computes the alpha_loss for EC-SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: alpha_loss: A scalar alpha loss. """ with tf.name_scope('alpha_loss'): tf.nest.assert_same_structure(time_steps, self.time_step_spec) unused_actions, log_pi = self._actions_and_log_probs(time_steps) alpha_loss = ( self._log_alpha * tf.stop_gradient(-log_pi - self._target_entropy)) if weights is not None: alpha_loss *= weights alpha_loss = tf.reduce_mean(input_tensor=alpha_loss) if self._debug_summaries: common.generate_tensor_summaries('alpha_loss', alpha_loss, self.train_step_counter) return alpha_loss
def actor_loss(self, time_steps, alphas, weights=None): """Computes the actor_loss for DDPG training. Args: time_steps: A batch of timesteps. weights: Optional scalar or element-wise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): actions, _ = self._actor_network((time_steps.observation, alphas), time_steps.step_type) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(actions) q, _ = self._critic_network( (time_steps.observation, actions, alphas), time_steps.step_type) q_means, q_vars = tf.reshape(q.loc, [-1]), tf.reshape(q.scale, [-1]) # actions = tf.nest.flatten(actions) cvar = self._compute_cvar(q_means, q_vars, alphas) actor_loss = -tf.reduce_mean(cvar) with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar(name='actor_loss', data=actor_loss, step=self.train_step_counter) if self._debug_summaries: common.generate_tensor_summaries('cvar', cvar, self.train_step_counter) return actor_loss
def alpha_loss(self, actor_time_steps, weights=None): """Computes the alpha_loss for EC-SAC training. Args: actor_time_steps: A batch of timesteps for the actor. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: alpha_loss: A scalar alpha loss. """ with tf.name_scope('alpha_loss'): actions_distribution, _ = self._actor_network( actor_time_steps.observation, actor_time_steps.step_type) actions = actions_distribution.sample() log_pis = actions_distribution.log_prob(actions) alpha_loss = (self._log_alpha * tf.stop_gradient(-log_pis - self._target_entropy)) if weights is not None: alpha_loss *= weights alpha_loss = tf.reduce_mean(input_tensor=alpha_loss) if self._debug_summaries: common.generate_tensor_summaries('alpha_loss', alpha_loss, self.train_step_counter) return alpha_loss
def _loss(self, experience, weights=None): """Computes loss for behavioral cloning. Args: experience: A `Trajectory` containing experience. weights: Optional scalar or element-wise (per-batch-entry) importance weights. Returns: loss: A `LossInfo` struct. Raises: ValueError: If the number of actions is greater than 1. """ with tf.name_scope('loss'): actions = tf.nest.flatten(experience.action)[0] logits, _ = self._cloning_network(experience.observation, experience.step_type) boundary_weights = tf.cast(~experience.is_boundary(), logits.dtype) error = boundary_weights * self._loss_fn(logits, actions) if nest_utils.is_batched_nested_tensors(experience.action, self.action_spec, num_outer_dims=2): # Do a sum over the time dimension. error = tf.reduce_sum(input_tensor=error, axis=1) # Average across the elements of the batch. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. if weights is not None: error *= weights loss = tf.reduce_mean(input_tensor=error) with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar(name='loss', data=loss, step=self.train_step_counter) if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._cloning_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: common.generate_tensor_summaries('errors', error, self.train_step_counter) return tf_agent.LossInfo(loss, BehavioralCloningLossInfo(loss=error))
def _alpha_loss_debug_summaries(self, alpha_loss, entropy_diff): if self._debug_summaries: common.generate_tensor_summaries('alpha_loss', alpha_loss, self.train_step_counter) common.generate_tensor_summaries('entropy_diff', entropy_diff, self.train_step_counter) tf.compat.v2.summary.scalar( name='log_alpha', data=self._log_alpha, step=self.train_step_counter)
def _critic_loss_debug_summaries(self, td_targets, pred_td_targets1, pred_td_targets2, weights): if self._debug_summaries: td_errors1 = td_targets - pred_td_targets1 td_errors2 = td_targets - pred_td_targets2 td_errors = tf.concat([td_errors1, td_errors2], axis=0) common.generate_tensor_summaries('td_errors', td_errors, self.train_step_counter) common.generate_tensor_summaries('td_targets', td_targets, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2, self.train_step_counter) common.generate_tensor_summaries('weights', weights, self.train_step_counter)
def critic_loss( self, time_steps, actions, next_time_steps, td_errors_loss_fn, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False, delta_r_scale=1.0, delta_r_warmup=0, ): sas_input = tf.concat( [time_steps.observation, actions, next_time_steps.observation], axis=-1) # Set training=False so no input noise is added. sa_probs, sas_probs = self._classifier(sas_input, training=False) sas_log_probs = tf.math.log(sas_probs) sa_log_probs = tf.math.log(sa_probs) if self._unnormalized_delta_r: # Option for ablation experiment. delta_r = sas_log_probs[:, 1] - sas_log_probs[:, 0] else: # Default option (i.e., the correct version). delta_r = (sas_log_probs[:, 1] - sas_log_probs[:, 0] - sa_log_probs[:, 1] + sa_log_probs[:, 0]) common.generate_tensor_summaries("delta_r", delta_r, self.train_step_counter) is_warmup = tf.cast(self.train_step_counter < delta_r_warmup, tf.float32) tf.compat.v2.summary.scalar(name="is_warmup", data=is_warmup, step=self.train_step_counter) next_time_steps = next_time_steps._replace( reward=next_time_steps.reward + delta_r_scale * (1 - is_warmup) * delta_r) return super(DarcAgent, self).critic_loss( time_steps, actions, next_time_steps, td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, weights=weights, training=training, )
def alpha_loss(self, time_steps, weights=None): """Computes the alpha_loss for EC-SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: alpha_loss: A scalar alpha loss. """ with tf.name_scope('alpha_loss'): tf.nest.assert_same_structure(time_steps, self.time_step_spec) unused_actions, log_pi = self._actions_and_log_probs(time_steps) entropy_diff = tf.stop_gradient(-log_pi - self._target_entropy) alpha_loss = (self._log_alpha * entropy_diff) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Sum over the time dimension. alpha_loss = tf.reduce_sum(input_tensor=alpha_loss, axis=1) if weights is not None: alpha_loss *= weights alpha_loss = tf.reduce_mean(input_tensor=alpha_loss) if self._debug_summaries: common.generate_tensor_summaries('alpha_loss', alpha_loss, self.train_step_counter) common.generate_tensor_summaries('entropy_diff', entropy_diff, self.train_step_counter) tf.compat.v2.summary.scalar(name='log_alpha', data=self._log_alpha, step=self.train_step_counter) return alpha_loss
def critic_loss(self, time_steps, actions, next_time_steps, weights=None, training=False): """Computes the critic loss for DDPG training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. weights: Optional scalar or element-wise (per-batch-entry) importance weights. training: Whether this loss is being used for training. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): target_actions, _ = self._target_actor_network( next_time_steps.observation, next_time_steps.step_type, training=False) target_critic_net_input = (next_time_steps.observation, target_actions) target_q_values, _ = self._target_critic_network( target_critic_net_input, next_time_steps.step_type, training=False) td_targets = tf.stop_gradient( self._reward_scale_factor * next_time_steps.reward + self._gamma * next_time_steps.discount * target_q_values) critic_net_input = (time_steps.observation, actions) q_values, _ = self._critic_network(critic_net_input, time_steps.step_type, training=training) critic_loss = self._td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors( time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. critic_loss = tf.reduce_sum(critic_loss, axis=1) if weights is not None: critic_loss *= weights critic_loss = tf.reduce_mean(critic_loss) with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar( name='critic_loss', data=critic_loss, step=self.train_step_counter) if self._debug_summaries: td_errors = td_targets - q_values common.generate_tensor_summaries('td_errors', td_errors, self.train_step_counter) common.generate_tensor_summaries('td_targets', td_targets, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) return critic_loss
def _critic_no_entropy_loss_debug_summaries(self, td_targets, pred_td_targets1, pred_td_targets2): if self._debug_summaries: td_errors1 = td_targets - pred_td_targets1 td_errors2 = td_targets - pred_td_targets2 td_errors = tf.concat([td_errors1, td_errors2], axis=0) common.generate_tensor_summaries('td_errors_no_entropy_critic', td_errors, self.train_step_counter) common.generate_tensor_summaries('td_targets_no_entropy_critic', td_targets, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets1_no_entropy_critic', pred_td_targets1, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets2_no_entropy_critic', pred_td_targets2, self.train_step_counter)
def critic_loss(self, time_steps, actions, next_time_steps): """Computes the critic loss for DDPG training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): target_actions, _ = self._target_actor_network( next_time_steps.observation, next_time_steps.step_type) target_q_values, _ = self._target_critic_network( next_time_steps.observation, target_actions, next_time_steps.step_type) td_targets = tf.stop_gradient( self._reward_scale_factor * next_time_steps.reward + self._gamma * next_time_steps.discount * target_q_values) q_values, _ = self._critic_network(time_steps.observation, actions, time_steps.step_type) critic_loss = self._td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors( time_steps, self.time_step_spec(), num_outer_dims=2): # Do a sum over the time dimension. critic_loss = tf.reduce_sum(critic_loss, axis=1) critic_loss = tf.reduce_mean(critic_loss) with tf.name_scope('Losses/'): tf.contrib.summary.scalar('critic_loss', critic_loss) if self._debug_summaries: td_errors = td_targets - q_values common_utils.generate_tensor_summaries('td_errors', td_errors) common_utils.generate_tensor_summaries('td_targets', td_targets) common_utils.generate_tensor_summaries('q_values', q_values) return critic_loss
def actor_loss(self, time_steps, weights=None): """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): tf.nest.assert_same_structure(time_steps, self.time_step_spec) actions, log_pi = self._actions_and_log_probs(time_steps) target_input = (time_steps.observation, actions) target_q_values1, _ = self._critic_network_1(target_input, time_steps.step_type, training=False) target_q_values2, _ = self._critic_network_2(target_input, time_steps.step_type, training=False) target_q_values = tf.minimum(target_q_values1, target_q_values2) actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Sum over the time dimension. actor_loss = tf.reduce_sum(input_tensor=actor_loss, axis=1) if weights is not None: actor_loss *= weights actor_loss = tf.reduce_mean(input_tensor=actor_loss) if self._debug_summaries: common.generate_tensor_summaries('actor_loss', actor_loss, self.train_step_counter) common.generate_tensor_summaries('actions', actions, self.train_step_counter) common.generate_tensor_summaries('log_pi', log_pi, self.train_step_counter) tf.compat.v2.summary.scalar( name='entropy_avg', data=-tf.reduce_mean(input_tensor=log_pi), step=self.train_step_counter) common.generate_tensor_summaries('target_q_values', target_q_values, self.train_step_counter) batch_size = nest_utils.get_outer_shape( time_steps, self._time_step_spec)[0] policy_state = self._train_policy.get_initial_state(batch_size) action_distribution = self._train_policy.distribution( time_steps, policy_state).action if isinstance(action_distribution, tfp.distributions.Normal): common.generate_tensor_summaries('act_mean', action_distribution.loc, self.train_step_counter) common.generate_tensor_summaries('act_stddev', action_distribution.scale, self.train_step_counter) elif isinstance(action_distribution, tfp.distributions.Categorical): common.generate_tensor_summaries( 'act_mode', action_distribution.mode(), self.train_step_counter) try: common.generate_tensor_summaries( 'entropy_action', action_distribution.entropy(), self.train_step_counter) except NotImplementedError: pass # Some distributions do not have an analytic entropy. return actor_loss
def critic_loss(self, time_steps, actions, next_time_steps, td_errors_loss_fn, gamma=1.0, reward_scale_factor=1.0, weights=None): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): tf.nest.assert_same_structure(actions, self.action_spec) tf.nest.assert_same_structure(time_steps, self.time_step_spec) tf.nest.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, next_log_pis = self._actions_and_log_probs( next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network_1( target_input, next_time_steps.step_type, training=False) target_q_values2, unused_network_state2 = self._target_critic_network_2( target_input, next_time_steps.step_type, training=False) target_q_values = (tf.minimum(target_q_values1, target_q_values2) - tf.exp(self._log_alpha) * next_log_pis) td_targets = tf.stop_gradient( reward_scale_factor * next_time_steps.reward + gamma * next_time_steps.discount * target_q_values) pred_input = (time_steps.observation, actions) pred_td_targets1, _ = self._critic_network_1(pred_input, time_steps.step_type, training=True) pred_td_targets2, _ = self._critic_network_2(pred_input, time_steps.step_type, training=True) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if weights is not None: critic_loss *= weights if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Sum over the time dimension. critic_loss = tf.reduce_sum(input_tensor=critic_loss, axis=1) # Take the mean across the batch. critic_loss = tf.reduce_mean(input_tensor=critic_loss) if self._debug_summaries: td_errors1 = td_targets - pred_td_targets1 td_errors2 = td_targets - pred_td_targets2 td_errors = tf.concat([td_errors1, td_errors2], axis=0) common.generate_tensor_summaries('td_errors', td_errors, self.train_step_counter) common.generate_tensor_summaries('td_targets', td_targets, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2, self.train_step_counter) return critic_loss
def _loss_h(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action valid_mask = tf.cast(~time_steps.is_last(), tf.float32) with tf.name_scope('loss'): # q_values is already gathered by actions h_values = self._compute_h_values(time_steps, actions, training=training) multi_dim_actions = self._action_spec.shape.rank > 0 next_q_all_values = self._compute_next_all_q_values( next_time_steps, policy_steps.info) next_h_all_values = self._compute_next_all_h_values( next_time_steps, policy_steps.info) next_h_actions = tf.argmax(next_h_all_values, axis=1) # next_h_values here is used only for logging next_h_values = self._compute_next_h_values( next_time_steps, policy_steps.info) # next_q_values refer to Q(r,s') in Eqs.(4),(5) next_q_values = common.index_with_actions( next_q_all_values, tf.cast(next_h_actions, dtype=tf.int32), multi_dim_actions) h_target_all_values = self._compute_next_all_h_values( time_steps, policy_steps.info) h_target_values = common.index_with_actions( h_target_all_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions) td_targets = compute_momentum_td_targets( q_target_values=next_q_values, h_target_values=h_target_values, beta=self.beta()) td_error = valid_mask * (td_targets - h_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, h_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._h_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_h_values = h_values - next_h_values common.generate_tensor_summaries('td_error_h', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss_h', td_loss, self.train_step_counter) common.generate_tensor_summaries('h_values', h_values, self.train_step_counter) common.generate_tensor_summaries('next_h_values', next_h_values, self.train_step_counter) common.generate_tensor_summaries('diff_h_values', diff_h_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def actor_loss(self, time_steps, weights=None): """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): tf.nest.assert_same_structure(time_steps, self.time_step_spec()) actions, log_pi = self._actions_and_log_probs(time_steps) target_input_1 = (time_steps.observation, actions) target_q_values1, unused_network_state1 = self._critic_network1( target_input_1, time_steps.step_type) target_input_2 = (time_steps.observation, actions) target_q_values2, unused_network_state2 = self._critic_network2( target_input_2, time_steps.step_type) target_q_values = tf.minimum(target_q_values1, target_q_values2) actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values if weights is not None: actor_loss *= weights actor_loss = tf.reduce_mean(input_tensor=actor_loss) if self._debug_summaries: common_utils.generate_tensor_summaries('actor_loss', actor_loss) common_utils.generate_tensor_summaries('actions', actions) common_utils.generate_tensor_summaries('log_pi', log_pi) tf.contrib.summary.scalar('entropy_avg', -tf.reduce_mean(input_tensor=log_pi)) common_utils.generate_tensor_summaries('target_q_values', target_q_values) action_distribution = self.policy().distribution( time_steps).action common_utils.generate_tensor_summaries('act_mean', action_distribution.loc) common_utils.generate_tensor_summaries( 'act_stddev', action_distribution.scale) common_utils.generate_tensor_summaries( 'entropy_raw_action', action_distribution.entropy()) return actor_loss
def safety_critic_loss(time_steps, actions, next_time_steps, safety_rewards, get_action, global_step, critic_network=None, target_network=None, target_safety=None, safety_gamma=0.45, loss_fn='bce', metrics=None, debug_summaries=False): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. safety_rewards: Task-agnostic rewards for safety. 1 is unsafe, 0 is safe. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: safe_critic_loss: A scalar critic loss. """ with tf.name_scope('safety_critic_loss'): next_actions = get_action(next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values, _ = target_network(target_input, next_time_steps.step_type) target_q_values = tf.nn.sigmoid(target_q_values) td_targets = tf.stop_gradient(safety_rewards + (1 - safety_rewards) * safety_gamma * next_time_steps.discount * target_q_values) if loss_fn == 'bce' or loss_fn == tf.keras.losses.binary_crossentropy: td_targets = tf.nn.sigmoid(td_targets) pred_input = (time_steps.observation, actions) pred_td_targets, _ = critic_network(pred_input, time_steps.step_type, training=True) pred_td_targets = tf.nn.sigmoid(pred_td_targets) # Loss fns: binary_crossentropy/squared_difference if loss_fn == 'mse': sc_loss = tf.math.squared_difference(td_targets, pred_td_targets) elif loss_fn == 'bce' or loss_fn is None: sc_loss = tf.keras.losses.binary_crossentropy( td_targets, pred_td_targets) elif loss_fn is not None: sc_loss = loss_fn(td_targets, pred_td_targets) if metrics: for metric in metrics: if isinstance(metric, tf.keras.metrics.AUC): metric.update_state(safety_rewards, pred_td_targets) else: rew_pred = tf.greater_equal(pred_td_targets, target_safety) metric.update_state(safety_rewards, rew_pred) if debug_summaries: pred_td_targets = tf.nn.sigmoid(pred_td_targets) td_errors = td_targets - pred_td_targets common.generate_tensor_summaries('safety_td_errors', td_errors, global_step) common.generate_tensor_summaries('safety_td_targets', td_targets, global_step) common.generate_tensor_summaries('safety_pred_td_targets', pred_td_targets, global_step) return sc_loss
def _actor_loss_debug_summaries(self, actor_loss, actions, log_pi, target_q_values, time_steps): if self._debug_summaries: common.generate_tensor_summaries('actor_loss', actor_loss, self.train_step_counter) common.generate_tensor_summaries('actions', actions, self.train_step_counter) common.generate_tensor_summaries('log_pi', log_pi, self.train_step_counter) tf.compat.v2.summary.scalar( name='entropy_avg', data=-tf.reduce_mean(input_tensor=log_pi), step=self.train_step_counter) common.generate_tensor_summaries('target_q_values', target_q_values, self.train_step_counter) batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._train_policy.get_initial_state(batch_size) action_distribution = self._train_policy.distribution( time_steps, policy_state).action if isinstance(action_distribution, tfp.distributions.Normal): common.generate_tensor_summaries('act_mean', action_distribution.loc, self.train_step_counter) common.generate_tensor_summaries('act_stddev', action_distribution.scale, self.train_step_counter) elif isinstance(action_distribution, tfp.distributions.Categorical): common.generate_tensor_summaries('act_mode', action_distribution.mode(), self.train_step_counter) common.generate_tensor_summaries('entropy_action', action_distribution.entropy(), self.train_step_counter)
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes loss for DQN training. Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.train_sequence_length` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute the element wise loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output td_loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether this loss is being used for training. Returns: loss: An instance of `DqnLossInfo`. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) squeeze_time_dim = not self._q_network.state_spec if self._n_step_update == 1: time_steps, policy_steps, next_time_steps = ( trajectory.experience_to_transitions(experience, squeeze_time_dim)) actions = policy_steps.action else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, policy_steps, _ = ( trajectory.experience_to_transitions(first_two_steps, squeeze_time_dim)) actions = policy_steps.action _, _, next_time_steps = (trajectory.experience_to_transitions( last_two_steps, squeeze_time_dim)) with tf.name_scope('loss'): q_values = self._compute_q_values(time_steps, actions, training=training) next_q_values = self._compute_next_q_values( next_time_steps, policy_steps.info) if self._n_step_update == 1: # Special case for n = 1 to avoid a loss of performance. td_targets = compute_td_targets( next_q_values, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount) else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. td_targets = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=next_q_values, time_major=False, provide_all_returns=False) valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def _actor_loss_debug_summaries(self, actor_loss, actions, log_pi, target_q_values, time_steps): if self._debug_summaries: common.generate_tensor_summaries('actor_loss', actor_loss, self.train_step_counter) try: common.generate_tensor_summaries('actions', actions, self.train_step_counter) except ValueError: pass # Guard against internal SAC variants that do not directly # generate actions. common.generate_tensor_summaries('log_pi', log_pi, self.train_step_counter) tf.compat.v2.summary.scalar( name='entropy_avg', data=-tf.reduce_mean(input_tensor=log_pi), step=self.train_step_counter) common.generate_tensor_summaries('target_q_values', target_q_values, self.train_step_counter) batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0] policy_state = self._train_policy.get_initial_state(batch_size) action_distribution = self._train_policy.distribution( time_steps, policy_state).action if isinstance(action_distribution, tfp.distributions.Normal): common.generate_tensor_summaries('act_mean', action_distribution.loc, self.train_step_counter) common.generate_tensor_summaries('act_stddev', action_distribution.scale, self.train_step_counter) elif isinstance(action_distribution, tfp.distributions.Categorical): common.generate_tensor_summaries('act_mode', action_distribution.mode(), self.train_step_counter) try: common.generate_tensor_summaries('entropy_action', action_distribution.entropy(), self.train_step_counter) except NotImplementedError: pass # Some distributions do not have an analytic entropy.
def actor_loss(self, time_steps, actor_time_steps, weights=None): """Computes the actor_loss for SAC training. Args: time_steps: A batch of timesteps for the critic. actor_time_steps: A batch of timesteps for the actor. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: actor_loss: A scalar actor loss. """ with tf.name_scope('actor_loss'): time_steps = tf.nest.map_structure(tf.stop_gradient, time_steps) if self._actor_input_stop_gradient: actor_time_steps = tf.nest.map_structure( tf.stop_gradient, actor_time_steps) actions_distribution, _ = self._actor_network( actor_time_steps.observation, actor_time_steps.step_type) actions = actions_distribution.sample() log_pis = actions_distribution.log_prob(actions) target_input_1 = (time_steps.observation, actions) target_q_values1, unused_network_state1 = self._critic_network1( target_input_1, time_steps.step_type) target_input_2 = (time_steps.observation, actions) target_q_values2, unused_network_state2 = self._critic_network2( target_input_2, time_steps.step_type) target_q_values = tf.minimum(target_q_values1, target_q_values2) actor_loss = tf.exp(self._log_alpha) * log_pis - target_q_values if weights is not None: actor_loss *= weights actor_loss = tf.reduce_mean(input_tensor=actor_loss) if self._debug_summaries: common.generate_tensor_summaries('actor_loss', actor_loss, self.train_step_counter) common.generate_tensor_summaries('actions', actions, self.train_step_counter) common.generate_tensor_summaries('log_pis', log_pis, self.train_step_counter) tf.compat.v2.summary.scalar( name='entropy_avg', data=-tf.reduce_mean(input_tensor=log_pis), step=self.train_step_counter) common.generate_tensor_summaries('target_q_values', target_q_values, self.train_step_counter) batch_size = nest_utils.get_outer_shape( time_steps, self._time_step_spec)[0] policy_state = self.policy.get_initial_state(batch_size) action_distribution = self.policy.distribution( time_steps, policy_state).action if isinstance(action_distribution, tfp.distributions.Normal): common.generate_tensor_summaries('act_mean', action_distribution.loc, self.train_step_counter) common.generate_tensor_summaries('act_stddev', action_distribution.scale, self.train_step_counter) elif isinstance(action_distribution, tfp.distributions.Categorical): common.generate_tensor_summaries( 'act_mode', action_distribution.mode(), self.train_step_counter) try: common.generate_tensor_summaries( 'entropy_action', action_distribution.entropy(), self.train_step_counter) except NotImplementedError: pass # Some distributions do not have an analytic entropy. return actor_loss
def critic_loss(self, time_steps, actions, next_time_steps, actor_next_time_steps, td_errors_loss_fn, gamma=1.0, reward_scale_factor=1.0, weights=None): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps for the critic. actions: A batch of actions. next_time_steps: A batch of next timesteps for the critic. actor_next_time_steps: A batch of next timesteps for the actor. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): if self._critic_input_stop_gradient: time_steps = tf.nest.map_structure(tf.stop_gradient, time_steps) next_time_steps = tf.nest.map_structure( tf.stop_gradient, next_time_steps) # not really necessary since there is a stop_gradient for the td_targets actor_next_time_steps = tf.nest.map_structure( tf.stop_gradient, actor_next_time_steps) next_actions_distribution, _ = self._actor_network( actor_next_time_steps.observation, actor_next_time_steps.step_type) next_actions = next_actions_distribution.sample() next_log_pis = next_actions_distribution.log_prob(next_actions) target_input_1 = (next_time_steps.observation, next_actions) target_q_values1, unused_network_state1 = self._target_critic_network1( target_input_1, next_time_steps.step_type) target_input_2 = (next_time_steps.observation, next_actions) target_q_values2, unused_network_state2 = self._target_critic_network2( target_input_2, next_time_steps.step_type) target_q_values = (tf.minimum(target_q_values1, target_q_values2) - tf.exp(self._log_alpha) * next_log_pis) td_targets = tf.stop_gradient( reward_scale_factor * next_time_steps.reward + gamma * next_time_steps.discount * target_q_values) pred_input_1 = (time_steps.observation, actions) pred_td_targets1, unused_network_state1 = self._critic_network1( pred_input_1, time_steps.step_type) pred_input_2 = (time_steps.observation, actions) pred_td_targets2, unused_network_state2 = self._critic_network2( pred_input_2, time_steps.step_type) critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1) critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2) critic_loss = critic_loss1 + critic_loss2 if weights is not None: critic_loss *= weights # Take the mean across the batch. critic_loss = tf.reduce_mean(input_tensor=critic_loss) if self._debug_summaries: td_errors1 = td_targets - pred_td_targets1 td_errors2 = td_targets - pred_td_targets2 td_errors = tf.concat([td_errors1, td_errors2], axis=0) common.generate_tensor_summaries('td_errors', td_errors, self.train_step_counter) common.generate_tensor_summaries('td_targets', td_targets, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1, self.train_step_counter) common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2, self.train_step_counter) return critic_loss
def critic_loss(self, experience, weights=None): # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) if self._n_step_return == 1: time_steps, actions, next_time_steps = self._experience_to_transitions( experience) else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, actions, _ = self._experience_to_transitions( first_two_steps) _, _, next_time_steps = self._experience_to_transitions( last_two_steps) with tf.name_scope('critic_loss'): tf.nest.assert_same_structure(actions, self.action_spec) tf.nest.assert_same_structure(time_steps, self.time_step_spec) tf.nest.assert_same_structure(next_time_steps, self.time_step_spec) target_actions, _ = self._target_actor_network( next_time_steps.observation, next_time_steps.step_type) target_critic_network_input = (next_time_steps.observation, target_actions) _, next_distribution, _ = self._target_critic_network( target_critic_network_input, next_time_steps.step_type) batch_size = next_distribution.shape[0] or tf.shape( next_distribution)[0] tiled_support = tf.tile(self._support, [batch_size]) tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms]) if self._n_step_return == 1: discount = next_time_steps.discount if discount.shape.ndims == 1: # We expect discount to have a shape of [batch_size], while # tiled_support will have a shape of [batch_size, num_atoms]. To # multiply these, we add a second dimension of 1 to the discount. discount = discount[:, None] next_value_term = tf.multiply(discount, tiled_support, name='next_value_term') reward = next_time_steps.reward if reward.shape.ndims == 1: # See the explanation above. reward = reward[:, None] reward_term = tf.multiply(self._reward_scale_factor, reward, name='reward_term') target_support = tf.add(reward_term, self._gamma * next_value_term, name='target_support') # TODO : This is not correct when n > 2 else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = self._reward_scale_factor * experience.reward[:, :-1] discounts = self._gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. discounted_returns = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=tf.zeros([batch_size], dtype=discounts.dtype), time_major=False, provide_all_returns=False) # Convert discounted_returns from [batch_size] to [batch_size, 1] discounted_returns = discounted_returns[:, None] final_value_discount = tf.reduce_prod(discounts, axis=1) final_value_discount = final_value_discount[:, None] # Save the values of discounted_returns and final_value_discount in # order to check them in unit tests. self._discounted_returns = discounted_returns self._final_value_discount = final_value_discount target_support = tf.add(discounted_returns, final_value_discount * tiled_support, name='target_support') target_distribution = tf.stop_gradient( self._project_distribution(target_support, next_distribution, self._support)) logits, distribution, _ = self._critic_network( (time_steps.observation, actions), time_steps.step_type) cross_entropy_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( labels=tf.stop_gradient(target_distribution), logits=logits)) l2_reg_loss = tf.add_n([ tf.nn.l2_loss(v) for v in self._critic_network.trainable_variables if 'kernel' in v.name ]) * self._critic_l2_lambda critic_loss = cross_entropy_loss + l2_reg_loss with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar('critic_loss', critic_loss, step=self.train_step_counter) if self._debug_summaries: distribution_errors = target_distribution - distribution with tf.name_scope('distribution_errors'): common.generate_tensor_summaries( 'distribution_errors', distribution_errors, step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean', tf.reduce_mean(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean_abs', tf.reduce_mean(tf.abs(distribution_errors)), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'max', tf.reduce_max(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'min', tf.reduce_min(distribution_errors), step=self.train_step_counter) with tf.name_scope('target_distribution'): common.generate_tensor_summaries( 'target_distribution', target_distribution, step=self.train_step_counter) return critic_loss
def _loss(self, experience, td_errors_loss_fn=tf.losses.huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None): """Computes critic loss for CategoricalDQN training. See Algorithm 1 and the discussion immediately preceding it in page 6 of "A Distributional Perspective on Reinforcement Learning" Bellemare et al., 2017 https://arxiv.org/abs/1707.06887 Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.required_experience_time_steps` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional weights used for importance sampling. Returns: critic_loss: A scalar critic loss. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires a time dimension to compute the loss properly. self._check_trajectory_dimensions(experience) if self._n_step_update == 1: time_steps, actions, next_time_steps = self._experience_to_transitions( experience) else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, actions, _ = self._experience_to_transitions( first_two_steps) _, _, next_time_steps = self._experience_to_transitions( last_two_steps) with tf.name_scope('critic_loss'): tf.nest.assert_same_structure(actions, self.action_spec) tf.nest.assert_same_structure(time_steps, self.time_step_spec) tf.nest.assert_same_structure(next_time_steps, self.time_step_spec) rank = nest_utils.get_outer_rank(time_steps.observation, self._time_step_spec.observation) # If inputs have a time dimension and the q_network is stateful, # combine the batch and time dimension. batch_squash = (None if rank <= 1 or self._q_network.state_spec in ((), None) else utils.BatchSquash(rank)) # q_logits contains the Q-value logits for all actions. q_logits, _ = self._q_network(time_steps.observation, time_steps.step_type) next_q_distribution = self._next_q_distribution( next_time_steps, batch_squash) if batch_squash is not None: # Squash outer dimensions to a single dimensions for facilitation # computing the loss the following. Required for supporting temporal # inputs, for example. q_logits = batch_squash.flatten(q_logits) actions = batch_squash.flatten(actions) next_time_steps = tf.nest.map_structure( batch_squash.flatten, next_time_steps) actions = tf.nest.flatten(actions)[0] if actions.shape.ndims > 1: actions = tf.squeeze(actions, range(1, actions.shape.ndims)) # Project the sample Bellman update \hat{T}Z_{\theta} onto the original # support of Z_{\theta} (see Figure 1 in paper). batch_size = tf.shape(q_logits)[0] tiled_support = tf.tile(self._support, [batch_size]) tiled_support = tf.reshape(tiled_support, [batch_size, self._num_atoms]) if self._n_step_update == 1: discount = next_time_steps.discount if discount.shape.ndims == 1: # We expect discount to have a shape of [batch_size], while # tiled_support will have a shape of [batch_size, num_atoms]. To # multiply these, we add a second dimension of 1 to the discount. discount = discount[:, None] next_value_term = tf.multiply(discount, tiled_support, name='next_value_term') reward = next_time_steps.reward if reward.shape.ndims == 1: # See the explanation above. reward = reward[:, None] reward_term = tf.multiply(reward_scale_factor, reward, name='reward_term') target_support = tf.add(reward_term, gamma * next_value_term, name='target_support') else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. rewards = reward_scale_factor * experience.reward[:, :-1] discounts = gamma * experience.discount[:, :-1] # TODO(b/134618876): Properly handle Trajectories that include episode # boundaries with nonzero discount. # TODO(b/131557265): Replace value_ops.discounted_return with a method # that only computes the single value needed. discounted_rewards = value_ops.discounted_return( rewards=rewards, discounts=discounts, final_value=tf.zeros([batch_size], dtype=discounts.dtype), time_major=False) # We only need the first value within the time dimension which # corresponds to the full final return. The remaining values are only # partial returns. discounted_rewards = discounted_rewards[:, :1] final_value_discount = tf.reduce_prod(discounts, axis=1) final_value_discount = final_value_discount[:, None] # Save the values of discounted_rewards and final_value_discount in # order to check them in unit tests. self._discounted_rewards = discounted_rewards self._final_value_discount = final_value_discount target_support = tf.add(discounted_rewards, final_value_discount * tiled_support, name='target_support') target_distribution = tf.stop_gradient( project_distribution(target_support, next_q_distribution, self._support)) # Obtain the current Q-value logits for the selected actions. indices = tf.range(tf.shape(q_logits)[0])[:, None] indices = tf.cast(indices, actions.dtype) reshaped_actions = tf.concat([indices, actions[:, None]], 1) chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions) # Compute the cross-entropy loss between the logits. If inputs have # a time dimension, compute the sum over the time dimension before # computing the mean over the batch dimension. if batch_squash is not None: target_distribution = batch_squash.unflatten( target_distribution) chosen_action_logits = batch_squash.unflatten( chosen_action_logits) critic_loss = tf.reduce_mean( tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits), axis=1)) else: critic_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2( labels=target_distribution, logits=chosen_action_logits)) with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar('critic_loss', critic_loss, step=self.train_step_counter) if self._debug_summaries: distribution_errors = target_distribution - chosen_action_logits with tf.name_scope('distribution_errors'): common.generate_tensor_summaries( 'distribution_errors', distribution_errors, step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean', tf.reduce_mean(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'mean_abs', tf.reduce_mean(tf.abs(distribution_errors)), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'max', tf.reduce_max(distribution_errors), step=self.train_step_counter) tf.compat.v2.summary.scalar( 'min', tf.reduce_min(distribution_errors), step=self.train_step_counter) with tf.name_scope('target_distribution'): common.generate_tensor_summaries( 'target_distribution', target_distribution, step=self.train_step_counter) # TODO(b/127318640): Give appropriate values for td_loss and td_error for # prioritized replay. return tf_agent.LossInfo( critic_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
def _loss(self, experience, td_errors_loss_fn=element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None): """Computes loss for DQN training. Args: experience: A batch of experience data in the form of a `Trajectory`. The structure of `experience` must match that of `self.policy.step_spec`. All tensors in `experience` must be shaped `[batch, time, ...]` where `time` must be equal to `self.train_sequence_length` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute the element wise loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output td_loss will be scaled by these weights, and the final scalar loss is the mean of these values. Returns: loss: An instance of `DqnLossInfo`. Raises: ValueError: if the number of actions is greater than 1. """ # Check that `experience` includes two outer dimensions [B, T, ...]. This # method requires `experience` to include the time dimension. self._check_trajectory_dimensions(experience) if self._n_step_update == 1: time_steps, actions, next_time_steps = self._experience_to_transitions( experience) else: # To compute n-step returns, we need the first time steps, the first # actions, and the last time steps. Therefore we extract the first and # last transitions from our Trajectory. first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience) last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience) time_steps, actions, _ = self._experience_to_transitions(first_two_steps) _, _, next_time_steps = self._experience_to_transitions(last_two_steps) with tf.name_scope('loss'): actions = tf.nest.flatten(actions)[0] q_values, _ = self._q_network(time_steps.observation, time_steps.step_type) # Handle action_spec.shape=(), and shape=(1,) by using the # multi_dim_actions param. multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.ndims > 0 q_values = common.index_with_actions( q_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions=multi_dim_actions) next_q_values = self._compute_next_q_values(next_time_steps) if self._n_step_update == 1: # Special case for n = 1 to avoid a loss of performance. td_targets = compute_td_targets( next_q_values, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount) else: # When computing discounted return, we need to throw out the last time # index of both reward and discount, which are filled with dummy values # to match the dimensions of the observation. # TODO(b/131557265): Replace value_ops.discounted_return with a method # that only computes the single value needed. n_step_return = value_ops.discounted_return( rewards=reward_scale_factor * experience.reward[:, :-1], discounts=gamma * experience.discount[:, :-1], final_value=next_q_values, time_major=False) # We only need the first value within the time dimension which # corresponds to the full final return. The remaining values are only # partial returns. td_targets = n_step_return[:, 0] valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors( time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) if weights is not None: td_loss *= weights # Average across the elements of the batch. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. loss = tf.reduce_mean(input_tensor=td_loss) with tf.name_scope('Losses/'): tf.compat.v1.summary.scalar( 'loss_' + self.name, loss, collections=['train_' + self.name]) # family=self.name) if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo(loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def loss(self, time_steps, actions, next_time_steps, td_errors_loss_fn=element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None): """Computes loss for DQN training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute the element wise loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output td_loss will be scaled by these weights, and the final scalar loss is the mean of these values. Returns: loss: An instance of `DqnLossInfo`. Raises: ValueError: if the number of actions is greater than 1. """ with tf.name_scope('loss'): actions = tf.nest.flatten(actions)[0] q_values, _ = self._q_network(time_steps.observation, time_steps.step_type) # Handle action_spec.shape=(), and shape=(1,) by using the # multi_dim_actions param. multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.ndims > 0 q_values = common_utils.index_with_actions( q_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions=multi_dim_actions) next_q_values = self._compute_next_q_values(next_time_steps) td_targets = compute_td_targets( next_q_values, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount) valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors( time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) if weights is not None: td_loss *= weights # Average across the elements of the batch. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. loss = tf.reduce_mean(input_tensor=td_loss) with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar( name='loss', data=loss, step=self.train_step_counter) if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common_utils.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common_utils.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common_utils.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common_utils.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common_utils.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo(loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def critic_loss(self, time_steps, actions, next_time_steps, td_errors_loss_fn, gamma=1.0, reward_scale_factor=1.0, weights=None): """Computes the critic loss for SAC training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. td_errors_loss_fn: A function(td_targets, predictions) to compute elementwise (per-batch-entry) loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): tf.nest.assert_same_structure(actions, self.action_spec) tf.nest.assert_same_structure(time_steps, self.time_step_spec) tf.nest.assert_same_structure(next_time_steps, self.time_step_spec) next_actions, next_log_pis = self._actions_and_log_probs(next_time_steps) target_input = (next_time_steps.observation, next_actions) target_q_values = [] for tcn in self._target_critic_networks: target_q_values1, _ = tcn( target_input, next_time_steps.step_type, training=False) target_q_values.append(target_q_values1) target_q_values = tfp.stats.percentile(target_q_values, self._percentile, axis=0) # target_q_values = tf.reduce_min(target_q_values) # - tf.exp(self._log_alpha) * next_log_pis td_targets = tf.stop_gradient( reward_scale_factor * next_time_steps.reward + gamma * next_time_steps.discount * target_q_values) pred_input = (time_steps.observation, actions) pred_td_targets = [] for cn in self._critic_networks: pred_td_targets1, _ = cn(pred_input, time_steps.step_type, training=True) pred_td_targets.append(pred_td_targets1) critic_loss = tf.reduce_mean( [td_errors_loss_fn(td_targets, pred_td_target) for pred_td_target in pred_td_targets], axis=0) if weights is not None: critic_loss *= weights # Take the mean across the batch. critic_loss = tf.reduce_mean(input_tensor=critic_loss) if self._debug_summaries: td_errors = [td_targets - pred_td_target for pred_td_target in pred_td_targets] td_errors = tf.concat(td_errors, axis=0) common.generate_tensor_summaries('td_errors', td_errors, self.train_step_counter) common.generate_tensor_summaries('td_targets', td_targets, self.train_step_counter) return critic_loss
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): """Computes loss for DQN training. Args: experience: A batch of experience data in the form of a `Trajectory` or `Transition`. The structure of `experience` must match that of `self.collect_policy.step_spec`. If a `Trajectory`, all tensors in `experience` must be shaped `[B, T, ...]` where `T` must be equal to `self.train_sequence_length` if that property is not `None`. td_errors_loss_fn: A function(td_targets, predictions) to compute the element wise loss. gamma: Discount for future rewards. reward_scale_factor: Multiplicative factor to scale rewards. weights: Optional scalar or elementwise (per-batch-entry) importance weights. The output td_loss will be scaled by these weights, and the final scalar loss is the mean of these values. training: Whether this loss is being used for training. Returns: loss: An instance of `DqnLossInfo`. Raises: ValueError: if the number of actions is greater than 1. """ transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action with tf.name_scope('loss'): q_values = self._compute_q_values(time_steps, actions, training=training) next_q_values = self._compute_next_q_values( next_time_steps, policy_steps.info) # This applies to any value of n_step_update and also in the RNN-DQN case. # In the RNN-DQN case, inputs and outputs contain a time dimension. td_targets = compute_td_targets( next_q_values, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount) valid_mask = tf.cast(~time_steps.is_last(), tf.float32) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors( time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = {'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss} common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo(total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def critic_loss(self, time_steps, actions, alphas, next_time_steps, weights=None): """Computes the critic loss for DDPG training. Args: time_steps: A batch of timesteps. actions: A batch of actions. next_time_steps: A batch of next timesteps. weights: Optional scalar or element-wise (per-batch-entry) importance weights. Returns: critic_loss: A scalar critic loss. """ with tf.name_scope('critic_loss'): target_actions, _ = self._target_actor_network( (next_time_steps.observation, alphas), next_time_steps.step_type) next_target_critic_net_input = (next_time_steps.observation, target_actions, alphas) next_target_Z, _ = self._target_critic_network( next_target_critic_net_input, next_time_steps.step_type) next_target_means = tf.reshape(next_target_Z.loc, [-1]) next_target_vars = tf.reshape(next_target_Z.scale, [-1]) target_critic_net_input = (time_steps.observation, actions, alphas) target_Z, _ = self._target_critic_network( target_critic_net_input, next_time_steps.step_type) target_means = tf.reshape(target_Z.loc, [-1]) if len(next_target_means.shape) != 1: raise ValueError( 'Q-network should output a tensor of shape (batch,) ' 'but shape {} was returned.'.format( next_target_means.shape.as_list())) if len(target_means.shape) != 1: raise ValueError( 'Q-network should output a tensor of shape (batch,) ' 'but shape {} was returned.'.format( target_means.shape.as_list())) td_mean_target = tf.stop_gradient( self._reward_scale_factor * next_time_steps.reward + self._gamma * next_time_steps.discount * next_target_means) # Refer to Eq. 6 in WCPG td_var_target = tf.stop_gradient( (self._reward_scale_factor * next_time_steps.reward)**2 + 2 * self._gamma * next_time_steps.discount * next_time_steps.reward * next_target_means + next_time_steps.discount * self._gamma**2 * next_target_vars + self._gamma**2 * next_target_means**2 - target_means**2) tf.debugging.check_numerics(target_means, 'target means is inf or nan.') tf.debugging.check_numerics(next_target_means, 'next target means is inf or nan.') tf.debugging.check_numerics(td_var_target, 'target var is inf or nan.') tf.debugging.check_numerics(td_var_target, 'target var is inf or nan.') critic_net_input = (time_steps.observation, actions, alphas) Z, _ = self._critic_network(critic_net_input, time_steps.step_type) q_means = tf.reshape(Z.loc, [-1]) q_vars = tf.reshape(Z.scale, [-1]) # tf.print('q_mean:', q_means, 'target q_mean:', next_target_means, output_stream=tf.logging.info) # tf.print('q_var:', q_vars, 'target q_var:', next_target_vars, output_stream=tf.logging.info) mean_td_error = self._td_errors_loss_fn(td_mean_target, q_means) # var_td_error = tf.sqrt(self._td_errors_loss_fn(td_var_target, q_vars)) var_td_error = td_var_target + q_vars - 2 * tf.sqrt( tf.abs(td_var_target * q_vars)) critic_loss = mean_td_error + var_td_error if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. critic_loss = tf.reduce_sum(critic_loss, axis=1) if weights is not None: critic_loss *= weights critic_loss = tf.reduce_mean(critic_loss) with tf.name_scope('Losses/'): tf.compat.v2.summary.scalar(name='critic_loss', data=critic_loss, step=self.train_step_counter) if self._debug_summaries: mean_td_errors = td_mean_target - q_means var_td_errors = td_var_target - q_vars common.generate_tensor_summaries('target_means', target_means, self.train_step_counter) common.generate_tensor_summaries('next_target_vars', next_target_vars, self.train_step_counter) common.generate_tensor_summaries('next_target_means', next_target_means, self.train_step_counter) common.generate_tensor_summaries('mean_td_errors', mean_td_errors, self.train_step_counter) common.generate_tensor_summaries('var_td_errors', var_td_errors, self.train_step_counter) common.generate_tensor_summaries('td_mean_targets', td_mean_target, self.train_step_counter) common.generate_tensor_summaries('td_var_targets', td_var_target, self.train_step_counter) common.generate_tensor_summaries('q_mean', q_means, self.train_step_counter) common.generate_tensor_summaries('q_var', q_vars, self.train_step_counter) return critic_loss, tf.reduce_mean(mean_td_error), tf.reduce_mean( var_td_error)
def _loss(self, experience, td_errors_loss_fn=common.element_wise_huber_loss, gamma=1.0, reward_scale_factor=1.0, weights=None, training=False): transition = self._as_transition(experience) time_steps, policy_steps, next_time_steps = transition actions = policy_steps.action valid_mask = tf.cast(~time_steps.is_last(), tf.float32) with tf.name_scope('loss'): # q_values is already gathered by actions q_values = self._compute_q_values(time_steps, actions, training=training) next_q_values = self._compute_next_all_q_values( next_time_steps, policy_steps.info) q_target_values = self._compute_next_all_q_values( time_steps, policy_steps.info) # This applies to any value of n_step_update and also in the RNN-DQN case. # In the RNN-DQN case, inputs and outputs contain a time dimension. #td_targets = compute_td_targets( # next_q_values, # rewards=reward_scale_factor * next_time_steps.reward, # discounts=gamma * next_time_steps.discount) td_targets = compute_munchausen_td_targets( next_q_values=next_q_values, q_target_values=q_target_values, actions=actions, rewards=reward_scale_factor * next_time_steps.reward, discounts=gamma * next_time_steps.discount, multi_dim_actions=self._action_spec.shape.rank > 0, alpha=self.alpha, entropy_tau=self.entropy_tau) td_error = valid_mask * (td_targets - q_values) td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values) if nest_utils.is_batched_nested_tensors(time_steps, self.time_step_spec, num_outer_dims=2): # Do a sum over the time dimension. td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1) # Aggregate across the elements of the batch and add regularization loss. # Note: We use an element wise loss above to ensure each element is always # weighted by 1/N where N is the batch size, even when some of the # weights are zero due to boundary transitions. Weighting by 1/K where K # is the actual number of non-zero weight would artificially increase # their contribution in the loss. Think about what would happen as # the number of boundary samples increases. agg_loss = common.aggregate_losses( per_example_loss=td_loss, sample_weight=weights, regularization_loss=self._q_network.losses) total_loss = agg_loss.total_loss losses_dict = { 'td_loss': agg_loss.weighted, 'reg_loss': agg_loss.regularization, 'total_loss': total_loss } common.summarize_scalar_dict(losses_dict, step=self.train_step_counter, name_scope='Losses/') if self._summarize_grads_and_vars: with tf.name_scope('Variables/'): for var in self._q_network.trainable_weights: tf.compat.v2.summary.histogram( name=var.name.replace(':', '_'), data=var, step=self.train_step_counter) if self._debug_summaries: diff_q_values = q_values - next_q_values common.generate_tensor_summaries('td_error', td_error, self.train_step_counter) common.generate_tensor_summaries('td_loss', td_loss, self.train_step_counter) common.generate_tensor_summaries('q_values', q_values, self.train_step_counter) common.generate_tensor_summaries('next_q_values', next_q_values, self.train_step_counter) common.generate_tensor_summaries('diff_q_values', diff_q_values, self.train_step_counter) return tf_agent.LossInfo( total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
def model_loss(self, images, actions, step_types, rewards, discounts, latent_posterior_samples_and_dists=None, weights=None): with tf.name_scope('model_loss'): if self._model_batch_size is not None: # Allow model batch size to be smaller than the batch size of the # other losses. This is because the model loss already gets a lot of # supervision from having a loss over all time steps. images, actions, step_types, rewards, discounts = tf.nest.map_structure( lambda x: x[:self._model_batch_size], (images, actions, step_types, rewards, discounts)) if latent_posterior_samples_and_dists is not None: latent_posterior_samples, latent_posterior_dists = latent_posterior_samples_and_dists latent_posterior_samples = tf.nest.map_structure( lambda x: x[:self._model_batch_size], latent_posterior_samples) latent_posterior_dists = slac_nest_utils.map_distribution_structure( lambda x: x[:self._model_batch_size], latent_posterior_dists) latent_posterior_samples_and_dists = ( latent_posterior_samples, latent_posterior_dists) model_loss, outputs = self._model_network.compute_loss( images, actions, step_types, rewards=rewards, discounts=discounts, latent_posterior_samples_and_dists= latent_posterior_samples_and_dists) for name, output in outputs.items(): if output.shape.ndims == 0: tf.contrib.summary.scalar(name, output) elif output.shape.ndims == 5: fps = 10 if self._control_timestep is None else int( np.round(1.0 / self._control_timestep)) if self._debug_summaries: _gif_summary(name + '/original', output[:self._num_images_per_summary], fps, step=self.train_step_counter) _gif_summary(name, output[:self._num_images_per_summary], fps, saturate=True, step=self.train_step_counter) else: raise NotImplementedError if weights is not None: model_loss *= weights model_loss = tf.reduce_mean(input_tensor=model_loss) if self._debug_summaries: common.generate_tensor_summaries('model_loss', model_loss, self.train_step_counter) return model_loss