Exemplo n.º 1
0
    def distribution(
        self, time_step: ts.TimeStep, policy_state: types.NestedTensor = ()
    ) -> policy_step.PolicyStep:
        """Generates the distribution over next actions given the time_step.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of Tensors
        representing the previous policy_state.

    Returns:
      A `PolicyStep` named tuple containing:

        `action`: A tf.distribution capturing the distribution of next actions.
        `state`: A policy state tensor for the next call to distribution.
        `info`: Optional side information such as action log probabilities.

    Raises:
      ValueError or TypeError: If `validate_args is True` and inputs or
        outputs do not match `time_step_spec`, `policy_state_spec`,
        or `policy_step_spec`.
    """
        if self._validate_args:
            time_step = nest_utils.prune_extra_keys(self._time_step_spec,
                                                    time_step)
            policy_state = nest_utils.prune_extra_keys(self._policy_state_spec,
                                                       policy_state)
            nest_utils.assert_same_structure(
                time_step,
                self._time_step_spec,
                message='time_step and time_step_spec structures do not match')
            nest_utils.assert_same_structure(
                policy_state,
                self._policy_state_spec,
                message=
                'policy_state and policy_state_spec structures do not match')
        if self._automatic_state_reset:
            policy_state = self._maybe_reset_state(time_step, policy_state)
        step = self._distribution(time_step=time_step,
                                  policy_state=policy_state)
        if self.emit_log_probability:
            # This here is set only for compatibility with info_spec in constructor.
            info = policy_step.set_log_probability(
                step.info,
                tf.nest.map_structure(
                    lambda _: tf.constant(0., dtype=tf.float32),
                    policy_step.get_log_probability(self._info_spec)))
            step = step._replace(info=info)
        if self._validate_args:
            nest_utils.assert_same_structure(
                step,
                self._policy_step_spec,
                message=('distribution output and policy_step_spec structures '
                         'do not match'))
        return step
  def _assert_nested_variable_updated(
      self,
      variables: types.NestedVariable,
      check_nest_seq_types: bool = True) -> None:
    # Prepare the exptected content of the variables.
    expected_values = (tf.constant(0, dtype=tf.int64, shape=()), {
        'var1': (tf.constant([1, 1], dtype=tf.float64, shape=(2,)),),
        'var2': tf.constant([[2], [3]], dtype=tf.int32, shape=(2, 1))
    })
    flat_expected_values = tf.nest.flatten(expected_values)

    # Assert that the variables have the same content as the expected values.
    # Meaning that the two nested structure have to be the same.
    self.assertIsNone(
        nest_utils.assert_same_structure(
            variables, expected_values, check_types=check_nest_seq_types))
    # And the values in `variables` have to be equal to (or close to, depending
    # on the component type) to the expected ones.
    flat_variables = tf.nest.flatten(variables)
    self.assertAllEqual(flat_variables[0], flat_expected_values[0])
    self.assertAllClose(flat_variables[1], flat_expected_values[1])
    self.assertAllEqual(flat_variables[2], flat_expected_values[2])
Exemplo n.º 3
0
    def critic_no_entropy_loss(self,
                               time_steps,
                               actions,
                               next_time_steps,
                               td_errors_loss_fn,
                               gamma=1.0,
                               reward_scale_factor=1.0,
                               weights=None,
                               training=False):
        """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.

    Returns:
      critic_loss: A scalar critic loss.
    """
        with tf.name_scope('critic_no_entropy_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            next_actions, _ = self._actions_and_log_probs(next_time_steps)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_no_entropy_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_no_entropy_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = tf.minimum(
                target_q_values1, target_q_values2
            )  # entropy has been removed from the target critic function

            td_targets = tf.stop_gradient(
                reward_scale_factor * next_time_steps.reward +
                gamma * next_time_steps.discount * target_q_values)

            pred_input = (time_steps.observation, actions)
            pred_td_targets1, _ = self._critic_network_no_entropy_1(
                pred_input, time_steps.step_type, training=training)
            pred_td_targets2, _ = self._critic_network_no_entropy_2(
                pred_input, time_steps.step_type, training=training)
            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if nest_utils.is_batched_nested_tensors(time_steps,
                                                    self.time_step_spec,
                                                    num_outer_dims=2):
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(input_tensor=critic_loss, axis=1)

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                sample_weight=weights,
                regularization_loss=(self._critic_network_no_entropy_1.losses +
                                     self._critic_network_no_entropy_2.losses))
            critic_no_entropy_loss = agg_loss.total_loss

            self._critic_no_entropy_loss_debug_summaries(
                td_targets, pred_td_targets1, pred_td_targets2)

            return critic_no_entropy_loss
Exemplo n.º 4
0
    def action(self, time_step, policy_state=(), seed=None):
        """Generates next action given the time_step and policy_state.

    Args:
      time_step: A `TimeStep` tuple corresponding to `time_step_spec()`.
      policy_state: A Tensor, or a nested dict, list or tuple of Tensors
        representing the previous policy_state.
      seed: Seed to use if action performs sampling (optional).

    Returns:
      A `PolicyStep` named tuple containing:
        `action`: An action Tensor matching the `action_spec`.
        `state`: A policy state tensor to be fed into the next call to action.
        `info`: Optional side information such as action log probabilities.

    Raises:
      RuntimeError: If subclass __init__ didn't call super().__init__.
      ValueError or TypeError: If `validate_args is True` and inputs or
        outputs do not match `time_step_spec`, `policy_state_spec`,
        or `policy_step_spec`.
    """
        if self._enable_functions and getattr(self, '_action_fn',
                                              None) is None:
            raise RuntimeError(
                'Cannot find _action_fn.  Did %s.__init__ call super?' %
                type(self).__name__)
        if self._enable_functions:
            action_fn = self._action_fn
        else:
            action_fn = self._action

        if self._validate_args:
            time_step = nest_utils.prune_extra_keys(self._time_step_spec,
                                                    time_step)
            policy_state = nest_utils.prune_extra_keys(self._policy_state_spec,
                                                       policy_state)
            nest_utils.assert_same_structure(
                time_step,
                self._time_step_spec,
                message='time_step and time_step_spec structures do not match')
            nest_utils.assert_same_structure(
                policy_state,
                self._policy_state_spec,
                message=
                'policy_state and policy_state_spec structures do not match')

        if self._automatic_state_reset:
            policy_state = self._maybe_reset_state(time_step, policy_state)
        step = action_fn(time_step=time_step,
                         policy_state=policy_state,
                         seed=seed)

        def clip_action(action, action_spec):
            if isinstance(action_spec, tensor_spec.BoundedTensorSpec):
                return common.clip_to_spec(action, action_spec)
            return action

        if self._validate_args:
            nest_utils.assert_same_structure(
                step.action,
                self._action_spec,
                message='action and action_spec structures do not match')

        if self._clip:
            clipped_actions = tf.nest.map_structure(clip_action, step.action,
                                                    self._action_spec)
            step = step._replace(action=clipped_actions)

        if self._validate_args:
            nest_utils.assert_same_structure(
                step,
                self._policy_step_spec,
                message=
                'action output and policy_step_spec structures do not match')

            def compare_to_spec(value, spec):
                return value.dtype.is_compatible_with(spec.dtype)

            compatibility = [
                compare_to_spec(v, s)
                for (v, s) in zip(tf.nest.flatten(step.action),
                                  tf.nest.flatten(self.action_spec))
            ]

            if not all(compatibility):
                get_dtype = lambda x: x.dtype
                action_dtypes = tf.nest.map_structure(get_dtype, step.action)
                spec_dtypes = tf.nest.map_structure(get_dtype,
                                                    self.action_spec)

                raise TypeError(
                    'Policy produced an action with a dtype that doesn\'t '
                    'match its action_spec. Got action:\n  %s\n with '
                    'action_spec:\n  %s' % (action_dtypes, spec_dtypes))

        return step
