예제 #1
0
    def _apply_loss(self, aggregated_losses, variables_to_train, tape,
                    optimizer):
        total_loss = aggregated_losses.total_loss
        tf.debugging.check_numerics(total_loss, "Loss is inf or nan")
        assert list(variables_to_train), "No variables in the agent's network."

        grads = tape.gradient(total_loss, variables_to_train)
        grads_and_vars = list(zip(grads, variables_to_train))

        if self._gradient_clipping is not None:
            grads_and_vars = eager_utils.clip_gradient_norms(
                grads_and_vars, self._gradient_clipping)

        if self.summarize_grads_and_vars:
            eager_utils.add_variables_summaries(grads_and_vars,
                                                self.train_step_counter)

        optimizer.apply_gradients(grads_and_vars)

        if self.summaries_enabled:
            dict_losses = {
                "loss": aggregated_losses.weighted,
                "reg_loss": aggregated_losses.regularization,
                "total_loss": total_loss
            }
            common.summarize_scalar_dict(dict_losses,
                                         step=self.train_step_counter,
                                         name_scope="Losses/")
예제 #2
0
    def _loss(self,
              experience,
              td_errors_loss_fn=common.element_wise_huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):

        transition = self._as_transition(experience)
        time_steps, policy_steps, next_time_steps = transition
        actions = policy_steps.action

        valid_mask = tf.cast(~time_steps.is_last(), tf.float32)

        with tf.name_scope('loss'):
            # q_values is already gathered by actions
            q_values = self._compute_q_values(time_steps,
                                              actions,
                                              training=training)

            next_q_values = self._compute_next_all_q_values(
                next_time_steps, policy_steps.info)

            q_target_values = self._compute_next_all_q_values(
                time_steps, policy_steps.info)

            # This applies to any value of n_step_update and also in the RNN-DQN case.
            # In the RNN-DQN case, inputs and outputs contain a time dimension.
            #td_targets = compute_td_targets(
            #    next_q_values,
            #    rewards=reward_scale_factor * next_time_steps.reward,
            #    discounts=gamma * next_time_steps.discount)

            td_targets = compute_munchausen_td_targets(
                next_q_values=next_q_values,
                q_target_values=q_target_values,
                actions=actions,
                rewards=reward_scale_factor * next_time_steps.reward,
                discounts=gamma * next_time_steps.discount,
                multi_dim_actions=self._action_spec.shape.rank > 0,
                alpha=self.alpha,
                entropy_tau=self.entropy_tau)

            td_error = valid_mask * (td_targets - q_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._q_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_q_values = q_values - next_q_values
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_values', q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_q_values',
                                                 next_q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
예제 #3
0
파일: dqn_agent.py 프로젝트: wuzh07/agents
  def _loss(self,
            experience,
            td_errors_loss_fn=common.element_wise_huber_loss,
            gamma=1.0,
            reward_scale_factor=1.0,
            weights=None,
            training=False):
    """Computes loss for DQN training.

    Args:
      experience: A batch of experience data in the form of a `Trajectory` or
        `Transition`. The structure of `experience` must match that of
        `self.collect_policy.step_spec`.

        If a `Trajectory`, all tensors in `experience` must be shaped
        `[B, T, ...]` where `T` must be equal to `self.train_sequence_length`
        if that property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether this loss is being used for training.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
    transition = self._as_transition(experience)
    time_steps, policy_steps, next_time_steps = transition
    actions = policy_steps.action

    with tf.name_scope('loss'):
      q_values = self._compute_q_values(time_steps, actions, training=training)

      next_q_values = self._compute_next_q_values(
          next_time_steps, policy_steps.info)

      # This applies to any value of n_step_update and also in the RNN-DQN case.
      # In the RNN-DQN case, inputs and outputs contain a time dimension.
      td_targets = compute_td_targets(
          next_q_values,
          rewards=reward_scale_factor * next_time_steps.reward,
          discounts=gamma * next_time_steps.discount)

      valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
      td_error = valid_mask * (td_targets - q_values)

      td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

      if nest_utils.is_batched_nested_tensors(
          time_steps, self.time_step_spec, num_outer_dims=2):
        # Do a sum over the time dimension.
        td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

      # Aggregate across the elements of the batch and add regularization loss.
      # Note: We use an element wise loss above to ensure each element is always
      #   weighted by 1/N where N is the batch size, even when some of the
      #   weights are zero due to boundary transitions. Weighting by 1/K where K
      #   is the actual number of non-zero weight would artificially increase
      #   their contribution in the loss. Think about what would happen as
      #   the number of boundary samples increases.

      agg_loss = common.aggregate_losses(
          per_example_loss=td_loss,
          sample_weight=weights,
          regularization_loss=self._q_network.losses)
      total_loss = agg_loss.total_loss

      losses_dict = {'td_loss': agg_loss.weighted,
                     'reg_loss': agg_loss.regularization,
                     'total_loss': total_loss}

      common.summarize_scalar_dict(losses_dict,
                                   step=self.train_step_counter,
                                   name_scope='Losses/')

      if self._summarize_grads_and_vars:
        with tf.name_scope('Variables/'):
          for var in self._q_network.trainable_weights:
            tf.compat.v2.summary.histogram(
                name=var.name.replace(':', '_'),
                data=var,
                step=self.train_step_counter)

      if self._debug_summaries:
        diff_q_values = q_values - next_q_values
        common.generate_tensor_summaries('td_error', td_error,
                                         self.train_step_counter)
        common.generate_tensor_summaries('td_loss', td_loss,
                                         self.train_step_counter)
        common.generate_tensor_summaries('q_values', q_values,
                                         self.train_step_counter)
        common.generate_tensor_summaries('next_q_values', next_q_values,
                                         self.train_step_counter)
        common.generate_tensor_summaries('diff_q_values', diff_q_values,
                                         self.train_step_counter)

      return tf_agent.LossInfo(total_loss, DqnLossInfo(td_loss=td_loss,
                                                       td_error=td_error))
예제 #4
0
파일: dqn_agent.py 프로젝트: zircote/agents
    def _loss(self,
              experience,
              td_errors_loss_fn=common.element_wise_huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):
        """Computes loss for DQN training.

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.train_sequence_length` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute the
        element wise loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.  The output td_loss will be scaled by these weights, and
        the final scalar loss is the mean of these values.
      training: Whether this loss is being used for training.

    Returns:
      loss: An instance of `DqnLossInfo`.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        squeeze_time_dim = not self._q_network.state_spec
        if self._n_step_update == 1:
            time_steps, policy_steps, next_time_steps = (
                trajectory.experience_to_transitions(experience,
                                                     squeeze_time_dim))
            actions = policy_steps.action
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, policy_steps, _ = (
                trajectory.experience_to_transitions(first_two_steps,
                                                     squeeze_time_dim))
            actions = policy_steps.action
            _, _, next_time_steps = (trajectory.experience_to_transitions(
                last_two_steps, squeeze_time_dim))

        with tf.name_scope('loss'):
            q_values = self._compute_q_values(time_steps,
                                              actions,
                                              training=training)

            next_q_values = self._compute_next_q_values(
                next_time_steps, policy_steps.info)

            if self._n_step_update == 1:
                # Special case for n = 1 to avoid a loss of performance.
                td_targets = compute_td_targets(
                    next_q_values,
                    rewards=reward_scale_factor * next_time_steps.reward,
                    discounts=gamma * next_time_steps.discount)
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                td_targets = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=next_q_values,
                    time_major=False,
                    provide_all_returns=False)

            valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
            td_error = valid_mask * (td_targets - q_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, q_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._q_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_q_values = q_values - next_q_values
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_values', q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_q_values',
                                                 next_q_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
예제 #5
0
    def _loss(self,
              experience,
              td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):
        """Computes critic loss for CategoricalDQN training.

    See Algorithm 1 and the discussion immediately preceding it in page 6 of
    "A Distributional Perspective on Reinforcement Learning"
      Bellemare et al., 2017
      https://arxiv.org/abs/1707.06887

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.required_experience_time_steps` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional weights used for importance sampling.
      training: Whether the loss is being used for training.
    Returns:
      critic_loss: A scalar critic loss.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        squeeze_time_dim = not self._q_network.state_spec
        if self._n_step_update == 1:
            time_steps, policy_steps, next_time_steps = (
                trajectory.experience_to_transitions(experience,
                                                     squeeze_time_dim))
            actions = policy_steps.action
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, policy_steps, _ = (
                trajectory.experience_to_transitions(first_two_steps,
                                                     squeeze_time_dim))
            actions = policy_steps.action
            _, _, next_time_steps = (trajectory.experience_to_transitions(
                last_two_steps, squeeze_time_dim))

        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            rank = nest_utils.get_outer_rank(time_steps.observation,
                                             self._time_step_spec.observation)

            # If inputs have a time dimension and the q_network is stateful,
            # combine the batch and time dimension.
            batch_squash = (None if rank <= 1 or self._q_network.state_spec
                            in ((), None) else network_utils.BatchSquash(rank))

            network_observation = time_steps.observation

            if self._observation_and_action_constraint_splitter is not None:
                network_observation, _ = (
                    self._observation_and_action_constraint_splitter(
                        network_observation))

            # q_logits contains the Q-value logits for all actions.
            q_logits, _ = self._q_network(network_observation,
                                          time_steps.step_type,
                                          training=training)

            if batch_squash is not None:
                # Squash outer dimensions to a single dimensions for facilitation
                # computing the loss the following. Required for supporting temporal
                # inputs, for example.
                q_logits = batch_squash.flatten(q_logits)
                actions = batch_squash.flatten(actions)
                next_time_steps = tf.nest.map_structure(
                    batch_squash.flatten, next_time_steps)

            next_q_distribution = self._next_q_distribution(next_time_steps)

            if actions.shape.rank > 1:
                actions = tf.squeeze(actions,
                                     list(range(1, actions.shape.rank)))

            # Project the sample Bellman update \hat{T}Z_{\theta} onto the original
            # support of Z_{\theta} (see Figure 1 in paper).
            batch_size = q_logits.shape[0] or tf.shape(q_logits)[0]
            tiled_support = tf.tile(self._support, [batch_size])
            tiled_support = tf.reshape(tiled_support,
                                       [batch_size, self._num_atoms])

            if self._n_step_update == 1:
                discount = next_time_steps.discount
                if discount.shape.rank == 1:
                    # We expect discount to have a shape of [batch_size], while
                    # tiled_support will have a shape of [batch_size, num_atoms]. To
                    # multiply these, we add a second dimension of 1 to the discount.
                    discount = tf.expand_dims(discount, -1)
                next_value_term = tf.multiply(discount,
                                              tiled_support,
                                              name='next_value_term')

                reward = next_time_steps.reward
                if reward.shape.rank == 1:
                    # See the explanation above.
                    reward = tf.expand_dims(reward, -1)
                reward_term = tf.multiply(reward_scale_factor,
                                          reward,
                                          name='reward_term')

                target_support = tf.add(reward_term,
                                        gamma * next_value_term,
                                        name='target_support')
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                discounted_returns = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=tf.zeros([batch_size], dtype=discounts.dtype),
                    time_major=False,
                    provide_all_returns=False)

                # Convert discounted_returns from [batch_size] to [batch_size, 1]
                discounted_returns = tf.expand_dims(discounted_returns, -1)

                final_value_discount = tf.reduce_prod(discounts, axis=1)
                final_value_discount = tf.expand_dims(final_value_discount, -1)

                # Save the values of discounted_returns and final_value_discount in
                # order to check them in unit tests.
                self._discounted_returns = discounted_returns
                self._final_value_discount = final_value_discount

                target_support = tf.add(discounted_returns,
                                        final_value_discount * tiled_support,
                                        name='target_support')

            target_distribution = tf.stop_gradient(
                project_distribution(target_support, next_q_distribution,
                                     self._support))

            # Obtain the current Q-value logits for the selected actions.
            indices = tf.range(batch_size)
            indices = tf.cast(indices, actions.dtype)
            reshaped_actions = tf.stack([indices, actions], axis=-1)
            chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions)

            # Compute the cross-entropy loss between the logits. If inputs have
            # a time dimension, compute the sum over the time dimension before
            # computing the mean over the batch dimension.
            if batch_squash is not None:
                target_distribution = batch_squash.unflatten(
                    target_distribution)
                chosen_action_logits = batch_squash.unflatten(
                    chosen_action_logits)
                critic_loss = tf.reduce_sum(
                    tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                        labels=target_distribution,
                        logits=chosen_action_logits),
                    axis=1)
            else:
                critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                    labels=target_distribution, logits=chosen_action_logits)

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            dict_losses = {
                'critic_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(dict_losses,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._debug_summaries:
                distribution_errors = target_distribution - chosen_action_logits
                with tf.name_scope('distribution_errors'):
                    common.generate_tensor_summaries(
                        'distribution_errors',
                        distribution_errors,
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean',
                        tf.reduce_mean(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean_abs',
                        tf.reduce_mean(tf.abs(distribution_errors)),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'max',
                        tf.reduce_max(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'min',
                        tf.reduce_min(distribution_errors),
                        step=self.train_step_counter)
                with tf.name_scope('target_distribution'):
                    common.generate_tensor_summaries(
                        'target_distribution',
                        target_distribution,
                        step=self.train_step_counter)

            # TODO(b/127318640): Give appropriate values for td_loss and td_error for
            # prioritized replay.
            return tf_agent.LossInfo(
                total_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
예제 #6
0
    def _loss(self, experience, weights=None):
        """Computes loss for behavioral cloning.

    Args:
      experience: A `Trajectory` containing experience.
      weights: Optional scalar or element-wise (per-batch-entry) importance
        weights.

    Returns:
      loss: A `LossInfo` struct.

    Raises:
      ValueError:
        If the number of actions is greater than 1.
    """
        with tf.name_scope('loss'):
            if self._nested_actions:
                actions = experience.action
            else:
                actions = tf.nest.flatten(experience.action)[0]

            batch_size = (tf.compat.dimension_value(
                experience.step_type.shape[0])
                          or tf.shape(experience.step_type)[0])
            logits, _ = self._cloning_network(
                experience.observation,
                experience.step_type,
                training=True,
                network_state=self._cloning_network.get_initial_state(
                    batch_size))

            error = self._loss_fn(logits, actions)
            error_dtype = tf.nest.flatten(error)[0].dtype
            boundary_weights = tf.cast(~experience.is_boundary(), error_dtype)
            error *= boundary_weights

            if nest_utils.is_batched_nested_tensors(experience.action,
                                                    self.action_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                error = tf.reduce_sum(input_tensor=error, axis=1)

            # Average across the elements of the batch.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=error,
                sample_weight=weights,
                regularization_loss=self._cloning_network.losses)
            total_loss = agg_loss.total_loss

            dict_losses = {
                'loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(dict_losses,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._cloning_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                common.generate_tensor_summaries('errors', error,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(total_loss,
                                     BehavioralCloningLossInfo(loss=error))
예제 #7
0
    def total_loss(self,
                   experience: traj.Trajectory,
                   returns: types.Tensor,
                   weights: types.Tensor,
                   training: bool = False) -> tf_agent.LossInfo:
        # Ensure we see at least one full episode.
        time_steps = ts.TimeStep(experience.step_type,
                                 tf.zeros_like(experience.reward),
                                 tf.zeros_like(experience.discount),
                                 experience.observation)
        is_last = experience.is_last()
        num_episodes = tf.reduce_sum(tf.cast(is_last, tf.float32))
        tf.debugging.assert_greater(
            num_episodes,
            0.0,
            message=
            'No complete episode found. REINFORCE requires full episodes '
            'to compute losses.')

        # Mask out partial episodes at the end of each batch of time_steps.
        # NOTE: We use is_last rather than is_boundary because the last transition
        # is the transition with the last valid reward.  In other words, the
        # reward on the boundary transitions do not have valid rewards.  Since
        # REINFORCE is calculating a loss w.r.t. the returns (and not bootstrapping)
        # keeping the boundary transitions is irrelevant.
        valid_mask = tf.cast(experience.is_last(), dtype=tf.float32)
        valid_mask = tf.math.cumsum(valid_mask, axis=1, reverse=True)
        valid_mask = tf.cast(valid_mask > 0, dtype=tf.float32)
        if weights is not None:
            weights *= valid_mask
        else:
            weights = valid_mask

        advantages = returns
        value_preds = None

        if self._baseline:
            value_preds, _ = self._value_network(time_steps.observation,
                                                 time_steps.step_type,
                                                 training=True)
            if self._debug_summaries:
                tf.compat.v2.summary.histogram(name='value_preds',
                                               data=value_preds,
                                               step=self.train_step_counter)

        advantages = self._advantage_fn(returns, value_preds)
        if self._debug_summaries:
            tf.compat.v2.summary.histogram(name='advantages',
                                           data=advantages,
                                           step=self.train_step_counter)

        # TODO(b/126592060): replace with tensor normalizer.
        if self._normalize_returns:
            advantages = _standard_normalize(advantages, axes=(0, 1))
            if self._debug_summaries:
                tf.compat.v2.summary.histogram(
                    name='normalized_%s' %
                    ('advantages' if self._baseline else 'returns'),
                    data=advantages,
                    step=self.train_step_counter)

        nest_utils.assert_same_structure(time_steps, self.time_step_spec)
        policy_state = _get_initial_policy_state(self.collect_policy,
                                                 time_steps)
        actions_distribution = self.collect_policy.distribution(
            time_steps, policy_state=policy_state).action

        policy_gradient_loss = self.policy_gradient_loss(
            actions_distribution,
            experience.action,
            experience.is_boundary(),
            advantages,
            num_episodes,
            weights,
        )

        entropy_regularization_loss = self.entropy_regularization_loss(
            actions_distribution, weights)

        network_regularization_loss = tf.nn.scale_regularization_loss(
            self._actor_network.losses)

        total_loss = (policy_gradient_loss + network_regularization_loss +
                      entropy_regularization_loss)

        losses_dict = {
            'policy_gradient_loss': policy_gradient_loss,
            'policy_network_regularization_loss': network_regularization_loss,
            'entropy_regularization_loss': entropy_regularization_loss,
            'value_estimation_loss': 0.0,
            'value_network_regularization_loss': 0.0,
        }

        value_estimation_loss = None
        if self._baseline:
            value_estimation_loss = self.value_estimation_loss(
                value_preds, returns, num_episodes, weights)
            value_network_regularization_loss = tf.nn.scale_regularization_loss(
                self._value_network.losses)
            total_loss += value_estimation_loss + value_network_regularization_loss
            losses_dict['value_estimation_loss'] = value_estimation_loss
            losses_dict['value_network_regularization_loss'] = (
                value_network_regularization_loss)

        loss_info_extra = ReinforceAgentLossInfo(**losses_dict)

        losses_dict[
            'total_loss'] = total_loss  # Total loss not in loss_info_extra.

        common.summarize_scalar_dict(losses_dict,
                                     self.train_step_counter,
                                     name_scope='Losses/')

        return tf_agent.LossInfo(total_loss, loss_info_extra)
    def _loss_h(self,
                experience,
                td_errors_loss_fn=common.element_wise_huber_loss,
                gamma=1.0,
                reward_scale_factor=1.0,
                weights=None,
                training=False):

        transition = self._as_transition(experience)
        time_steps, policy_steps, next_time_steps = transition
        actions = policy_steps.action

        valid_mask = tf.cast(~time_steps.is_last(), tf.float32)

        with tf.name_scope('loss'):
            # q_values is already gathered by actions
            h_values = self._compute_h_values(time_steps,
                                              actions,
                                              training=training)

            multi_dim_actions = self._action_spec.shape.rank > 0

            next_q_all_values = self._compute_next_all_q_values(
                next_time_steps, policy_steps.info)

            next_h_all_values = self._compute_next_all_h_values(
                next_time_steps, policy_steps.info)

            next_h_actions = tf.argmax(next_h_all_values, axis=1)

            # next_h_values here is used only for logging
            next_h_values = self._compute_next_h_values(
                next_time_steps, policy_steps.info)

            # next_q_values refer to Q(r,s') in Eqs.(4),(5)
            next_q_values = common.index_with_actions(
                next_q_all_values, tf.cast(next_h_actions, dtype=tf.int32),
                multi_dim_actions)

            h_target_all_values = self._compute_next_all_h_values(
                time_steps, policy_steps.info)

            h_target_values = common.index_with_actions(
                h_target_all_values, tf.cast(actions, dtype=tf.int32),
                multi_dim_actions)

            td_targets = compute_momentum_td_targets(
                q_target_values=next_q_values,
                h_target_values=h_target_values,
                beta=self.beta())

            td_error = valid_mask * (td_targets - h_values)

            td_loss = valid_mask * td_errors_loss_fn(td_targets, h_values)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            # Note: We use an element wise loss above to ensure each element is always
            #   weighted by 1/N where N is the batch size, even when some of the
            #   weights are zero due to boundary transitions. Weighting by 1/K where K
            #   is the actual number of non-zero weight would artificially increase
            #   their contribution in the loss. Think about what would happen as
            #   the number of boundary samples increases.

            agg_loss = common.aggregate_losses(
                per_example_loss=td_loss,
                sample_weight=weights,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self._h_network.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self._debug_summaries:
                diff_h_values = h_values - next_h_values
                common.generate_tensor_summaries('td_error_h', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss_h', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('h_values', h_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('next_h_values',
                                                 next_h_values,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_h_values',
                                                 diff_h_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
예제 #9
0
    def _loss(self,
              net,
              individual_iql_time_step,
              individual_iql_next_time_step,
              time_steps,
              actions,
              next_time_steps,
              i,
              t,
              td_errors_loss_fn,
              gamma=1.0,
              weights=None,
              training=False):
        with tf.name_scope('loss'):

            individual_target = tf.reshape(
                net._compute_next_q_values(next_time_steps, index=i, time=t),
                [-1, 1])
            individual_main = tf.reshape(
                net._compute_q_values(time_steps,
                                      actions,
                                      index=i,
                                      time=t,
                                      training=True), [-1, 1])

            reward = tf.reshape(individual_iql_next_time_step.reward, [-1, 1])
            discount = tf.reshape(individual_iql_next_time_step.discount,
                                  [-1, 1])
            td_targets = tf.stop_gradient(reward +
                                          gamma * discount * individual_target)

            valid_mask = tf.reshape(
                tf.cast(~individual_iql_time_step.is_last(), tf.float32),
                [-1, 1])
            td_error = valid_mask * (td_targets - individual_main)
            td_loss = valid_mask * tf.compat.v1.losses.absolute_difference(
                td_targets,
                individual_main,
                reduction=tf.compat.v1.losses.Reduction.NONE)
            # td_loss = valid_mask * td_errors_loss_fn(td_targets, q_total)

            if nest_utils.is_batched_nested_tensors(individual_iql_time_step,
                                                    net.agent.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            agg_loss = common.aggregate_losses(per_example_loss=td_loss,
                                               sample_weight=weights)
            total_loss = agg_loss.total_loss

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self.summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in net.agent.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self.debug_summaries:
                diff_q_values = individual_main - individual_target
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_total', individual_main,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('target_q_total',
                                                 individual_target,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))
예제 #10
0
    def _loss(self,
              time_steps,
              policy_steps,
              next_time_steps,
              agents,
              nameDict,
              networkDict,
              td_errors_loss_fn,
              gamma=1.0,
              weights=None,
              training=False):
        with tf.name_scope('loss'):
            total_agents_target = []
            total_agents_main = []
            for i, flexAgent in enumerate(agents):
                for node in nameDict:
                    target = None
                    for type, names in nameDict[node].items():
                        if flexAgent.id in names:
                            target = []
                            main = []
                            for net in networkDict[node][type]:
                                action_index = -1
                                for t in range(24):
                                    action_index += 1
                                    actions = tf.gather(policy_steps.action[i],
                                                        indices=action_index,
                                                        axis=-1)
                                    individual_target = net._compute_next_q_values(
                                        next_time_steps, index=i, time=t)
                                    individual_main = net._compute_q_values(
                                        time_steps,
                                        actions,
                                        index=i,
                                        time=t,
                                        training=True)
                                    target.append(
                                        tf.reshape(individual_target, [-1, 1]))
                                    main.append(
                                        tf.reshape(individual_main, [-1, 1]))
                            break
                    if target is not None:
                        break
                total_agents_target.append(tf.concat(target, -1))
                total_agents_main.append(tf.concat(main, -1))
            total_agents_target = tf.concat(total_agents_target, -1)
            total_agents_main = tf.concat(total_agents_main, -1)

            q_total, _ = self.QMIXNet(total_agents_main,
                                      time_steps.observation,
                                      training=training)
            q_total = tf.squeeze(q_total)
            target_q_total, _ = self.TargetQMIXNet(total_agents_target,
                                                   next_time_steps.observation,
                                                   training=False)
            target_q_total = tf.squeeze(target_q_total)
            """using the mean reward for all the agents"""
            mean_reward = tf.reduce_mean(next_time_steps.reward, axis=1)
            td_targets = tf.stop_gradient(mean_reward +
                                          gamma * next_time_steps.discount *
                                          target_q_total)

            valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
            td_error = valid_mask * (td_targets - q_total)
            td_loss = valid_mask * tf.compat.v1.losses.absolute_difference(
                td_targets,
                q_total,
                reduction=tf.compat.v1.losses.Reduction.NONE)
            # td_loss = valid_mask * td_errors_loss_fn(td_targets, q_total)

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Do a sum over the time dimension.
                td_loss = tf.reduce_sum(input_tensor=td_loss, axis=1)

            # Aggregate across the elements of the batch and add regularization loss.
            agg_loss = common.aggregate_losses(per_example_loss=td_loss,
                                               sample_weight=weights)
            total_loss = agg_loss.total_loss

            if self.summary_writer is not None:
                with self.summary_writer.as_default():
                    tf.summary.scalar('loss',
                                      total_loss,
                                      step=self.train_step_counter)

            losses_dict = {
                'td_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(losses_dict,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self.summarize_grads_and_vars:
                with tf.name_scope('Variables/'):
                    for var in self.QMIXNet.trainable_weights:
                        tf.compat.v2.summary.histogram(
                            name=var.name.replace(':', '_'),
                            data=var,
                            step=self.train_step_counter)

            if self.debug_summaries:
                diff_q_values = q_total - target_q_total
                common.generate_tensor_summaries('td_error', td_error,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('td_loss', td_loss,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('q_total', q_total,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('target_q_total',
                                                 target_q_total,
                                                 self.train_step_counter)
                common.generate_tensor_summaries('diff_q_values',
                                                 diff_q_values,
                                                 self.train_step_counter)

            return tf_agent.LossInfo(
                total_loss, DqnLossInfo(td_loss=td_loss, td_error=td_error))