Exemple #1
0
  def alpha_loss(self, time_steps, weights=None):
    """Computes the alpha_loss for EC-SAC training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      alpha_loss: A scalar alpha loss.
    """
    with tf.name_scope('alpha_loss'):
      tf.nest.assert_same_structure(time_steps, self.time_step_spec)

      unused_actions, log_pi = self._actions_and_log_probs(time_steps)
      alpha_loss = (
          self._log_alpha * tf.stop_gradient(-log_pi - self._target_entropy))

      if weights is not None:
        alpha_loss *= weights

      alpha_loss = tf.reduce_mean(input_tensor=alpha_loss)

      if self._debug_summaries:
        common.generate_tensor_summaries('alpha_loss', alpha_loss,
                                         self.train_step_counter)

      return alpha_loss
Exemple #2
0
    def actor_loss(self, time_steps, alphas, weights=None):
        """Computes the actor_loss for DDPG training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.
    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            actions, _ = self._actor_network((time_steps.observation, alphas),
                                             time_steps.step_type)
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(actions)
                q, _ = self._critic_network(
                    (time_steps.observation, actions, alphas),
                    time_steps.step_type)
                q_means, q_vars = tf.reshape(q.loc,
                                             [-1]), tf.reshape(q.scale, [-1])
                # actions = tf.nest.flatten(actions)

            cvar = self._compute_cvar(q_means, q_vars, alphas)
            actor_loss = -tf.reduce_mean(cvar)

            with tf.name_scope('Losses/'):
                tf.compat.v2.summary.scalar(name='actor_loss',
                                            data=actor_loss,
                                            step=self.train_step_counter)

            if self._debug_summaries:
                common.generate_tensor_summaries('cvar', cvar,
                                                 self.train_step_counter)

        return actor_loss
Exemple #3
0
    def alpha_loss(self, actor_time_steps, weights=None):
        """Computes the alpha_loss for EC-SAC training.

    Args:
      actor_time_steps: A batch of timesteps for the actor.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      alpha_loss: A scalar alpha loss.
    """
        with tf.name_scope('alpha_loss'):
            actions_distribution, _ = self._actor_network(
                actor_time_steps.observation, actor_time_steps.step_type)
            actions = actions_distribution.sample()
            log_pis = actions_distribution.log_prob(actions)
            alpha_loss = (self._log_alpha *
                          tf.stop_gradient(-log_pis - self._target_entropy))

            if weights is not None:
                alpha_loss *= weights

            alpha_loss = tf.reduce_mean(input_tensor=alpha_loss)

            if self._debug_summaries:
                common.generate_tensor_summaries('alpha_loss', alpha_loss,
                                                 self.train_step_counter)

            return alpha_loss
Exemple #4
0
    def _loss(self, experience, weights=None):
        """Computes loss for behavioral cloning.

    Args:
      experience: A `Trajectory` containing experience.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.

    Returns:
      loss: A `LossInfo` struct.

    Raises:
      ValueError:
        If the number of actions is greater than 1.
    """
        with tf.name_scope('loss'):
            actions = tf.nest.flatten(experience.action)[0]
            logits, _ = self._cloning_network(experience.observation,
                                              experience.step_type)

            boundary_weights = tf.cast(~experience.is_boundary(), logits.dtype)
            error = boundary_weights * self._loss_fn(logits, actions)

            if nest_utils.is_batched_nested_tensors(experience.action,
                                                    self.action_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                error = tf.reduce_sum(input_tensor=error, axis=1)

            # Average across the elements of the batch.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.
            if weights is not None:
                error *= weights
            loss = tf.reduce_mean(input_tensor=error)

            with tf.name_scope('Losses/'):
                tf.compat.v2.summary.scalar(name='loss',
                                            data=loss,
                                            step=self.train_step_counter)

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._cloning_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                common.generate_tensor_summaries('errors', error,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(loss,
                                     BehavioralCloningLossInfo(loss=error))
Exemple #5
0
  def _alpha_loss_debug_summaries(self, alpha_loss, entropy_diff):
    if self._debug_summaries:
      common.generate_tensor_summaries('alpha_loss', alpha_loss,
                                       self.train_step_counter)
      common.generate_tensor_summaries('entropy_diff', entropy_diff,
                                       self.train_step_counter)

      tf.compat.v2.summary.scalar(
          name='log_alpha', data=self._log_alpha, step=self.train_step_counter)
 def _critic_loss_debug_summaries(self, td_targets, pred_td_targets1,
                                  pred_td_targets2, weights):
   if self._debug_summaries:
     td_errors1 = td_targets - pred_td_targets1
     td_errors2 = td_targets - pred_td_targets2
     td_errors = tf.concat([td_errors1, td_errors2], axis=0)
     common.generate_tensor_summaries('td_errors', td_errors,
                                      self.train_step_counter)
     common.generate_tensor_summaries('td_targets', td_targets,
                                      self.train_step_counter)
     common.generate_tensor_summaries('pred_td_targets1', pred_td_targets1,
                                      self.train_step_counter)
     common.generate_tensor_summaries('pred_td_targets2', pred_td_targets2,
                                      self.train_step_counter)
     common.generate_tensor_summaries('weights', weights,
                                      self.train_step_counter)
Exemple #7
0
 def critic_loss(
     self,
     time_steps,
     actions,
     next_time_steps,
     td_errors_loss_fn,
     gamma=1.0,
     reward_scale_factor=1.0,
     weights=None,
     training=False,
     delta_r_scale=1.0,
     delta_r_warmup=0,
 ):
     sas_input = tf.concat(
         [time_steps.observation, actions, next_time_steps.observation],
         axis=-1)
     # Set training=False so no input noise is added.
     sa_probs, sas_probs = self._classifier(sas_input, training=False)
     sas_log_probs = tf.math.log(sas_probs)
     sa_log_probs = tf.math.log(sa_probs)
     if self._unnormalized_delta_r:  # Option for ablation experiment.
         delta_r = sas_log_probs[:, 1] - sas_log_probs[:, 0]
     else:  # Default option (i.e., the correct version).
         delta_r = (sas_log_probs[:, 1] - sas_log_probs[:, 0] -
                    sa_log_probs[:, 1] + sa_log_probs[:, 0])
     common.generate_tensor_summaries("delta_r", delta_r,
                                      self.train_step_counter)
     is_warmup = tf.cast(self.train_step_counter < delta_r_warmup,
                         tf.float32)
     tf.compat.v2.summary.scalar(name="is_warmup",
                                 data=is_warmup,
                                 step=self.train_step_counter)
     next_time_steps = next_time_steps._replace(
         reward=next_time_steps.reward + delta_r_scale *
         (1 - is_warmup) * delta_r)
     return super(DarcAgent, self).critic_loss(
         time_steps,
         actions,
         next_time_steps,
         td_errors_loss_fn,
         gamma=gamma,
         reward_scale_factor=reward_scale_factor,
         weights=weights,
         training=training,
     )
Exemple #8
0
    def alpha_loss(self, time_steps, weights=None):
        """Computes the alpha_loss for EC-SAC training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      alpha_loss: A scalar alpha loss.
    """
        with tf.name_scope('alpha_loss'):
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)

            unused_actions, log_pi = self._actions_and_log_probs(time_steps)
            entropy_diff = tf.stop_gradient(-log_pi - self._target_entropy)
            alpha_loss = (self._log_alpha * entropy_diff)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Sum over the time dimension.
                alpha_loss = tf.reduce_sum(input_tensor=alpha_loss, axis=1)

            if weights is not None:
                alpha_loss *= weights

            alpha_loss = tf.reduce_mean(input_tensor=alpha_loss)

            if self._debug_summaries:
                common.generate_tensor_summaries('alpha_loss', alpha_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('entropy_diff', entropy_diff,
                                                 self.train_step_counter)

                tf.compat.v2.summary.scalar(name='log_alpha',
                                            data=self._log_alpha,
                                            step=self.train_step_counter)

            return alpha_loss
  def critic_loss(self,
                  time_steps,
                  actions,
                  next_time_steps,
                  weights=None,
                  training=False):
    """Computes the critic loss for DDPG training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.
    Returns:
      critic_loss: A scalar critic loss.
    """
    with tf.name_scope('critic_loss'):
      target_actions, _ = self._target_actor_network(
          next_time_steps.observation, next_time_steps.step_type,
          training=False)
      target_critic_net_input = (next_time_steps.observation, target_actions)
      target_q_values, _ = self._target_critic_network(
          target_critic_net_input, next_time_steps.step_type,
          training=False)

      td_targets = tf.stop_gradient(
          self._reward_scale_factor * next_time_steps.reward +
          self._gamma * next_time_steps.discount * target_q_values)

      critic_net_input = (time_steps.observation, actions)
      q_values, _ = self._critic_network(critic_net_input,
                                         time_steps.step_type,
                                         training=training)

      critic_loss = self._td_errors_loss_fn(td_targets, q_values)
      if nest_utils.is_batched_nested_tensors(
          time_steps, self.time_step_spec, num_outer_dims=2):
        # Do a sum over the time dimension.
        critic_loss = tf.reduce_sum(critic_loss, axis=1)
      if weights is not None:
        critic_loss *= weights
      critic_loss = tf.reduce_mean(critic_loss)

      with tf.name_scope('Losses/'):
        tf.compat.v2.summary.scalar(
            name='critic_loss', data=critic_loss, step=self.train_step_counter)

      if self._debug_summaries:
        td_errors = td_targets - q_values
        common.generate_tensor_summaries('td_errors', td_errors,
                                         self.train_step_counter)
        common.generate_tensor_summaries('td_targets', td_targets,
                                         self.train_step_counter)
        common.generate_tensor_summaries('q_values', q_values,
                                         self.train_step_counter)

      return critic_loss
Exemple #10
0
 def _critic_no_entropy_loss_debug_summaries(self, td_targets,
                                             pred_td_targets1,
                                             pred_td_targets2):
   if self._debug_summaries:
     td_errors1 = td_targets - pred_td_targets1
     td_errors2 = td_targets - pred_td_targets2
     td_errors = tf.concat([td_errors1, td_errors2], axis=0)
     common.generate_tensor_summaries('td_errors_no_entropy_critic', td_errors,
                                      self.train_step_counter)
     common.generate_tensor_summaries('td_targets_no_entropy_critic',
                                      td_targets, self.train_step_counter)
     common.generate_tensor_summaries('pred_td_targets1_no_entropy_critic',
                                      pred_td_targets1,
                                      self.train_step_counter)
     common.generate_tensor_summaries('pred_td_targets2_no_entropy_critic',
                                      pred_td_targets2,
                                      self.train_step_counter)
Exemple #11
0
  def critic_loss(self,
                  time_steps,
                  actions,
                  next_time_steps):
    """Computes the critic loss for DDPG training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
    Returns:
      critic_loss: A scalar critic loss.
    """
    with tf.name_scope('critic_loss'):
      target_actions, _ = self._target_actor_network(
          next_time_steps.observation, next_time_steps.step_type)
      target_q_values, _ = self._target_critic_network(
          next_time_steps.observation, target_actions,
          next_time_steps.step_type)

      td_targets = tf.stop_gradient(
          self._reward_scale_factor * next_time_steps.reward +
          self._gamma * next_time_steps.discount * target_q_values)

      q_values, _ = self._critic_network(time_steps.observation, actions,
                                         time_steps.step_type)

      critic_loss = self._td_errors_loss_fn(td_targets, q_values)
      if nest_utils.is_batched_nested_tensors(
          time_steps, self.time_step_spec(), num_outer_dims=2):
        # Do a sum over the time dimension.
        critic_loss = tf.reduce_sum(critic_loss, axis=1)
      critic_loss = tf.reduce_mean(critic_loss)

      with tf.name_scope('Losses/'):
        tf.contrib.summary.scalar('critic_loss', critic_loss)

      if self._debug_summaries:
        td_errors = td_targets - q_values
        common_utils.generate_tensor_summaries('td_errors', td_errors)
        common_utils.generate_tensor_summaries('td_targets', td_targets)
        common_utils.generate_tensor_summaries('q_values', q_values)

      return critic_loss
Exemple #12
0
    def actor_loss(self, time_steps, weights=None):
        """Computes the actor_loss for SAC training.
    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)

            actions, log_pi = self._actions_and_log_probs(time_steps)
            target_input = (time_steps.observation, actions)
            target_q_values1, _ = self._critic_network_1(target_input,
                                                         time_steps.step_type,
                                                         training=False)
            target_q_values2, _ = self._critic_network_2(target_input,
                                                         time_steps.step_type,
                                                         training=False)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)
            actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values
            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Sum over the time dimension.
                actor_loss = tf.reduce_sum(input_tensor=actor_loss, axis=1)
            if weights is not None:
                actor_loss *= weights
            actor_loss = tf.reduce_mean(input_tensor=actor_loss)

            if self._debug_summaries:
                common.generate_tensor_summaries('actor_loss', actor_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('actions', actions,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('log_pi', log_pi,
                                                 self.train_step_counter)
                tf.compat.v2.summary.scalar(
                    name='entropy_avg',
                    data=-tf.reduce_mean(input_tensor=log_pi),
                    step=self.train_step_counter)
                common.generate_tensor_summaries('target_q_values',
                                                 target_q_values,
                                                 self.train_step_counter)
                batch_size = nest_utils.get_outer_shape(
                    time_steps, self._time_step_spec)[0]
                policy_state = self._train_policy.get_initial_state(batch_size)
                action_distribution = self._train_policy.distribution(
                    time_steps, policy_state).action
                if isinstance(action_distribution, tfp.distributions.Normal):
                    common.generate_tensor_summaries('act_mean',
                                                     action_distribution.loc,
                                                     self.train_step_counter)
                    common.generate_tensor_summaries('act_stddev',
                                                     action_distribution.scale,
                                                     self.train_step_counter)
                elif isinstance(action_distribution,
                                tfp.distributions.Categorical):
                    common.generate_tensor_summaries(
                        'act_mode', action_distribution.mode(),
                        self.train_step_counter)
                try:
                    common.generate_tensor_summaries(
                        'entropy_action', action_distribution.entropy(),
                        self.train_step_counter)
                except NotImplementedError:
                    pass  # Some distributions do not have an analytic entropy.

            return actor_loss
Exemple #13
0
    def critic_loss(self,
                    time_steps,
                    actions,
                    next_time_steps,
                    td_errors_loss_fn,
                    gamma=1.0,
                    reward_scale_factor=1.0,
                    weights=None):
        """Computes the critic loss for SAC training.
    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
    Returns:
      critic_loss: A scalar critic loss.
    """
        with tf.name_scope('critic_loss'):
            tf.nest.assert_same_structure(actions, self.action_spec)
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)
            tf.nest.assert_same_structure(next_time_steps, self.time_step_spec)

            next_actions, next_log_pis = self._actions_and_log_probs(
                next_time_steps)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = (tf.minimum(target_q_values1, target_q_values2) -
                               tf.exp(self._log_alpha) * next_log_pis)

            td_targets = tf.stop_gradient(
                reward_scale_factor * next_time_steps.reward +
                gamma * next_time_steps.discount * target_q_values)

            pred_input = (time_steps.observation, actions)
            pred_td_targets1, _ = self._critic_network_1(pred_input,
                                                         time_steps.step_type,
                                                         training=True)
            pred_td_targets2, _ = self._critic_network_2(pred_input,
                                                         time_steps.step_type,
                                                         training=True)
            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if weights is not None:
                critic_loss *= weights

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(input_tensor=critic_loss, axis=1)

            # Take the mean across the batch.
            critic_loss = tf.reduce_mean(input_tensor=critic_loss)

            if self._debug_summaries:
                td_errors1 = td_targets - pred_td_targets1
                td_errors2 = td_targets - pred_td_targets2
                td_errors = tf.concat([td_errors1, td_errors2], axis=0)
                common.generate_tensor_summaries('td_errors', td_errors,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_targets', td_targets,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('pred_td_targets1',
                                                 pred_td_targets1,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('pred_td_targets2',
                                                 pred_td_targets2,
                                                 self.train_step_counter)

            return critic_loss
    def _loss_h(self,
                experience,
                td_errors_loss_fn=common.element_wise_huber_loss,
                gamma=1.0,
                reward_scale_factor=1.0,
                weights=None,
                training=False):

        transition = self._as_transition(experience)
        time_steps, policy_steps, next_time_steps = transition
        actions = policy_steps.action

        valid_mask = tf.cast(~time_steps.is_last(), tf.float32)

        with tf.name_scope('loss'):
            # q_values is already gathered by actions
            h_values = self._compute_h_values(time_steps,
                                              actions,
                                              training=training)

            multi_dim_actions = self._action_spec.shape.rank > 0

            next_q_all_values = self._compute_next_all_q_values(
                next_time_steps, policy_steps.info)

            next_h_all_values = self._compute_next_all_h_values(
                next_time_steps, policy_steps.info)

            next_h_actions = tf.argmax(next_h_all_values, axis=1)

            # next_h_values here is used only for logging
            next_h_values = self._compute_next_h_values(
                next_time_steps, policy_steps.info)

            # next_q_values refer to Q(r,s') in Eqs.(4),(5)
            next_q_values = common.index_with_actions(
                next_q_all_values, tf.cast(next_h_actions, dtype=tf.int32),
                multi_dim_actions)

            h_target_all_values = self._compute_next_all_h_values(
                time_steps, policy_steps.info)

            h_target_values = common.index_with_actions(
                h_target_all_values, tf.cast(actions, dtype=tf.int32),
                multi_dim_actions)

            td_targets = compute_momentum_td_targets(
                q_target_values=next_q_values,
                h_target_values=h_target_values,
                beta=self.beta())

            td_error = valid_mask * (td_targets - h_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, h_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._h_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_h_values = h_values - next_h_values
                common.generate_tensor_summaries('td_error_h', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss_h', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('h_values', h_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_h_values',
                                                 next_h_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_h_values',
                                                 diff_h_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
Exemple #15
0
    def actor_loss(self, time_steps, weights=None):
        """Computes the actor_loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            tf.nest.assert_same_structure(time_steps, self.time_step_spec())

            actions, log_pi = self._actions_and_log_probs(time_steps)
            target_input_1 = (time_steps.observation, actions)
            target_q_values1, unused_network_state1 = self._critic_network1(
                target_input_1, time_steps.step_type)
            target_input_2 = (time_steps.observation, actions)
            target_q_values2, unused_network_state2 = self._critic_network2(
                target_input_2, time_steps.step_type)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)
            actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values
            if weights is not None:
                actor_loss *= weights
            actor_loss = tf.reduce_mean(input_tensor=actor_loss)

            if self._debug_summaries:
                common_utils.generate_tensor_summaries('actor_loss',
                                                       actor_loss)
                common_utils.generate_tensor_summaries('actions', actions)
                common_utils.generate_tensor_summaries('log_pi', log_pi)
                tf.contrib.summary.scalar('entropy_avg',
                                          -tf.reduce_mean(input_tensor=log_pi))
                common_utils.generate_tensor_summaries('target_q_values',
                                                       target_q_values)
                action_distribution = self.policy().distribution(
                    time_steps).action
                common_utils.generate_tensor_summaries('act_mean',
                                                       action_distribution.loc)
                common_utils.generate_tensor_summaries(
                    'act_stddev', action_distribution.scale)
                common_utils.generate_tensor_summaries(
                    'entropy_raw_action', action_distribution.entropy())

            return actor_loss
Exemple #16
0
def safety_critic_loss(time_steps,
                       actions,
                       next_time_steps,
                       safety_rewards,
                       get_action,
                       global_step,
                       critic_network=None,
                       target_network=None,
                       target_safety=None,
                       safety_gamma=0.45,
                       loss_fn='bce',
                       metrics=None,
                       debug_summaries=False):
    """Computes the critic loss for SAC training.

  Args:
    time_steps: A batch of timesteps.
    actions: A batch of actions.
    next_time_steps: A batch of next timesteps.
    safety_rewards: Task-agnostic rewards for safety. 1 is unsafe, 0 is safe.
    weights: Optional scalar or elementwise (per-batch-entry) importance
      weights.

  Returns:
    safe_critic_loss: A scalar critic loss.
  """
    with tf.name_scope('safety_critic_loss'):
        next_actions = get_action(next_time_steps)
        target_input = (next_time_steps.observation, next_actions)
        target_q_values, _ = target_network(target_input,
                                            next_time_steps.step_type)
        target_q_values = tf.nn.sigmoid(target_q_values)
        td_targets = tf.stop_gradient(safety_rewards + (1 - safety_rewards) *
                                      safety_gamma * next_time_steps.discount *
                                      target_q_values)

        if loss_fn == 'bce' or loss_fn == tf.keras.losses.binary_crossentropy:
            td_targets = tf.nn.sigmoid(td_targets)

        pred_input = (time_steps.observation, actions)
        pred_td_targets, _ = critic_network(pred_input,
                                            time_steps.step_type,
                                            training=True)
        pred_td_targets = tf.nn.sigmoid(pred_td_targets)

        # Loss fns: binary_crossentropy/squared_difference
        if loss_fn == 'mse':
            sc_loss = tf.math.squared_difference(td_targets, pred_td_targets)
        elif loss_fn == 'bce' or loss_fn is None:
            sc_loss = tf.keras.losses.binary_crossentropy(
                td_targets, pred_td_targets)
        elif loss_fn is not None:
            sc_loss = loss_fn(td_targets, pred_td_targets)

        if metrics:
            for metric in metrics:
                if isinstance(metric, tf.keras.metrics.AUC):
                    metric.update_state(safety_rewards, pred_td_targets)
                else:
                    rew_pred = tf.greater_equal(pred_td_targets, target_safety)
                    metric.update_state(safety_rewards, rew_pred)

        if debug_summaries:
            pred_td_targets = tf.nn.sigmoid(pred_td_targets)
            td_errors = td_targets - pred_td_targets
            common.generate_tensor_summaries('safety_td_errors', td_errors,
                                             global_step)
            common.generate_tensor_summaries('safety_td_targets', td_targets,
                                             global_step)
            common.generate_tensor_summaries('safety_pred_td_targets',
                                             pred_td_targets, global_step)

        return sc_loss
Exemple #17
0
 def _actor_loss_debug_summaries(self, actor_loss, actions, log_pi,
                                 target_q_values, time_steps):
     if self._debug_summaries:
         common.generate_tensor_summaries('actor_loss', actor_loss,
                                          self.train_step_counter)
         common.generate_tensor_summaries('actions', actions,
                                          self.train_step_counter)
         common.generate_tensor_summaries('log_pi', log_pi,
                                          self.train_step_counter)
         tf.compat.v2.summary.scalar(
             name='entropy_avg',
             data=-tf.reduce_mean(input_tensor=log_pi),
             step=self.train_step_counter)
         common.generate_tensor_summaries('target_q_values',
                                          target_q_values,
                                          self.train_step_counter)
         batch_size = nest_utils.get_outer_shape(time_steps,
                                                 self._time_step_spec)[0]
         policy_state = self._train_policy.get_initial_state(batch_size)
         action_distribution = self._train_policy.distribution(
             time_steps, policy_state).action
         if isinstance(action_distribution, tfp.distributions.Normal):
             common.generate_tensor_summaries('act_mean',
                                              action_distribution.loc,
                                              self.train_step_counter)
             common.generate_tensor_summaries('act_stddev',
                                              action_distribution.scale,
                                              self.train_step_counter)
         elif isinstance(action_distribution,
                         tfp.distributions.Categorical):
             common.generate_tensor_summaries('act_mode',
                                              action_distribution.mode(),
                                              self.train_step_counter)
         common.generate_tensor_summaries('entropy_action',
                                          action_distribution.entropy(),
                                          self.train_step_counter)
Exemple #18
0
    def _loss(self,
              experience,
              td_errors_loss_fn=common.element_wise_huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):
        """Computes loss for DQN training.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.train_sequence_length` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether this loss is being used for training.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        squeeze_time_dim = not self._q_network.state_spec
        if self._n_step_update == 1:
            time_steps, policy_steps, next_time_steps = (
                trajectory.experience_to_transitions(experience,
                                                     squeeze_time_dim))
            actions = policy_steps.action
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, policy_steps, _ = (
                trajectory.experience_to_transitions(first_two_steps,
                                                     squeeze_time_dim))
            actions = policy_steps.action
            _, _, next_time_steps = (trajectory.experience_to_transitions(
                last_two_steps, squeeze_time_dim))

        with tf.name_scope('loss'):
            q_values = self._compute_q_values(time_steps,
                                              actions,
                                              training=training)

            next_q_values = self._compute_next_q_values(
                next_time_steps, policy_steps.info)

            if self._n_step_update == 1:
                # Special case for n = 1 to avoid a loss of performance.
                td_targets = compute_td_targets(
                    next_q_values,
                    rewards=reward_scale_factor * next_time_steps.reward,
                    discounts=gamma * next_time_steps.discount)
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                td_targets = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=next_q_values,
                    time_major=False,
                    provide_all_returns=False)

            valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
            td_error = valid_mask * (td_targets - q_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._q_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_q_values = q_values - next_q_values
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_values', q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_q_values',
                                                 next_q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
Exemple #19
0
    def _actor_loss_debug_summaries(self, actor_loss, actions, log_pi,
                                    target_q_values, time_steps):
        if self._debug_summaries:
            common.generate_tensor_summaries('actor_loss', actor_loss,
                                             self.train_step_counter)
            try:
                common.generate_tensor_summaries('actions', actions,
                                                 self.train_step_counter)
            except ValueError:
                pass  # Guard against internal SAC variants that do not directly
                # generate actions.

            common.generate_tensor_summaries('log_pi', log_pi,
                                             self.train_step_counter)
            tf.compat.v2.summary.scalar(
                name='entropy_avg',
                data=-tf.reduce_mean(input_tensor=log_pi),
                step=self.train_step_counter)
            common.generate_tensor_summaries('target_q_values',
                                             target_q_values,
                                             self.train_step_counter)
            batch_size = nest_utils.get_outer_shape(time_steps,
                                                    self._time_step_spec)[0]
            policy_state = self._train_policy.get_initial_state(batch_size)
            action_distribution = self._train_policy.distribution(
                time_steps, policy_state).action
            if isinstance(action_distribution, tfp.distributions.Normal):
                common.generate_tensor_summaries('act_mean',
                                                 action_distribution.loc,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('act_stddev',
                                                 action_distribution.scale,
                                                 self.train_step_counter)
            elif isinstance(action_distribution,
                            tfp.distributions.Categorical):
                common.generate_tensor_summaries('act_mode',
                                                 action_distribution.mode(),
                                                 self.train_step_counter)
            try:
                common.generate_tensor_summaries('entropy_action',
                                                 action_distribution.entropy(),
                                                 self.train_step_counter)
            except NotImplementedError:
                pass  # Some distributions do not have an analytic entropy.
Exemple #20
0
    def actor_loss(self, time_steps, actor_time_steps, weights=None):
        """Computes the actor_loss for SAC training.

    Args:
      time_steps: A batch of timesteps for the critic.
      actor_time_steps: A batch of timesteps for the actor.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      actor_loss: A scalar actor loss.
    """
        with tf.name_scope('actor_loss'):
            time_steps = tf.nest.map_structure(tf.stop_gradient, time_steps)

            if self._actor_input_stop_gradient:
                actor_time_steps = tf.nest.map_structure(
                    tf.stop_gradient, actor_time_steps)

            actions_distribution, _ = self._actor_network(
                actor_time_steps.observation, actor_time_steps.step_type)
            actions = actions_distribution.sample()
            log_pis = actions_distribution.log_prob(actions)
            target_input_1 = (time_steps.observation, actions)
            target_q_values1, unused_network_state1 = self._critic_network1(
                target_input_1, time_steps.step_type)
            target_input_2 = (time_steps.observation, actions)
            target_q_values2, unused_network_state2 = self._critic_network2(
                target_input_2, time_steps.step_type)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)
            actor_loss = tf.exp(self._log_alpha) * log_pis - target_q_values
            if weights is not None:
                actor_loss *= weights
            actor_loss = tf.reduce_mean(input_tensor=actor_loss)

            if self._debug_summaries:
                common.generate_tensor_summaries('actor_loss', actor_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('actions', actions,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('log_pis', log_pis,
                                                 self.train_step_counter)
                tf.compat.v2.summary.scalar(
                    name='entropy_avg',
                    data=-tf.reduce_mean(input_tensor=log_pis),
                    step=self.train_step_counter)
                common.generate_tensor_summaries('target_q_values',
                                                 target_q_values,
                                                 self.train_step_counter)
                batch_size = nest_utils.get_outer_shape(
                    time_steps, self._time_step_spec)[0]
                policy_state = self.policy.get_initial_state(batch_size)
                action_distribution = self.policy.distribution(
                    time_steps, policy_state).action
                if isinstance(action_distribution, tfp.distributions.Normal):
                    common.generate_tensor_summaries('act_mean',
                                                     action_distribution.loc,
                                                     self.train_step_counter)
                    common.generate_tensor_summaries('act_stddev',
                                                     action_distribution.scale,
                                                     self.train_step_counter)
                elif isinstance(action_distribution,
                                tfp.distributions.Categorical):
                    common.generate_tensor_summaries(
                        'act_mode', action_distribution.mode(),
                        self.train_step_counter)
                try:
                    common.generate_tensor_summaries(
                        'entropy_action', action_distribution.entropy(),
                        self.train_step_counter)
                except NotImplementedError:
                    pass  # Some distributions do not have an analytic entropy.

            return actor_loss
Exemple #21
0
    def critic_loss(self,
                    time_steps,
                    actions,
                    next_time_steps,
                    actor_next_time_steps,
                    td_errors_loss_fn,
                    gamma=1.0,
                    reward_scale_factor=1.0,
                    weights=None):
        """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps for the critic.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps for the critic.
      actor_next_time_steps: A batch of next timesteps for the actor.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      critic_loss: A scalar critic loss.
    """
        with tf.name_scope('critic_loss'):
            if self._critic_input_stop_gradient:
                time_steps = tf.nest.map_structure(tf.stop_gradient,
                                                   time_steps)
                next_time_steps = tf.nest.map_structure(
                    tf.stop_gradient, next_time_steps)

            # not really necessary since there is a stop_gradient for the td_targets
            actor_next_time_steps = tf.nest.map_structure(
                tf.stop_gradient, actor_next_time_steps)

            next_actions_distribution, _ = self._actor_network(
                actor_next_time_steps.observation,
                actor_next_time_steps.step_type)
            next_actions = next_actions_distribution.sample()
            next_log_pis = next_actions_distribution.log_prob(next_actions)
            target_input_1 = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network1(
                target_input_1, next_time_steps.step_type)
            target_input_2 = (next_time_steps.observation, next_actions)
            target_q_values2, unused_network_state2 = self._target_critic_network2(
                target_input_2, next_time_steps.step_type)
            target_q_values = (tf.minimum(target_q_values1, target_q_values2) -
                               tf.exp(self._log_alpha) * next_log_pis)

            td_targets = tf.stop_gradient(
                reward_scale_factor * next_time_steps.reward +
                gamma * next_time_steps.discount * target_q_values)

            pred_input_1 = (time_steps.observation, actions)
            pred_td_targets1, unused_network_state1 = self._critic_network1(
                pred_input_1, time_steps.step_type)
            pred_input_2 = (time_steps.observation, actions)
            pred_td_targets2, unused_network_state2 = self._critic_network2(
                pred_input_2, time_steps.step_type)
            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if weights is not None:
                critic_loss *= weights

            # Take the mean across the batch.
            critic_loss = tf.reduce_mean(input_tensor=critic_loss)

            if self._debug_summaries:
                td_errors1 = td_targets - pred_td_targets1
                td_errors2 = td_targets - pred_td_targets2
                td_errors = tf.concat([td_errors1, td_errors2], axis=0)
                common.generate_tensor_summaries('td_errors', td_errors,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_targets', td_targets,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('pred_td_targets1',
                                                 pred_td_targets1,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('pred_td_targets2',
                                                 pred_td_targets2,
                                                 self.train_step_counter)

            return critic_loss
Exemple #22
0
    def critic_loss(self, experience, weights=None):
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        if self._n_step_return == 1:
            time_steps, actions, next_time_steps = self._experience_to_transitions(
                experience)
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, actions, _ = self._experience_to_transitions(
                first_two_steps)
            _, _, next_time_steps = self._experience_to_transitions(
                last_two_steps)

        with tf.name_scope('critic_loss'):
            tf.nest.assert_same_structure(actions, self.action_spec)
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)
            tf.nest.assert_same_structure(next_time_steps, self.time_step_spec)

            target_actions, _ = self._target_actor_network(
                next_time_steps.observation, next_time_steps.step_type)
            target_critic_network_input = (next_time_steps.observation,
                                           target_actions)
            _, next_distribution, _ = self._target_critic_network(
                target_critic_network_input, next_time_steps.step_type)

            batch_size = next_distribution.shape[0] or tf.shape(
                next_distribution)[0]
            tiled_support = tf.tile(self._support, [batch_size])
            tiled_support = tf.reshape(tiled_support,
                                       [batch_size, self._num_atoms])

            if self._n_step_return == 1:
                discount = next_time_steps.discount
                if discount.shape.ndims == 1:
                    # We expect discount to have a shape of [batch_size], while
                    # tiled_support will have a shape of [batch_size, num_atoms]. To
                    # multiply these, we add a second dimension of 1 to the discount.
                    discount = discount[:, None]
                next_value_term = tf.multiply(discount,
                                              tiled_support,
                                              name='next_value_term')

                reward = next_time_steps.reward
                if reward.shape.ndims == 1:
                    # See the explanation above.
                    reward = reward[:, None]
                reward_term = tf.multiply(self._reward_scale_factor,
                                          reward,
                                          name='reward_term')

                target_support = tf.add(reward_term,
                                        self._gamma * next_value_term,
                                        name='target_support')
            # TODO : This is not correct when n > 2
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = self._reward_scale_factor * experience.reward[:, :-1]
                discounts = self._gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                discounted_returns = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=tf.zeros([batch_size], dtype=discounts.dtype),
                    time_major=False,
                    provide_all_returns=False)

                # Convert discounted_returns from [batch_size] to [batch_size, 1]
                discounted_returns = discounted_returns[:, None]

                final_value_discount = tf.reduce_prod(discounts, axis=1)
                final_value_discount = final_value_discount[:, None]

                # Save the values of discounted_returns and final_value_discount in
                # order to check them in unit tests.
                self._discounted_returns = discounted_returns
                self._final_value_discount = final_value_discount

                target_support = tf.add(discounted_returns,
                                        final_value_discount * tiled_support,
                                        name='target_support')

            target_distribution = tf.stop_gradient(
                self._project_distribution(target_support, next_distribution,
                                           self._support))

            logits, distribution, _ = self._critic_network(
                (time_steps.observation, actions), time_steps.step_type)

            cross_entropy_loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(
                    labels=tf.stop_gradient(target_distribution),
                    logits=logits))
            l2_reg_loss = tf.add_n([
                tf.nn.l2_loss(v)
                for v in self._critic_network.trainable_variables
                if 'kernel' in v.name
            ]) * self._critic_l2_lambda

            critic_loss = cross_entropy_loss + l2_reg_loss

            with tf.name_scope('Losses/'):
                tf.compat.v2.summary.scalar('critic_loss',
                                            critic_loss,
                                            step=self.train_step_counter)

            if self._debug_summaries:
                distribution_errors = target_distribution - distribution
                with tf.name_scope('distribution_errors'):
                    common.generate_tensor_summaries(
                        'distribution_errors',
                        distribution_errors,
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean',
                        tf.reduce_mean(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean_abs',
                        tf.reduce_mean(tf.abs(distribution_errors)),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'max',
                        tf.reduce_max(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'min',
                        tf.reduce_min(distribution_errors),
                        step=self.train_step_counter)
                with tf.name_scope('target_distribution'):
                    common.generate_tensor_summaries(
                        'target_distribution',
                        target_distribution,
                        step=self.train_step_counter)

            return critic_loss
    def _loss(self,
              experience,
              td_errors_loss_fn=tf.losses.huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None):
        """Computes critic loss for CategoricalDQN training.

    See Algorithm 1 and the discussion immediately preceding it in page 6 of
    "A Distributional Perspective on Reinforcement Learning"
      Bellemare et al., 2017
      https://arxiv.org/abs/1707.06887

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.required_experience_time_steps` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional weights used for importance sampling.
    Returns:
      critic_loss: A scalar critic loss.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        if self._n_step_update == 1:
            time_steps, actions, next_time_steps = self._experience_to_transitions(
                experience)
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, actions, _ = self._experience_to_transitions(
                first_two_steps)
            _, _, next_time_steps = self._experience_to_transitions(
                last_two_steps)

        with tf.name_scope('critic_loss'):
            tf.nest.assert_same_structure(actions, self.action_spec)
            tf.nest.assert_same_structure(time_steps, self.time_step_spec)
            tf.nest.assert_same_structure(next_time_steps, self.time_step_spec)

            rank = nest_utils.get_outer_rank(time_steps.observation,
                                             self._time_step_spec.observation)

            # If inputs have a time dimension and the q_network is stateful,
            # combine the batch and time dimension.
            batch_squash = (None if rank <= 1 or self._q_network.state_spec
                            in ((), None) else utils.BatchSquash(rank))

            # q_logits contains the Q-value logits for all actions.
            q_logits, _ = self._q_network(time_steps.observation,
                                          time_steps.step_type)
            next_q_distribution = self._next_q_distribution(
                next_time_steps, batch_squash)

            if batch_squash is not None:
                # Squash outer dimensions to a single dimensions for facilitation
                # computing the loss the following. Required for supporting temporal
                # inputs, for example.
                q_logits = batch_squash.flatten(q_logits)
                actions = batch_squash.flatten(actions)
                next_time_steps = tf.nest.map_structure(
                    batch_squash.flatten, next_time_steps)

            actions = tf.nest.flatten(actions)[0]
            if actions.shape.ndims > 1:
                actions = tf.squeeze(actions, range(1, actions.shape.ndims))

            # Project the sample Bellman update \hat{T}Z_{\theta} onto the original
            # support of Z_{\theta} (see Figure 1 in paper).
            batch_size = tf.shape(q_logits)[0]
            tiled_support = tf.tile(self._support, [batch_size])
            tiled_support = tf.reshape(tiled_support,
                                       [batch_size, self._num_atoms])

            if self._n_step_update == 1:
                discount = next_time_steps.discount
                if discount.shape.ndims == 1:
                    # We expect discount to have a shape of [batch_size], while
                    # tiled_support will have a shape of [batch_size, num_atoms]. To
                    # multiply these, we add a second dimension of 1 to the discount.
                    discount = discount[:, None]
                next_value_term = tf.multiply(discount,
                                              tiled_support,
                                              name='next_value_term')

                reward = next_time_steps.reward
                if reward.shape.ndims == 1:
                    # See the explanation above.
                    reward = reward[:, None]
                reward_term = tf.multiply(reward_scale_factor,
                                          reward,
                                          name='reward_term')

                target_support = tf.add(reward_term,
                                        gamma * next_value_term,
                                        name='target_support')
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                # TODO(b/131557265): Replace value_ops.discounted_return with a method
                # that only computes the single value needed.
                discounted_rewards = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=tf.zeros([batch_size], dtype=discounts.dtype),
                    time_major=False)

                # We only need the first value within the time dimension which
                # corresponds to the full final return. The remaining values are only
                # partial returns.
                discounted_rewards = discounted_rewards[:, :1]

                final_value_discount = tf.reduce_prod(discounts, axis=1)
                final_value_discount = final_value_discount[:, None]

                # Save the values of discounted_rewards and final_value_discount in
                # order to check them in unit tests.
                self._discounted_rewards = discounted_rewards
                self._final_value_discount = final_value_discount

                target_support = tf.add(discounted_rewards,
                                        final_value_discount * tiled_support,
                                        name='target_support')

            target_distribution = tf.stop_gradient(
                project_distribution(target_support, next_q_distribution,
                                     self._support))

            # Obtain the current Q-value logits for the selected actions.
            indices = tf.range(tf.shape(q_logits)[0])[:, None]
            indices = tf.cast(indices, actions.dtype)
            reshaped_actions = tf.concat([indices, actions[:, None]], 1)
            chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions)

            # Compute the cross-entropy loss between the logits. If inputs have
            # a time dimension, compute the sum over the time dimension before
            # computing the mean over the batch dimension.
            if batch_squash is not None:
                target_distribution = batch_squash.unflatten(
                    target_distribution)
                chosen_action_logits = batch_squash.unflatten(
                    chosen_action_logits)
                critic_loss = tf.reduce_mean(
                    tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(
                        labels=target_distribution,
                        logits=chosen_action_logits),
                                  axis=1))
            else:
                critic_loss = tf.reduce_mean(
                    tf.nn.softmax_cross_entropy_with_logits_v2(
                        labels=target_distribution,
                        logits=chosen_action_logits))

            with tf.name_scope('Losses/'):
                tf.compat.v2.summary.scalar('critic_loss',
                                            critic_loss,
                                            step=self.train_step_counter)

            if self._debug_summaries:
                distribution_errors = target_distribution - chosen_action_logits
                with tf.name_scope('distribution_errors'):
                    common.generate_tensor_summaries(
                        'distribution_errors',
                        distribution_errors,
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean',
                        tf.reduce_mean(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean_abs',
                        tf.reduce_mean(tf.abs(distribution_errors)),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'max',
                        tf.reduce_max(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'min',
                        tf.reduce_min(distribution_errors),
                        step=self.train_step_counter)
                with tf.name_scope('target_distribution'):
                    common.generate_tensor_summaries(
                        'target_distribution',
                        target_distribution,
                        step=self.train_step_counter)

            # TODO(b/127318640): Give appropriate values for td_loss and td_error for
            # prioritized replay.
            return tf_agent.LossInfo(
                critic_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
Exemple #24
0
  def _loss(self,
            experience,
            td_errors_loss_fn=element_wise_huber_loss,
            gamma=1.0,
            reward_scale_factor=1.0,
            weights=None):
    """Computes loss for DQN training.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.train_sequence_length` if that property is
        not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and the
        final scalar loss is the mean of these values.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
    # Check that `experience` includes two outer dimensions [B, T, ...]. This
    # method requires `experience` to include the time dimension.
    self._check_trajectory_dimensions(experience)

    if self._n_step_update == 1:
      time_steps, actions, next_time_steps = self._experience_to_transitions(
          experience)
    else:
      # To compute n-step returns, we need the first time steps, the first
      # actions, and the last time steps. Therefore we extract the first and
      # last transitions from our Trajectory.
      first_two_steps = tf.nest.map_structure(lambda x: x[:, :2], experience)
      last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:], experience)
      time_steps, actions, _ = self._experience_to_transitions(first_two_steps)
      _, _, next_time_steps = self._experience_to_transitions(last_two_steps)

    with tf.name_scope('loss'):
      actions = tf.nest.flatten(actions)[0]
      q_values, _ = self._q_network(time_steps.observation,
                                    time_steps.step_type)

      # Handle action_spec.shape=(), and shape=(1,) by using the
      # multi_dim_actions param.
      multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.ndims > 0
      q_values = common.index_with_actions(
          q_values,
          tf.cast(actions, dtype=tf.int32),
          multi_dim_actions=multi_dim_actions)

      next_q_values = self._compute_next_q_values(next_time_steps)

      if self._n_step_update == 1:
        # Special case for n = 1 to avoid a loss of performance.
        td_targets = compute_td_targets(
            next_q_values,
            rewards=reward_scale_factor * next_time_steps.reward,
            discounts=gamma * next_time_steps.discount)
      else:
        # When computing discounted return, we need to throw out the last time
        # index of both reward and discount, which are filled with dummy values
        # to match the dimensions of the observation.
        # TODO(b/131557265): Replace value_ops.discounted_return with a method
        # that only computes the single value needed.
        n_step_return = value_ops.discounted_return(
            rewards=reward_scale_factor * experience.reward[:, :-1],
            discounts=gamma * experience.discount[:, :-1],
            final_value=next_q_values,
            time_major=False)

        # We only need the first value within the time dimension which
        # corresponds to the full final return. The remaining values are only
        # partial returns.
        td_targets = n_step_return[:, 0]

      valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
      td_error = valid_mask * (td_targets - q_values)

      td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

      if nest_utils.is_batched_nested_tensors(
          time_steps, self.time_step_spec, num_outer_dims=2):
        # Do a sum over the time dimension.
        td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

      if weights is not None:
        td_loss *= weights

      # Average across the elements of the batch.
      # Note: We use an element wise loss above to ensure each element is always
      #   weighted by 1/N where N is the batch size, even when some of the
      #   weights are zero due to boundary transitions. Weighting by 1/K where K
      #   is the actual number of non-zero weight would artificially increase
      #   their contribution in the loss. Think about what would happen as
      #   the number of boundary samples increases.
      loss = tf.reduce_mean(input_tensor=td_loss)

      with tf.name_scope('Losses/'):
        tf.compat.v1.summary.scalar(
            'loss_' + self.name, loss, collections=['train_' + self.name])
        # family=self.name)

      if self._summarize_grads_and_vars:
        with tf.name_scope('Variables/'):
          for var in self._q_network.trainable_weights:
            tf.compat.v2.summary.histogram(
                name=var.name.replace(':', '_'),
                data=var,
                step=self.train_step_counter)

      if self._debug_summaries:
        diff_q_values = q_values - next_q_values
        common.generate_tensor_summaries('td_error', td_error,
                                         self.train_step_counter)
        common.generate_tensor_summaries('td_loss', td_loss,
                                         self.train_step_counter)
        common.generate_tensor_summaries('q_values', q_values,
                                         self.train_step_counter)
        common.generate_tensor_summaries('next_q_values', next_q_values,
                                         self.train_step_counter)
        common.generate_tensor_summaries('diff_q_values', diff_q_values,
                                         self.train_step_counter)

      return tf_agent.LossInfo(loss,
                               DqnLossInfo(td_loss=td_loss, td_error=td_error))
Exemple #25
0
  def loss(self,
           time_steps,
           actions,
           next_time_steps,
           td_errors_loss_fn=element_wise_huber_loss,
           gamma=1.0,
           reward_scale_factor=1.0,
           weights=None):
    """Computes loss for DQN training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
    with tf.name_scope('loss'):
      actions = tf.nest.flatten(actions)[0]
      q_values, _ = self._q_network(time_steps.observation,
                                    time_steps.step_type)

      # Handle action_spec.shape=(), and shape=(1,) by using the
      # multi_dim_actions param.
      multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.ndims > 0
      q_values = common_utils.index_with_actions(
          q_values, tf.cast(actions, dtype=tf.int32),
          multi_dim_actions=multi_dim_actions)

      next_q_values = self._compute_next_q_values(next_time_steps)
      td_targets = compute_td_targets(
          next_q_values,
          rewards=reward_scale_factor * next_time_steps.reward,
          discounts=gamma * next_time_steps.discount)

      valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
      td_error = valid_mask * (td_targets - q_values)

      td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

      if nest_utils.is_batched_nested_tensors(
          time_steps, self.time_step_spec, num_outer_dims=2):
        # Do a sum over the time dimension.
        td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

      if weights is not None:
        td_loss *= weights

      # Average across the elements of the batch.
      # Note: We use an element wise loss above to ensure each element is always
      #   weighted by 1/N where N is the batch size, even when some of the
      #   weights are zero due to boundary transitions. Weighting by 1/K where K
      #   is the actual number of non-zero weight would artificially increase
      #   their contribution in the loss. Think about what would happen as
      #   the number of boundary samples increases.
      loss = tf.reduce_mean(input_tensor=td_loss)

      with tf.name_scope('Losses/'):
        tf.compat.v2.summary.scalar(
            name='loss', data=loss, step=self.train_step_counter)

      if self._summarize_grads_and_vars:
        with tf.name_scope('Variables/'):
          for var in self._q_network.trainable_weights:
            tf.compat.v2.summary.histogram(
                name=var.name.replace(':', '_'),
                data=var,
                step=self.train_step_counter)

      if self._debug_summaries:
        diff_q_values = q_values - next_q_values
        common_utils.generate_tensor_summaries('td_error', td_error,
                                               self.train_step_counter)
        common_utils.generate_tensor_summaries('td_loss', td_loss,
                                               self.train_step_counter)
        common_utils.generate_tensor_summaries('q_values', q_values,
                                               self.train_step_counter)
        common_utils.generate_tensor_summaries('next_q_values', next_q_values,
                                               self.train_step_counter)
        common_utils.generate_tensor_summaries('diff_q_values', diff_q_values,
                                               self.train_step_counter)

      return tf_agent.LossInfo(loss, DqnLossInfo(td_loss=td_loss,
                                                 td_error=td_error))
Exemple #26
0
  def critic_loss(self,
                  time_steps,
                  actions,
                  next_time_steps,
                  td_errors_loss_fn,
                  gamma=1.0,
                  reward_scale_factor=1.0,
                  weights=None):
    """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      critic_loss: A scalar critic loss.
    """
    with tf.name_scope('critic_loss'):
      tf.nest.assert_same_structure(actions, self.action_spec)
      tf.nest.assert_same_structure(time_steps, self.time_step_spec)
      tf.nest.assert_same_structure(next_time_steps, self.time_step_spec)

      next_actions, next_log_pis = self._actions_and_log_probs(next_time_steps)
      target_input = (next_time_steps.observation, next_actions)
      target_q_values = []
      for tcn in self._target_critic_networks:
        target_q_values1, _ = tcn(
            target_input, next_time_steps.step_type, training=False)
        target_q_values.append(target_q_values1)

      target_q_values = tfp.stats.percentile(target_q_values, self._percentile,
                                             axis=0)
      # target_q_values = tf.reduce_min(target_q_values)  # - tf.exp(self._log_alpha) * next_log_pis

      td_targets = tf.stop_gradient(
          reward_scale_factor * next_time_steps.reward +
          gamma * next_time_steps.discount * target_q_values)

      pred_input = (time_steps.observation, actions)

      pred_td_targets = []
      for cn in self._critic_networks:
        pred_td_targets1, _ = cn(pred_input, time_steps.step_type, training=True)
        pred_td_targets.append(pred_td_targets1)

      critic_loss = tf.reduce_mean(
        [td_errors_loss_fn(td_targets, pred_td_target) for pred_td_target in
         pred_td_targets], axis=0)

      if weights is not None:
        critic_loss *= weights

      # Take the mean across the batch.
      critic_loss = tf.reduce_mean(input_tensor=critic_loss)

      if self._debug_summaries:
        td_errors = [td_targets - pred_td_target for pred_td_target in pred_td_targets]
        td_errors = tf.concat(td_errors, axis=0)
        common.generate_tensor_summaries('td_errors', td_errors,
                                         self.train_step_counter)
        common.generate_tensor_summaries('td_targets', td_targets,
                                         self.train_step_counter)
      return critic_loss
Exemple #27
0
  def _loss(self,
            experience,
            td_errors_loss_fn=common.element_wise_huber_loss,
            gamma=1.0,
            reward_scale_factor=1.0,
            weights=None,
            training=False):
    """Computes loss for DQN training.

    Args:
      experience: A batch of experience data in the form of a `Trajectory` or
        `Transition`. The structure of `experience` must match that of
        `self.collect_policy.step_spec`.

        If a `Trajectory`, all tensors in `experience` must be shaped
        `[B, T, ...]` where `T` must be equal to `self.train_sequence_length`
        if that property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether this loss is being used for training.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
    transition = self._as_transition(experience)
    time_steps, policy_steps, next_time_steps = transition
    actions = policy_steps.action

    with tf.name_scope('loss'):
      q_values = self._compute_q_values(time_steps, actions, training=training)

      next_q_values = self._compute_next_q_values(
          next_time_steps, policy_steps.info)

      # This applies to any value of n_step_update and also in the RNN-DQN case.
      # In the RNN-DQN case, inputs and outputs contain a time dimension.
      td_targets = compute_td_targets(
          next_q_values,
          rewards=reward_scale_factor * next_time_steps.reward,
          discounts=gamma * next_time_steps.discount)

      valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
      td_error = valid_mask * (td_targets - q_values)

      td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

      if nest_utils.is_batched_nested_tensors(
          time_steps, self.time_step_spec, num_outer_dims=2):
        # Do a sum over the time dimension.
        td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

      # Aggregate across the elements of the batch and add regularization loss.
      # Note: We use an element wise loss above to ensure each element is always
      #   weighted by 1/N where N is the batch size, even when some of the
      #   weights are zero due to boundary transitions. Weighting by 1/K where K
      #   is the actual number of non-zero weight would artificially increase
      #   their contribution in the loss. Think about what would happen as
      #   the number of boundary samples increases.

      agg_loss = common.aggregate_losses(
          per_example_loss=td_loss,
          sample_weight=weights,
          regularization_loss=self._q_network.losses)
      total_loss = agg_loss.total_loss

      losses_dict = {'td_loss': agg_loss.weighted,
                     'reg_loss': agg_loss.regularization,
                     'total_loss': total_loss}

      common.summarize_scalar_dict(losses_dict,
                                   step=self.train_step_counter,
                                   name_scope='Losses/')

      if self._summarize_grads_and_vars:
        with tf.name_scope('Variables/'):
          for var in self._q_network.trainable_weights:
            tf.compat.v2.summary.histogram(
                name=var.name.replace(':', '_'),
                data=var,
                step=self.train_step_counter)

      if self._debug_summaries:
        diff_q_values = q_values - next_q_values
        common.generate_tensor_summaries('td_error', td_error,
                                         self.train_step_counter)
        common.generate_tensor_summaries('td_loss', td_loss,
                                         self.train_step_counter)
        common.generate_tensor_summaries('q_values', q_values,
                                         self.train_step_counter)
        common.generate_tensor_summaries('next_q_values', next_q_values,
                                         self.train_step_counter)
        common.generate_tensor_summaries('diff_q_values', diff_q_values,
                                         self.train_step_counter)

      return tf_agent.LossInfo(total_loss, DqnLossInfo(td_loss=td_loss,
                                                       td_error=td_error))
Exemple #28
0
    def critic_loss(self,
                    time_steps,
                    actions,
                    alphas,
                    next_time_steps,
                    weights=None):
        """Computes the critic loss for DDPG training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.
    Returns:
      critic_loss: A scalar critic loss.
    """
        with tf.name_scope('critic_loss'):
            target_actions, _ = self._target_actor_network(
                (next_time_steps.observation, alphas),
                next_time_steps.step_type)
            next_target_critic_net_input = (next_time_steps.observation,
                                            target_actions, alphas)
            next_target_Z, _ = self._target_critic_network(
                next_target_critic_net_input, next_time_steps.step_type)
            next_target_means = tf.reshape(next_target_Z.loc, [-1])
            next_target_vars = tf.reshape(next_target_Z.scale, [-1])
            target_critic_net_input = (time_steps.observation, actions, alphas)
            target_Z, _ = self._target_critic_network(
                target_critic_net_input, next_time_steps.step_type)
            target_means = tf.reshape(target_Z.loc, [-1])
            if len(next_target_means.shape) != 1:
                raise ValueError(
                    'Q-network should output a tensor of shape (batch,) '
                    'but shape {} was returned.'.format(
                        next_target_means.shape.as_list()))
            if len(target_means.shape) != 1:
                raise ValueError(
                    'Q-network should output a tensor of shape (batch,) '
                    'but shape {} was returned.'.format(
                        target_means.shape.as_list()))

            td_mean_target = tf.stop_gradient(
                self._reward_scale_factor * next_time_steps.reward +
                self._gamma * next_time_steps.discount * next_target_means)

            # Refer to Eq. 6 in WCPG
            td_var_target = tf.stop_gradient(
                (self._reward_scale_factor * next_time_steps.reward)**2 +
                2 * self._gamma * next_time_steps.discount *
                next_time_steps.reward * next_target_means +
                next_time_steps.discount * self._gamma**2 * next_target_vars +
                self._gamma**2 * next_target_means**2 - target_means**2)
            tf.debugging.check_numerics(target_means,
                                        'target means is inf or nan.')
            tf.debugging.check_numerics(next_target_means,
                                        'next target means is inf or nan.')
            tf.debugging.check_numerics(td_var_target,
                                        'target var is inf or nan.')
            tf.debugging.check_numerics(td_var_target,
                                        'target var is inf or nan.')

            critic_net_input = (time_steps.observation, actions, alphas)
            Z, _ = self._critic_network(critic_net_input, time_steps.step_type)
            q_means = tf.reshape(Z.loc, [-1])
            q_vars = tf.reshape(Z.scale, [-1])

            # tf.print('q_mean:', q_means, 'target q_mean:', next_target_means, output_stream=tf.logging.info)
            # tf.print('q_var:', q_vars, 'target q_var:', next_target_vars, output_stream=tf.logging.info)
            mean_td_error = self._td_errors_loss_fn(td_mean_target, q_means)
            # var_td_error = tf.sqrt(self._td_errors_loss_fn(td_var_target, q_vars))
            var_td_error = td_var_target + q_vars - 2 * tf.sqrt(
                tf.abs(td_var_target * q_vars))
            critic_loss = mean_td_error + var_td_error

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                critic_loss = tf.reduce_sum(critic_loss, axis=1)
            if weights is not None:
                critic_loss *= weights
            critic_loss = tf.reduce_mean(critic_loss)

            with tf.name_scope('Losses/'):
                tf.compat.v2.summary.scalar(name='critic_loss',
                                            data=critic_loss,
                                            step=self.train_step_counter)

            if self._debug_summaries:
                mean_td_errors = td_mean_target - q_means
                var_td_errors = td_var_target - q_vars
                common.generate_tensor_summaries('target_means', target_means,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_target_vars',
                                                 next_target_vars,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_target_means',
                                                 next_target_means,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('mean_td_errors',
                                                 mean_td_errors,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('var_td_errors',
                                                 var_td_errors,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_mean_targets',
                                                 td_mean_target,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_var_targets',
                                                 td_var_target,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_mean', q_means,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_var', q_vars,
                                                 self.train_step_counter)

            return critic_loss, tf.reduce_mean(mean_td_error), tf.reduce_mean(
                var_td_error)
Exemple #29
0
    def _loss(self,
              experience,
              td_errors_loss_fn=common.element_wise_huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):

        transition = self._as_transition(experience)
        time_steps, policy_steps, next_time_steps = transition
        actions = policy_steps.action

        valid_mask = tf.cast(~time_steps.is_last(), tf.float32)

        with tf.name_scope('loss'):
            # q_values is already gathered by actions
            q_values = self._compute_q_values(time_steps,
                                              actions,
                                              training=training)

            next_q_values = self._compute_next_all_q_values(
                next_time_steps, policy_steps.info)

            q_target_values = self._compute_next_all_q_values(
                time_steps, policy_steps.info)

            # This applies to any value of n_step_update and also in the RNN-DQN case.
            # In the RNN-DQN case, inputs and outputs contain a time dimension.
            #td_targets = compute_td_targets(
            #    next_q_values,
            #    rewards=reward_scale_factor * next_time_steps.reward,
            #    discounts=gamma * next_time_steps.discount)

            td_targets = compute_munchausen_td_targets(
                next_q_values=next_q_values,
                q_target_values=q_target_values,
                actions=actions,
                rewards=reward_scale_factor * next_time_steps.reward,
                discounts=gamma * next_time_steps.discount,
                multi_dim_actions=self._action_spec.shape.rank > 0,
                alpha=self.alpha,
                entropy_tau=self.entropy_tau)

            td_error = valid_mask * (td_targets - q_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._q_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_q_values = q_values - next_q_values
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_values', q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_q_values',
                                                 next_q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
Exemple #30
0
    def model_loss(self,
                   images,
                   actions,
                   step_types,
                   rewards,
                   discounts,
                   latent_posterior_samples_and_dists=None,
                   weights=None):
        with tf.name_scope('model_loss'):
            if self._model_batch_size is not None:
                # Allow model batch size to be smaller than the batch size of the
                # other losses. This is because the model loss already gets a lot of
                # supervision from having a loss over all time steps.
                images, actions, step_types, rewards, discounts = tf.nest.map_structure(
                    lambda x: x[:self._model_batch_size],
                    (images, actions, step_types, rewards, discounts))
                if latent_posterior_samples_and_dists is not None:
                    latent_posterior_samples, latent_posterior_dists = latent_posterior_samples_and_dists
                    latent_posterior_samples = tf.nest.map_structure(
                        lambda x: x[:self._model_batch_size],
                        latent_posterior_samples)
                    latent_posterior_dists = slac_nest_utils.map_distribution_structure(
                        lambda x: x[:self._model_batch_size],
                        latent_posterior_dists)
                    latent_posterior_samples_and_dists = (
                        latent_posterior_samples, latent_posterior_dists)

            model_loss, outputs = self._model_network.compute_loss(
                images,
                actions,
                step_types,
                rewards=rewards,
                discounts=discounts,
                latent_posterior_samples_and_dists=
                latent_posterior_samples_and_dists)
            for name, output in outputs.items():
                if output.shape.ndims == 0:
                    tf.contrib.summary.scalar(name, output)
                elif output.shape.ndims == 5:
                    fps = 10 if self._control_timestep is None else int(
                        np.round(1.0 / self._control_timestep))
                    if self._debug_summaries:
                        _gif_summary(name + '/original',
                                     output[:self._num_images_per_summary],
                                     fps,
                                     step=self.train_step_counter)
                    _gif_summary(name,
                                 output[:self._num_images_per_summary],
                                 fps,
                                 saturate=True,
                                 step=self.train_step_counter)
                else:
                    raise NotImplementedError

            if weights is not None:
                model_loss *= weights

            model_loss = tf.reduce_mean(input_tensor=model_loss)

            if self._debug_summaries:
                common.generate_tensor_summaries('model_loss', model_loss,
                                                 self.train_step_counter)

            return model_loss