Exemplo n.º 5
0
 def as_dict(outputs, output_spec):
     nest_utils.assert_same_structure(outputs, output_spec)
     flat_outputs = tf.nest.flatten(outputs)
     flat_names = [s.name for s in tf.nest.flatten(output_spec)]
     return dict(zip(flat_names, flat_outputs))
Exemplo n.º 6
0
    def _loss(self,
              experience,
              td_errors_loss_fn=tf.compat.v1.losses.huber_loss,
              gamma=1.0,
              reward_scale_factor=1.0,
              weights=None,
              training=False):
        """Computes critic loss for CategoricalDQN training.

    See Algorithm 1 and the discussion immediately preceding it in page 6 of
    "A Distributional Perspective on Reinforcement Learning"
      Bellemare et al., 2017
      https://arxiv.org/abs/1707.06887

    Args:
      experience: A batch of experience data in the form of a `Trajectory`. The
        structure of `experience` must match that of `self.policy.step_spec`.
        All tensors in `experience` must be shaped `[batch, time, ...]` where
        `time` must be equal to `self.required_experience_time_steps` if that
        property is not `None`.
      td_errors_loss_fn: A function(td_targets, predictions) to compute loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional weights used for importance sampling.
      training: Whether the loss is being used for training.
    Returns:
      critic_loss: A scalar critic loss.
    Raises:
      ValueError:
        if the number of actions is greater than 1.
    """
        # Check that `experience` includes two outer dimensions [B, T, ...]. This
        # method requires a time dimension to compute the loss properly.
        self._check_trajectory_dimensions(experience)

        squeeze_time_dim = not self._q_network.state_spec
        if self._n_step_update == 1:
            time_steps, policy_steps, next_time_steps = (
                trajectory.experience_to_transitions(experience,
                                                     squeeze_time_dim))
            actions = policy_steps.action
        else:
            # To compute n-step returns, we need the first time steps, the first
            # actions, and the last time steps. Therefore we extract the first and
            # last transitions from our Trajectory.
            first_two_steps = tf.nest.map_structure(lambda x: x[:, :2],
                                                    experience)
            last_two_steps = tf.nest.map_structure(lambda x: x[:, -2:],
                                                   experience)
            time_steps, policy_steps, _ = (
                trajectory.experience_to_transitions(first_two_steps,
                                                     squeeze_time_dim))
            actions = policy_steps.action
            _, _, next_time_steps = (trajectory.experience_to_transitions(
                last_two_steps, squeeze_time_dim))

        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            rank = nest_utils.get_outer_rank(time_steps.observation,
                                             self._time_step_spec.observation)

            # If inputs have a time dimension and the q_network is stateful,
            # combine the batch and time dimension.
            batch_squash = (None if rank <= 1 or self._q_network.state_spec
                            in ((), None) else network_utils.BatchSquash(rank))

            network_observation = time_steps.observation

            if self._observation_and_action_constraint_splitter is not None:
                network_observation, _ = (
                    self._observation_and_action_constraint_splitter(
                        network_observation))

            # q_logits contains the Q-value logits for all actions.
            q_logits, _ = self._q_network(network_observation,
                                          time_steps.step_type,
                                          training=training)

            if batch_squash is not None:
                # Squash outer dimensions to a single dimensions for facilitation
                # computing the loss the following. Required for supporting temporal
                # inputs, for example.
                q_logits = batch_squash.flatten(q_logits)
                actions = batch_squash.flatten(actions)
                next_time_steps = tf.nest.map_structure(
                    batch_squash.flatten, next_time_steps)

            next_q_distribution = self._next_q_distribution(next_time_steps)

            if actions.shape.rank > 1:
                actions = tf.squeeze(actions,
                                     list(range(1, actions.shape.rank)))

            # Project the sample Bellman update \hat{T}Z_{\theta} onto the original
            # support of Z_{\theta} (see Figure 1 in paper).
            batch_size = q_logits.shape[0] or tf.shape(q_logits)[0]
            tiled_support = tf.tile(self._support, [batch_size])
            tiled_support = tf.reshape(tiled_support,
                                       [batch_size, self._num_atoms])

            if self._n_step_update == 1:
                discount = next_time_steps.discount
                if discount.shape.rank == 1:
                    # We expect discount to have a shape of [batch_size], while
                    # tiled_support will have a shape of [batch_size, num_atoms]. To
                    # multiply these, we add a second dimension of 1 to the discount.
                    discount = tf.expand_dims(discount, -1)
                next_value_term = tf.multiply(discount,
                                              tiled_support,
                                              name='next_value_term')

                reward = next_time_steps.reward
                if reward.shape.rank == 1:
                    # See the explanation above.
                    reward = tf.expand_dims(reward, -1)
                reward_term = tf.multiply(reward_scale_factor,
                                          reward,
                                          name='reward_term')

                target_support = tf.add(reward_term,
                                        gamma * next_value_term,
                                        name='target_support')
            else:
                # When computing discounted return, we need to throw out the last time
                # index of both reward and discount, which are filled with dummy values
                # to match the dimensions of the observation.
                rewards = reward_scale_factor * experience.reward[:, :-1]
                discounts = gamma * experience.discount[:, :-1]

                # TODO(b/134618876): Properly handle Trajectories that include episode
                # boundaries with nonzero discount.

                discounted_returns = value_ops.discounted_return(
                    rewards=rewards,
                    discounts=discounts,
                    final_value=tf.zeros([batch_size], dtype=discounts.dtype),
                    time_major=False,
                    provide_all_returns=False)

                # Convert discounted_returns from [batch_size] to [batch_size, 1]
                discounted_returns = tf.expand_dims(discounted_returns, -1)

                final_value_discount = tf.reduce_prod(discounts, axis=1)
                final_value_discount = tf.expand_dims(final_value_discount, -1)

                # Save the values of discounted_returns and final_value_discount in
                # order to check them in unit tests.
                self._discounted_returns = discounted_returns
                self._final_value_discount = final_value_discount

                target_support = tf.add(discounted_returns,
                                        final_value_discount * tiled_support,
                                        name='target_support')

            target_distribution = tf.stop_gradient(
                project_distribution(target_support, next_q_distribution,
                                     self._support))

            # Obtain the current Q-value logits for the selected actions.
            indices = tf.range(batch_size)
            indices = tf.cast(indices, actions.dtype)
            reshaped_actions = tf.stack([indices, actions], axis=-1)
            chosen_action_logits = tf.gather_nd(q_logits, reshaped_actions)

            # Compute the cross-entropy loss between the logits. If inputs have
            # a time dimension, compute the sum over the time dimension before
            # computing the mean over the batch dimension.
            if batch_squash is not None:
                target_distribution = batch_squash.unflatten(
                    target_distribution)
                chosen_action_logits = batch_squash.unflatten(
                    chosen_action_logits)
                critic_loss = tf.reduce_sum(
                    tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                        labels=target_distribution,
                        logits=chosen_action_logits),
                    axis=1)
            else:
                critic_loss = tf.compat.v1.nn.softmax_cross_entropy_with_logits_v2(
                    labels=target_distribution, logits=chosen_action_logits)

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                regularization_loss=self._q_network.losses)
            total_loss = agg_loss.total_loss

            dict_losses = {
                'critic_loss': agg_loss.weighted,
                'reg_loss': agg_loss.regularization,
                'total_loss': total_loss
            }

            common.summarize_scalar_dict(dict_losses,
                                         step=self.train_step_counter,
                                         name_scope='Losses/')

            if self._debug_summaries:
                distribution_errors = target_distribution - chosen_action_logits
                with tf.name_scope('distribution_errors'):
                    common.generate_tensor_summaries(
                        'distribution_errors',
                        distribution_errors,
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean',
                        tf.reduce_mean(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'mean_abs',
                        tf.reduce_mean(tf.abs(distribution_errors)),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'max',
                        tf.reduce_max(distribution_errors),
                        step=self.train_step_counter)
                    tf.compat.v2.summary.scalar(
                        'min',
                        tf.reduce_min(distribution_errors),
                        step=self.train_step_counter)
                with tf.name_scope('target_distribution'):
                    common.generate_tensor_summaries(
                        'target_distribution',
                        target_distribution,
                        step=self.train_step_counter)

            # TODO(b/127318640): Give appropriate values for td_loss and td_error for
            # prioritized replay.
            return tf_agent.LossInfo(
                total_loss, dqn_agent.DqnLossInfo(td_loss=(), td_error=()))
  def critic_loss(self,
                  time_steps,
                  actions,
                  next_time_steps,
                  augmented_obs,
                  augmented_next_obs,
                  td_errors_loss_fn,
                  gamma=1.0,
                  reward_scale_factor=1.0,
                  weights=None,
                  training=False):
    """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      augmented_obs: List of observations.
      augmented_next_obs: List of next_observations.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.

    Returns:
      critic_loss: A scalar critic loss.
    """
    with tf.name_scope('critic_loss'):
      nest_utils.assert_same_structure(actions, self.action_spec)
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
      nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)

      td_targets = self._compute_td_targets(next_time_steps,
                                            reward_scale_factor, gamma)

      # Compute td_targets with augmentations.
      for i in range(self._num_augmentations - 1):
        augmented_next_time_steps = next_time_steps._replace(
            observation=augmented_next_obs[i])

        augmented_td_targets = self._compute_td_targets(
            augmented_next_time_steps, reward_scale_factor, gamma)

        td_targets = td_targets + augmented_td_targets

      # Average td_target estimation over augmentations.
      if self._num_augmentations > 1:
        td_targets = td_targets / self._num_augmentations

      pred_td_targets1, pred_td_targets2, critic_loss = (
          self._compute_prediction_critic_loss(
              (time_steps.observation, actions), td_targets, time_steps,
              training, td_errors_loss_fn))

      # Add Q Augmentations to the critic loss.
      for i in range(self._num_augmentations - 1):
        augmented_time_steps = time_steps._replace(observation=augmented_obs[i])
        _, _, loss = (
            self._compute_prediction_critic_loss(
                (augmented_time_steps.observation, actions), td_targets,
                augmented_time_steps, training, td_errors_loss_fn))
        critic_loss = critic_loss + loss

      agg_loss = common.aggregate_losses(
          per_example_loss=critic_loss,
          sample_weight=weights,
          regularization_loss=(self._critic_network_1.losses +
                               self._critic_network_2.losses))
      critic_loss = agg_loss.total_loss

      self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                        pred_td_targets2)

      return critic_loss
Exemplo n.º 8
0
  def critic_loss(self,
                  time_steps,
                  expert_experience,
                  actions,
                  next_time_steps,
                  future_time_steps,
                  td_errors_loss_fn,
                  gamma = 1.0,
                  reward_scale_factor = 1.0,
                  weights = None,
                  training = False,
                  loss_name='c',
                  use_done=False,
                  q_combinator='min'):
    """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      expert_experience: An array of success examples.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      future_time_steps: A batch of future timesteps, used for n-step returns.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.
      loss_name: Which loss function to use. Use 'c' for RCE and 'q' for SQIL.
      use_done: Whether to use the terminal flag from the environment in the
        Bellman backup. We found that omitting it led to better results.
      q_combinator: Whether to combine the two Q-functions by taking the 'min'
        (as in TD3) or the 'max'.

    Returns:
      critic_loss: A scalar critic loss.
    """
    assert weights is None
    with tf.name_scope('critic_loss'):
      nest_utils.assert_same_structure(actions, self.action_spec)
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)
      nest_utils.assert_same_structure(next_time_steps, self.time_step_spec)

      next_actions, _ = self._actions_and_log_probs(next_time_steps)
      target_input = (next_time_steps.observation, next_actions)
      target_q_values1, unused_network_state1 = self._target_critic_network_1(
          target_input, next_time_steps.step_type, training=False)
      target_q_values2, unused_network_state2 = self._target_critic_network_2(
          target_input, next_time_steps.step_type, training=False)
      if self._n_step is not None:
        future_actions, _ = self._actions_and_log_probs(future_time_steps)
        future_input = (future_time_steps.observation, future_actions)
        future_q_values1, _ = self._target_critic_network_1(
            future_input, future_time_steps.step_type, training=False)
        future_q_values2, _ = self._target_critic_network_2(
            future_input, future_time_steps.step_type, training=False)

        gamma_n = gamma**self._n_step  # Discount for n-step returns
        target_q_values1 = (target_q_values1 + gamma_n * future_q_values1) / 2.0
        target_q_values2 = (target_q_values2 + gamma_n * future_q_values2) / 2.0

      if q_combinator == 'min':
        target_q_values = tf.minimum(target_q_values1, target_q_values2)
      else:
        assert q_combinator == 'max'
        target_q_values = tf.maximum(target_q_values1, target_q_values2)

      batch_size = time_steps.observation.shape[0]
      if loss_name == 'q':
        if use_done:
          td_targets = gamma * next_time_steps.discount * target_q_values
        else:
          td_targets = gamma * target_q_values
      else:
        assert loss_name == 'c'
        w = target_q_values / (1 - target_q_values)
        td_targets = gamma * w / (gamma * w + 1)
        if use_done:
          td_targets = next_time_steps.discount * td_targets
        weights = tf.concat([1 + gamma * w, (1 - gamma) * tf.ones(batch_size)],
                            axis=0)

      td_targets = tf.stop_gradient(td_targets)
      td_targets = tf.concat([td_targets, tf.ones(batch_size)], axis=0)

      # Note that the actions only depend on the observations. We create the
      # expert_time_steps object simply to make this look like a time step
      # object.
      expert_time_steps = time_steps._replace(observation=expert_experience)
      if self._use_behavior_policy:
        policy_state = self._train_policy.get_initial_state(batch_size)
        action_distribution = self._behavior_policy.distribution(
            time_steps, policy_state=policy_state).action
        # Sample actions and log_pis from transformed distribution.
        expert_actions = tf.nest.map_structure(lambda d: d.sample(),
                                               action_distribution)
      else:
        expert_actions, _ = self._actions_and_log_probs(expert_time_steps)

      observation = time_steps.observation
      pred_input = (tf.concat([observation, expert_experience], axis=0),
                    tf.concat([actions, expert_actions], axis=0))

      pred_td_targets1, _ = self._critic_network_1(
          pred_input, time_steps.step_type, training=training)
      pred_td_targets2, _ = self._critic_network_2(
          pred_input, time_steps.step_type, training=training)

      self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                        pred_td_targets2)

      critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
      critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
      critic_loss = critic_loss1 + critic_loss2

      if critic_loss.shape.rank > 1:
        # Sum over the time dimension.
        critic_loss = tf.reduce_sum(
            critic_loss, axis=range(1, critic_loss.shape.rank))

      agg_loss = common.aggregate_losses(
          per_example_loss=critic_loss,
          sample_weight=weights,
          regularization_loss=(self._critic_network_1.losses +
                               self._critic_network_2.losses))
      critic_loss = agg_loss.total_loss

      self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                        pred_td_targets2)

      return critic_loss
Exemplo n.º 9
0
    def run(self, trajectory, policy_state=None):
        """Apply the policy to trajectory steps and store actions/info.

    If `self.time_major == True`, the tensors in `trajectory` are assumed to
    have shape `[time, batch, ...]`.  Otherwise they are assumed to
    have shape `[batch, time, ...]`.

    Args:
      trajectory: The `Trajectory` to run against.
        If the replay class was created with `time_major=True`, then
        the tensors in trajectory must be shaped `[time, batch, ...]`.
        Otherwise they must be shaped `[batch, time, ...]`.
      policy_state: (optional) A nest Tensor with initial step policy state.

    Returns:
      output_actions: A nest of the actions that the policy took.
        If the replay class was created with `time_major=True`, then
        the tensors here will be shaped `[time, batch, ...]`.  Otherwise
        they'll be shaped `[batch, time, ...]`.
      output_policy_info: A nest of the policy info that the policy emitted.
        If the replay class was created with `time_major=True`, then
        the tensors here will be shaped `[time, batch, ...]`.  Otherwise
        they'll be shaped `[batch, time, ...]`.
      policy_state: A nest Tensor with final step policy state.

    Raises:
      TypeError: If `policy_state` structure doesn't match
        `self.policy.policy_state_spec`, or `trajectory` structure doesn't
        match `self.policy.trajectory_spec`.
      ValueError: If `policy_state` doesn't match
        `self.policy.policy_state_spec`, or `trajectory` structure doesn't
        match `self.policy.trajectory_spec`.
      ValueError: If `trajectory` lacks two outer dims.
    """
        trajectory_spec = self._policy.trajectory_spec
        outer_dims = nest_utils.get_outer_shape(trajectory, trajectory_spec)

        if tf.compat.dimension_value(outer_dims.shape[0]) != 2:
            raise ValueError(
                "Expected two outer dimensions, but saw '{}' dimensions.\n"
                "Trajectory:\n{}.\nTrajectory spec from policy:\n{}.".format(
                    tf.compat.dimension_value(outer_dims.shape[0]), trajectory,
                    trajectory_spec))
        if self._time_major:
            sequence_length = outer_dims[0]
            batch_size = outer_dims[1]
            static_batch_size = tf.compat.dimension_value(
                trajectory.discount.shape[1])
        else:
            batch_size = outer_dims[0]
            sequence_length = outer_dims[1]
            static_batch_size = tf.compat.dimension_value(
                trajectory.discount.shape[0])

        if policy_state is None:
            policy_state = self._policy.get_initial_state(batch_size)
        else:
            nest_utils.assert_same_structure(policy_state,
                                             self._policy.policy_state_spec)

        if not self._time_major:
            # Make trajectory time-major.
            trajectory = tf.nest.map_structure(common.transpose_batch_time,
                                               trajectory)

        trajectory_tas = tf.nest.map_structure(
            lambda t: tf.TensorArray(t.dtype, size=sequence_length).unstack(t),
            trajectory)

        def create_output_ta(spec):
            return tf.TensorArray(spec.dtype,
                                  size=sequence_length,
                                  element_shape=(tf.TensorShape([
                                      static_batch_size
                                  ]).concatenate(spec.shape)))

        output_action_tas = tf.nest.map_structure(create_output_ta,
                                                  trajectory_spec.action)
        output_policy_info_tas = tf.nest.map_structure(
            create_output_ta, trajectory_spec.policy_info)

        read0 = lambda ta: ta.read(0)
        zeros_like0 = lambda t: tf.zeros_like(t[0])
        ones_like0 = lambda t: tf.ones_like(t[0])
        time_step = ts.TimeStep(
            step_type=read0(trajectory_tas.step_type),
            reward=tf.nest.map_structure(zeros_like0, trajectory.reward),
            discount=ones_like0(trajectory.discount),
            observation=tf.nest.map_structure(read0,
                                              trajectory_tas.observation))

        def process_step(time, time_step, policy_state, output_action_tas,
                         output_policy_info_tas):
            """Take an action on the given step, and update output TensorArrays.

      Args:
        time: Step time.  Describes which row to read from the trajectory
          TensorArrays and which location to write into in the output
          TensorArrays.
        time_step: Previous step's `TimeStep`.
        policy_state: Policy state tensor or nested structure of tensors.
        output_action_tas: Nest of `tf.TensorArray` containing new actions.
        output_policy_info_tas: Nest of `tf.TensorArray` containing new
          policy info.

      Returns:
        policy_state: The next policy state.
        next_output_action_tas: Updated `output_action_tas`.
        next_output_policy_info_tas: Updated `output_policy_info_tas`.
      """
            action_step = self._policy.action(time_step, policy_state)
            policy_state = action_step.state
            write_ta = lambda ta, t: ta.write(time - 1, t)
            next_output_action_tas = tf.nest.map_structure(
                write_ta, output_action_tas, action_step.action)
            next_output_policy_info_tas = tf.nest.map_structure(
                write_ta, output_policy_info_tas, action_step.info)

            return (action_step.state, next_output_action_tas,
                    next_output_policy_info_tas)

        def loop_body(time, time_step, policy_state, output_action_tas,
                      output_policy_info_tas):
            """Runs a step in environment.

      While loop will call multiple times.

      Args:
        time: Step time.
        time_step: Previous step's `TimeStep`.
        policy_state: Policy state tensor or nested structure of tensors.
        output_action_tas: Updated nest of `tf.TensorArray`, the new actions.
        output_policy_info_tas: Updated nest of `tf.TensorArray`, the new
          policy info.

      Returns:
        loop_vars for next iteration of tf.while_loop.
      """
            policy_state, next_output_action_tas, next_output_policy_info_tas = (
                process_step(time, time_step, policy_state, output_action_tas,
                             output_policy_info_tas))

            ta_read = lambda ta: ta.read(time)
            ta_read_prev = lambda ta: ta.read(time - 1)
            time_step = ts.TimeStep(
                step_type=ta_read(trajectory_tas.step_type),
                observation=tf.nest.map_structure(ta_read,
                                                  trajectory_tas.observation),
                reward=tf.nest.map_structure(ta_read_prev,
                                             trajectory_tas.reward),
                discount=ta_read_prev(trajectory_tas.discount))

            return (time + 1, time_step, policy_state, next_output_action_tas,
                    next_output_policy_info_tas)

        time = tf.constant(1)
        time, time_step, policy_state, output_action_tas, output_policy_info_tas = (
            tf.while_loop(cond=lambda time, *_: time < sequence_length,
                          body=loop_body,
                          loop_vars=[
                              time, time_step, policy_state, output_action_tas,
                              output_policy_info_tas
                          ],
                          back_prop=False,
                          name="trajectory_replay_loop"))

        # Run the last time step
        last_policy_state, output_action_tas, output_policy_info_tas = (
            process_step(time, time_step, policy_state, output_action_tas,
                         output_policy_info_tas))

        def stack_ta(ta):
            t = ta.stack()
            if not self._time_major:
                t = common.transpose_batch_time(t)
            return t

        stacked_output_actions = tf.nest.map_structure(stack_ta,
                                                       output_action_tas)
        stacked_output_policy_info = tf.nest.map_structure(
            stack_ta, output_policy_info_tas)

        return (stacked_output_actions, stacked_output_policy_info,
                last_policy_state)
Exemplo n.º 10
0
    def critic_loss(self,
                    time_steps: ts.TimeStep,
                    actions: types.Tensor,
                    next_time_steps: ts.TimeStep,
                    td_errors_loss_fn: types.LossFn,
                    gamma: types.Float = 1.0,
                    reward_scale_factor: types.Float = 1.0,
                    weights: Optional[types.Tensor] = None,
                    training: bool = False) -> types.Tensor:
        """Computes the critic loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.

    Returns:
      critic_loss: A scalar critic loss.
    """
        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            next_actions, next_log_pis = self._actions_and_log_probs(
                next_time_steps)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = (tf.minimum(target_q_values1, target_q_values2) -
                               tf.exp(self._log_alpha) * next_log_pis)

            td_targets = tf.stop_gradient(
                reward_scale_factor * next_time_steps.reward +
                gamma * next_time_steps.discount * target_q_values)

            pred_input = (time_steps.observation, actions)
            pred_td_targets1, _ = self._critic_network_1(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            pred_td_targets2, _ = self._critic_network_2(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if critic_loss.shape.rank > 1:
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(critic_loss,
                                            axis=range(1,
                                                       critic_loss.shape.rank))

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                sample_weight=weights,
                regularization_loss=(self._critic_network_1.losses +
                                     self._critic_network_2.losses))
            critic_loss = agg_loss.total_loss

            self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                              pred_td_targets2)

            return critic_loss
Exemplo n.º 11
0
    def total_loss(self,
                   experience: traj.Trajectory,
                   returns: types.Tensor,
                   weights: types.Tensor,
                   training: bool = False) -> tf_agent.LossInfo:
        # Ensure we see at least one full episode.
        time_steps = ts.TimeStep(experience.step_type,
                                 tf.zeros_like(experience.reward),
                                 tf.zeros_like(experience.discount),
                                 experience.observation)
        is_last = experience.is_last()
        num_episodes = tf.reduce_sum(tf.cast(is_last, tf.float32))
        tf.debugging.assert_greater(
            num_episodes,
            0.0,
            message=
            'No complete episode found. REINFORCE requires full episodes '
            'to compute losses.')

        # Mask out partial episodes at the end of each batch of time_steps.
        # NOTE: We use is_last rather than is_boundary because the last transition
        # is the transition with the last valid reward.  In other words, the
        # reward on the boundary transitions do not have valid rewards.  Since
        # REINFORCE is calculating a loss w.r.t. the returns (and not bootstrapping)
        # keeping the boundary transitions is irrelevant.
        valid_mask = tf.cast(experience.is_last(), dtype=tf.float32)
        valid_mask = tf.math.cumsum(valid_mask, axis=1, reverse=True)
        valid_mask = tf.cast(valid_mask > 0, dtype=tf.float32)
        if weights is not None:
            weights *= valid_mask
        else:
            weights = valid_mask

        advantages = returns
        value_preds = None

        if self._baseline:
            value_preds, _ = self._value_network(time_steps.observation,
                                                 time_steps.step_type,
                                                 training=True)
            if self._debug_summaries:
                tf.compat.v2.summary.histogram(name='value_preds',
                                               data=value_preds,
                                               step=self.train_step_counter)

        advantages = self._advantage_fn(returns, value_preds)
        if self._debug_summaries:
            tf.compat.v2.summary.histogram(name='advantages',
                                           data=advantages,
                                           step=self.train_step_counter)

        # TODO(b/126592060): replace with tensor normalizer.
        if self._normalize_returns:
            advantages = _standard_normalize(advantages, axes=(0, 1))
            if self._debug_summaries:
                tf.compat.v2.summary.histogram(
                    name='normalized_%s' %
                    ('advantages' if self._baseline else 'returns'),
                    data=advantages,
                    step=self.train_step_counter)

        nest_utils.assert_same_structure(time_steps, self.time_step_spec)
        policy_state = _get_initial_policy_state(self.collect_policy,
                                                 time_steps)
        actions_distribution = self.collect_policy.distribution(
            time_steps, policy_state=policy_state).action

        policy_gradient_loss = self.policy_gradient_loss(
            actions_distribution,
            experience.action,
            experience.is_boundary(),
            advantages,
            num_episodes,
            weights,
        )

        entropy_regularization_loss = self.entropy_regularization_loss(
            actions_distribution, weights)

        network_regularization_loss = tf.nn.scale_regularization_loss(
            self._actor_network.losses)

        total_loss = (policy_gradient_loss + network_regularization_loss +
                      entropy_regularization_loss)

        losses_dict = {
            'policy_gradient_loss': policy_gradient_loss,
            'policy_network_regularization_loss': network_regularization_loss,
            'entropy_regularization_loss': entropy_regularization_loss,
            'value_estimation_loss': 0.0,
            'value_network_regularization_loss': 0.0,
        }

        value_estimation_loss = None
        if self._baseline:
            value_estimation_loss = self.value_estimation_loss(
                value_preds, returns, num_episodes, weights)
            value_network_regularization_loss = tf.nn.scale_regularization_loss(
                self._value_network.losses)
            total_loss += value_estimation_loss + value_network_regularization_loss
            losses_dict['value_estimation_loss'] = value_estimation_loss
            losses_dict['value_network_regularization_loss'] = (
                value_network_regularization_loss)

        loss_info_extra = ReinforceAgentLossInfo(**losses_dict)

        losses_dict[
            'total_loss'] = total_loss  # Total loss not in loss_info_extra.

        common.summarize_scalar_dict(losses_dict,
                                     self.train_step_counter,
                                     name_scope='Losses/')

        return tf_agent.LossInfo(total_loss, loss_info_extra)
Exemplo n.º 12
0
    def critic_loss(
        self,
        time_steps,
        actions,
        next_time_steps,
        td_errors_loss_fn,
        gamma=1.0,
        weights=None,
        training=False,
        w_clipping=None,
        self_normalized=False,
        lambda_fix=False,
    ):
        """Computes the critic loss for C-learning training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.
      w_clipping: Maximum value used for clipping the weights. Use -1 to do no
        clipping; use None to use the recommended value of 1 / (1 - gamma).
      self_normalized: Whether to normalize the weights to the average is 1.
        Empirically this usually hurts performance.
      lambda_fix: Whether to include the adjustment when using future positives.
        Empirically this has little effect.

    Returns:
      critic_loss: A scalar critic loss.
    """
        del weights
        if w_clipping is None:
            w_clipping = 1 / (1 - gamma)
        rfp = gin.query_parameter('goal_fn.relabel_future_prob')
        rnp = gin.query_parameter('goal_fn.relabel_next_prob')
        assert rfp + rnp == 0.5
        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            next_actions, _ = self._actions_and_log_probs(next_time_steps)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)

            w = tf.stop_gradient(target_q_values / (1 - target_q_values))
            if w_clipping >= 0:
                w = tf.clip_by_value(w, 0, w_clipping)
            tf.debugging.assert_all_finite(w,
                                           'Not all elements of w are finite')
            if self_normalized:
                w = w / tf.reduce_mean(w)

            batch_size = nest_utils.get_outer_shape(time_steps,
                                                    self._time_step_spec)[0]
            half_batch = batch_size // 2
            float_batch_size = tf.cast(batch_size, float)
            num_next = tf.cast(tf.round(float_batch_size * rnp), tf.int32)
            num_future = tf.cast(tf.round(float_batch_size * rfp), tf.int32)
            if lambda_fix:
                lambda_coef = 2 * rnp
                weights = tf.concat([
                    tf.fill((num_next, ), (1 - gamma)),
                    tf.fill((num_future, ), 1.0),
                    (1 + lambda_coef * gamma * w)[half_batch:]
                ],
                                    axis=0)
            else:
                weights = tf.concat([
                    tf.fill((num_next, ), (1 - gamma)),
                    tf.fill((num_future, ), 1.0), (1 + gamma * w)[half_batch:]
                ],
                                    axis=0)

            # Note that we assume that episodes never terminate. If they do, then
            # we need to include next_time_steps.discount in the (negative) TD target.
            # We exclude the termination here so that we can use termination to
            # indicate task success during evaluation. In the evaluation setting,
            # task success depends on the task, but we don't want the termination
            # here to depend on the task. Hence, we ignored it.
            if lambda_fix:
                lambda_coef = 2 * rnp
                y = lambda_coef * gamma * w / (1 + lambda_coef * gamma * w)
            else:
                y = gamma * w / (1 + gamma * w)
            td_targets = tf.stop_gradient(next_time_steps.reward +
                                          (1 - next_time_steps.reward) * y)
            if rfp > 0:
                td_targets = tf.concat(
                    [tf.ones(half_batch), td_targets[half_batch:]], axis=0)

            observation = time_steps.observation
            pred_input = (observation, actions)
            pred_td_targets1, _ = self._critic_network_1(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            pred_td_targets2, _ = self._critic_network_2(pred_input,
                                                         time_steps.step_type,
                                                         training=training)

            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if critic_loss.shape.rank > 1:
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(critic_loss,
                                            axis=range(1,
                                                       critic_loss.shape.rank))

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                sample_weight=weights,
                regularization_loss=(self._critic_network_1.losses +
                                     self._critic_network_2.losses))
            critic_loss = agg_loss.total_loss
            self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                              pred_td_targets2, weights)

            return critic_loss
Exemplo n.º 13
0
    def _critic_loss_with_optional_entropy_term(
            self,
            time_steps: ts.TimeStep,
            actions: types.Tensor,
            next_time_steps: ts.TimeStep,
            td_errors_loss_fn: types.LossFn,
            gamma: types.Float = 1.0,
            reward_scale_factor: types.Float = 1.0,
            weights: Optional[types.Tensor] = None,
            training: bool = False) -> types.Tensor:
        r"""Computes the critic loss for CQL-SAC training.

    The original SAC critic loss is:
    ```
    (q(s, a) - (r(s, a) + \gamma q(s', a') - \gamma \alpha \log \pi(a'|s')))^2
    ```

    The CQL-SAC critic loss makes the entropy term optional.
    CQL may value unseen actions higher since it lower-bounds the value of
    seen actions. This makes the policy entropy potentially redundant in the
    target term, since it will further enhance unseen actions' effects.

    If self._include_critic_entropy_term is False, this loss equation becomes:
    ```
    (q(s, a) - (r(s, a) + \gamma q(s', a')))^2
    ```

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      td_errors_loss_fn: A function(td_targets, predictions) to compute
        elementwise (per-batch-entry) loss.
      gamma: Discount for future rewards.
      reward_scale_factor: Multiplicative factor to scale rewards.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.
      training: Whether this loss is being used for training.

    Returns:
      critic_loss: A scalar critic loss.
    """
        with tf.name_scope('critic_loss'):
            nest_utils.assert_same_structure(actions, self.action_spec)
            nest_utils.assert_same_structure(time_steps, self.time_step_spec)
            nest_utils.assert_same_structure(next_time_steps,
                                             self.time_step_spec)

            # We do not update actor or target networks in critic loss.
            next_actions, next_log_pis = self._actions_and_log_probs(
                next_time_steps, training=False)
            target_input = (next_time_steps.observation, next_actions)
            target_q_values1, unused_network_state1 = self._target_critic_network_1(
                target_input, next_time_steps.step_type, training=False)
            target_q_values2, unused_network_state2 = self._target_critic_network_2(
                target_input, next_time_steps.step_type, training=False)
            target_q_values = tf.minimum(target_q_values1, target_q_values2)

            if self._include_critic_entropy_term:
                target_q_values -= (tf.exp(self._log_alpha) * next_log_pis)

            reward = next_time_steps.reward
            if self._reward_noise_variance > 0:
                reward_noise = tf.random.normal(
                    tf.shape(reward),
                    0.0,
                    self._reward_noise_variance,
                    seed=self._reward_seed_stream())
                reward += reward_noise

            td_targets = tf.stop_gradient(reward_scale_factor * reward +
                                          gamma * next_time_steps.discount *
                                          target_q_values)

            pred_input = (time_steps.observation, actions)
            pred_td_targets1, _ = self._critic_network_1(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            pred_td_targets2, _ = self._critic_network_2(pred_input,
                                                         time_steps.step_type,
                                                         training=training)
            critic_loss1 = td_errors_loss_fn(td_targets, pred_td_targets1)
            critic_loss2 = td_errors_loss_fn(td_targets, pred_td_targets2)
            critic_loss = critic_loss1 + critic_loss2

            if critic_loss.shape.rank > 1:
                # Sum over the time dimension.
                critic_loss = tf.reduce_sum(critic_loss,
                                            axis=range(1,
                                                       critic_loss.shape.rank))

            agg_loss = common.aggregate_losses(
                per_example_loss=critic_loss,
                sample_weight=weights,
                regularization_loss=(self._critic_network_1.losses +
                                     self._critic_network_2.losses))
            critic_loss = agg_loss.total_loss

            self._critic_loss_debug_summaries(td_targets, pred_td_targets1,
                                              pred_td_targets2)

            return critic_loss
Exemplo n.º 14
0
    def __init__(self,
                 nested_layers: types.NestedLayer,
                 input_spec: typing.Optional[types.NestedTensorSpec] = None,
                 name: typing.Optional[typing.Text] = None):
        """Create a Sequential Network.

    Args:
      nested_layers: A nest of layers and/or networks.  These will be used
        to process the inputs (input nest structure will have to match this
        structure).  Any layers that are subclasses of
        `tf.keras.layers.{RNN,LSTM,GRU,...}` are wrapped in
        `tf_agents.keras_layers.RNNWrapper`.
      input_spec: (Optional.)  A nest of `tf.TypeSpec` representing the
        input observations.  The structure of `input_spec` must match
        that of `nested_layers`.
      name: (Optional.) Network name.

    Raises:
      TypeError: If any of the layers are not instances of keras `Layer`.
      ValueError: If `input_spec` is provided but its nest structure does
        not match that of `nested_layers`.
      RuntimeError: If not `tf.executing_eagerly()`; as this is required to
        be able to create deep copies of layers in `layers`.
    """
        if not tf.executing_eagerly():
            raise RuntimeError(
                'Not executing eagerly - cannot make deep copies of `nested_layers`.'
            )

        flat_nested_layers = tf.nest.flatten(nested_layers)
        for layer in flat_nested_layers:
            if not isinstance(layer, tf.keras.layers.Layer):
                raise TypeError(
                    'Expected all layers to be instances of keras Layer, but saw'
                    ': \'{}\''.format(layer))

        if input_spec is not None:
            nest_utils.assert_same_structure(
                nested_layers,
                input_spec,
                message=
                ('`nested_layers` and `input_spec` do not have matching structures'
                 ))
            flat_input_spec = tf.nest.flatten(input_spec)
        else:
            flat_input_spec = [None] * len(flat_nested_layers)

        # Wrap in Sequential if necessary.
        flat_nested_layers = [
            sequential.Sequential([m], s)
            if not isinstance(m, network.Network) else m
            for (s, m) in zip(flat_input_spec, flat_nested_layers)
        ]

        flat_nested_layers_state_specs = [
            m.state_spec for m in flat_nested_layers
        ]
        nested_layers = tf.nest.pack_sequence_as(nested_layers,
                                                 flat_nested_layers)
        # We use flattened layers and states here instead of tf.nest.map_structure
        # for several reason.  One is that we perform several operations against
        # the layers and we want to avoid calling into tf.nest.map* multiple times.
        # But the main reason is that network states have a different *structure*
        # than the layers; e.g., `nested_layers` may just be tf.keras.layers.LSTM,
        # but the states would then have structure `[.,.]`.  Passing these in
        # as args to tf.nest.map_structure causes it to fail.  Instead we would
        # have to use nest.map_structure_up_to -- but that function is not part
        # of the public TF API.  However, if we do everything in flatland and then
        # use pack_sequence_as, we bypass the more rigid structure tests.
        state_spec = tf.nest.pack_sequence_as(nested_layers,
                                              flat_nested_layers_state_specs)

        super(NestMap, self).__init__(input_tensor_spec=input_spec,
                                      state_spec=state_spec,
                                      name=name)
        self._nested_layers = nested_layers
Exemplo n.º 15
0
  def actor_loss(self,
                 time_steps,
                 actions,
                 next_time_steps,
                 weights=None):
    """Computes the actor_loss for SAC training.

    Args:
      time_steps: A batch of timesteps.
      actions: A batch of actions.
      next_time_steps: A batch of next timesteps.
      weights: Optional scalar or elementwise (per-batch-entry) importance
        weights.

    Returns:
      actor_loss: A scalar actor loss.
    """
    prev_time_steps, prev_actions, time_steps = time_steps, actions, next_time_steps  # pylint: disable=line-too-long
    with tf.name_scope('actor_loss'):
      nest_utils.assert_same_structure(time_steps, self.time_step_spec)

      actions, log_pi = self._actions_and_log_probs(time_steps)
      target_input = (time_steps.observation, actions)
      target_q_values1, _ = self._critic_network_1(
          target_input, step_type=time_steps.step_type, training=False)
      target_q_values2, _ = self._critic_network_2(
          target_input, step_type=time_steps.step_type, training=False)
      target_q_values = tf.minimum(target_q_values1, target_q_values2)
      actor_loss = tf.exp(self._log_alpha) * log_pi - target_q_values

      ### Flatten time dimension. We'll add it back when adding the loss.
      num_outer_dims = nest_utils.get_outer_rank(time_steps,
                                                 self.time_step_spec)
      has_time_dim = (num_outer_dims == 2)
      if has_time_dim:
        batch_squash = utils.BatchSquash(2)  # Squash B, and T dims.
        obs = batch_squash.flatten(time_steps.observation)
        prev_obs = batch_squash.flatten(prev_time_steps.observation)
        prev_actions = batch_squash.flatten(prev_actions)
      else:
        obs = time_steps.observation
        prev_obs = prev_time_steps.observation
      z = self._actor_network._z_encoder(obs, training=True)  # pylint: disable=protected-access
      prior = self._actor_network._predictor((prev_obs, prev_actions),  # pylint: disable=protected-access
                                             training=True)

      # kl is a vector of length batch_size, which has already been summed over
      # the latent dimension z.
      kl = tfp.distributions.kl_divergence(z, prior)
      if has_time_dim:
        kl = batch_squash.unflatten(kl)

      kl_coef = tf.stop_gradient(
          tf.exp(self._actor_network._log_kl_coefficient))  # pylint: disable=protected-access
      # The actor loss trains both the predictor and the encoder.
      actor_loss += kl_coef * kl

      if actor_loss.shape.rank > 1:
        # Sum over the time dimension.
        actor_loss = tf.reduce_sum(
            actor_loss, axis=range(1, actor_loss.shape.rank))
      reg_loss = self._actor_network.losses if self._actor_network else None
      agg_loss = common.aggregate_losses(
          per_example_loss=actor_loss,
          sample_weight=weights,
          regularization_loss=reg_loss)
      actor_loss = agg_loss.total_loss
      self._actor_loss_debug_summaries(actor_loss, actions, log_pi,
                                       target_q_values, time_steps)
      tf.compat.v2.summary.scalar(
          name='encoder_kl',
          data=tf.reduce_mean(kl),
          step=self.train_step_counter)

      return actor_loss
Exemplo n.º 16
0
    def __call__(self, inputs, *args, **kwargs):
        """A wrapper around `Network.call`.

    A typical `call` method in a class subclassing `Network` will have a
    signature that accepts `inputs`, as well as other `*args` and `**kwargs`.
    `call` can optionally also accept `step_type` and `network_state`
    (if `state_spec != ()` is not trivial).  e.g.:

    ```python
    def call(self,
             inputs,
             step_type=None,
             network_state=(),
             training=False):
        ...
        return outputs, new_network_state
    ```

    We will validate the first argument (`inputs`)
    against `self.input_tensor_spec` if one is available.

    If a `network_state` kwarg is given it is also validated against
    `self.state_spec`.  Similarly, the return value of the `call` method is
    expected to be a tuple/list with 2 values:  `(output, new_state)`.
    We validate `new_state` against `self.state_spec`.

    If no `network_state` kwarg is given (or if empty `network_state = ()` is
    given, it is up to `call` to assume a proper "empty" state, and to
    emit an appropriate `output_state`.

    Args:
      inputs: The input to `self.call`, matching `self.input_tensor_spec`.
      *args: Additional arguments to `self.call`.
      **kwargs: Additional keyword arguments to `self.call`.
        These can include `network_state` and `step_type`.  `step_type` is
        required if the network's `call` requires it. `network_state` is
        required if the underlying network's `call` requires it.

    Returns:
      A tuple `(outputs, new_network_state)`.
    """
        if self.input_tensor_spec is not None:
            nest_utils.assert_same_structure(
                inputs,
                self.input_tensor_spec,
                message="inputs and input_tensor_spec structures do not match")
        call_argspec = tf_inspect.getargspec(self.call)

        # Convert *args, **kwargs to a canonical kwarg representation.
        normalized_kwargs = tf_inspect.getcallargs(self.call, inputs, *args,
                                                   **kwargs)
        # TODO(b/156315434): Rename network_state to just state.
        network_state = normalized_kwargs.get("network_state", None)
        normalized_kwargs.pop("self", None)

        if network_state not in (None, ()):
            nest_utils.assert_same_structure(
                network_state,
                self.state_spec,
                message="network_state and state_spec structures do not match")

        if "step_type" not in call_argspec.args and not call_argspec.keywords:
            normalized_kwargs.pop("step_type", None)

        if (network_state in (None, ())
                and "network_state" not in call_argspec.args
                and not call_argspec.keywords):
            normalized_kwargs.pop("network_state", None)

        outputs, new_state = super(Network, self).__call__(**normalized_kwargs)
        nest_utils.assert_same_structure(
            new_state,
            self.state_spec,
            message=
            "network output state and state_spec structures do not match")

        return outputs, new_state