def testSimple(self):
     converter = data_converter.AsTransition(self._data_context,
                                             squeeze_time_dim=True)
     transition = tensor_spec.sample_spec_nest(
         self._data_context.transition_spec, outer_dims=[2])
     converted = converter(transition)
     (transition, converted) = self.evaluate((transition, converted))
     tf.nest.map_structure(self.assertAllEqual, converted, transition)
 def testTrajectoryNotSingleStepTransition(self):
     converter = data_converter.AsTransition(self._data_context)
     traj = tensor_spec.sample_spec_nest(self._data_context.trajectory_spec,
                                         outer_dims=[2, 3])
     converted = converter(traj)
     expected = trajectory.to_transition(traj)
     (expected, converted) = self.evaluate((expected, converted))
     tf.nest.map_structure(self.assertAllEqual, converted, expected)
 def testTrajectoryInvalidTimeDimensionRaises(self):
     converter = data_converter.AsTransition(self._data_context,
                                             squeeze_time_dim=True)
     traj = tensor_spec.sample_spec_nest(self._data_context.trajectory_spec,
                                         outer_dims=[2, 3])
     with self.assertRaisesRegex(
             ValueError,
             r'has a time axis dim value \'3\' vs the expected \'2\''):
         converter(traj)
예제 #4
0
 def testValidateTransitionWithState(self):
     converter = data_converter.AsTransition(self._data_context_with_state,
                                             squeeze_time_dim=False)
     transition = tensor_spec.sample_spec_nest(
         self._data_context_with_state.transition_spec, outer_dims=[1, 2])
     pruned_action_step = transition.action_step._replace(
         state=tf.nest.map_structure(lambda t: t[:, 0, ...],
                                     transition.action_step.state))
     transition = transition._replace(action_step=pruned_action_step)
     converted = converter(transition)
     (transition, converted) = self.evaluate((transition, converted))
 def testFromBatchTimeTrajectory(self):
     converter = data_converter.AsTransition(self._data_context,
                                             squeeze_time_dim=True)
     traj = tensor_spec.sample_spec_nest(self._data_context.trajectory_spec,
                                         outer_dims=[4, 2])  # [B, T=2]
     converted = converter(traj)
     expected = trajectory.to_transition(traj)
     # Remove the now-singleton time dim.
     expected = tf.nest.map_structure(lambda x: tf.squeeze(x, 1), expected)
     (expected, converted) = self.evaluate((expected, converted))
     tf.nest.map_structure(self.assertAllEqual, converted, expected)
예제 #6
0
 def _setup_data_converter(self, q_network, gamma, n_step_update):
     if q_network.state_spec:
         # AsNStepTransition does not support emitting [B, T, ...] tensors,
         # which we need for DQN-RNN.
         self._as_transition = data_converter.AsTransition(
             self.data_context, squeeze_time_dim=False)
     else:
         # This reduces the n-step return and removes the extra time dimension,
         # allowing the rest of the computations to be independent of the
         # n-step parameter.
         self._as_transition = data_converter.AsNStepTransition(
             self.data_context, gamma=gamma, n=n_step_update)
 def testPrunes(self):
     converter = data_converter.AsTransition(self._data_context,
                                             squeeze_time_dim=True)
     my_spec = self._data_context.transition_spec.replace(
         action_step=self._data_context.transition_spec.action_step.replace(
             action={
                 'action1': tf.TensorSpec((), tf.float32),
                 'action2': tf.TensorSpec([4], tf.int32)
             }))
     transition = tensor_spec.sample_spec_nest(my_spec, outer_dims=[2])
     converted = converter(transition)
     expected = tf.nest.map_structure(lambda x: x, transition)
     del expected.action_step.action['action2']
     (expected, converted) = self.evaluate((expected, converted))
     tf.nest.map_structure(self.assertAllEqual, converted, expected)
예제 #8
0
    def _setup_data_converter(self, q_network, gamma, n_step_update):
        if q_network.state_spec:
            if not self._in_graph_bellman_update:
                self._data_context = data_converter.DataContext(
                    time_step_spec=self._time_step_spec,
                    action_spec=self._action_spec,
                    info_spec=self._collect_policy.info_spec,
                    policy_state_spec=self._q_network.state_spec,
                    use_half_transition=True)
                self._as_transition = data_converter.AsHalfTransition(
                    self.data_context, squeeze_time_dim=False)
            else:
                self._data_context = data_converter.DataContext(
                    time_step_spec=self._time_step_spec,
                    action_spec=self._action_spec,
                    info_spec=self._collect_policy.info_spec,
                    policy_state_spec=self._q_network.state_spec,
                    use_half_transition=False)
                self._as_transition = data_converter.AsTransition(
                    self.data_context,
                    squeeze_time_dim=False,
                    prepend_t0_to_next_time_step=True)
        else:
            if not self._in_graph_bellman_update:
                self._data_context = data_converter.DataContext(
                    time_step_spec=self._time_step_spec,
                    action_spec=self._action_spec,
                    info_spec=self._collect_policy.info_spec,
                    policy_state_spec=self._q_network.state_spec,
                    use_half_transition=True)

                self._as_transition = data_converter.AsHalfTransition(
                    self.data_context, squeeze_time_dim=True)
            else:
                # This reduces the n-step return and removes the extra time dimension,
                # allowing the rest of the computations to be independent of the
                # n-step parameter.
                self._as_transition = data_converter.AsNStepTransition(
                    self.data_context, gamma=gamma, n=n_step_update)
예제 #9
0
    def __init__(
            self,
            time_step_spec: ts.TimeStep,
            action_spec: types.NestedTensorSpec,
            q_network: network.Network,
            optimizer: types.Optimizer,
            observation_and_action_constraint_splitter: Optional[
                types.Splitter] = None,
            epsilon_greedy: types.Float = 0.1,
            n_step_update: int = 1,
            boltzmann_temperature: Optional[types.Int] = None,
            emit_log_probability: bool = False,
            # Params for target network updates
            target_q_network: Optional[network.Network] = None,
            target_update_tau: types.Float = 1.0,
            target_update_period: int = 1,
            # Params for training.
            td_errors_loss_fn: Optional[types.LossFn] = None,
            gamma: types.Float = 1.0,
            reward_scale_factor: types.Float = 1.0,
            gradient_clipping: Optional[types.Float] = None,
            # Params for debugging
            debug_summaries: bool = False,
            summarize_grads_and_vars: bool = False,
            train_step_counter: Optional[tf.Variable] = None,
            name: Optional[Text] = None,
            entropy_tau: types.Float = 0.9,
            alpha: types.Float = 0.3):

        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        if epsilon_greedy is not None and boltzmann_temperature is not None:
            raise ValueError(
                'Configured both epsilon_greedy value {} and temperature {}, '
                'however only one of them can be used for exploration.'.format(
                    epsilon_greedy, boltzmann_temperature))

        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
        self._q_network = q_network
        net_observation_spec = time_step_spec.observation
        if observation_and_action_constraint_splitter:
            net_observation_spec, _ = observation_and_action_constraint_splitter(
                net_observation_spec)
        q_network.create_variables(net_observation_spec)
        if target_q_network:
            target_q_network.create_variables(net_observation_spec)
        self._target_q_network = common.maybe_copy_target_network_with_checks(
            self._q_network,
            target_q_network,
            input_spec=net_observation_spec,
            name='TargetQNetwork')

        self._check_network_output(self._q_network, 'q_network')
        self._check_network_output(self._target_q_network, 'target_q_network')

        self._epsilon_greedy = epsilon_greedy
        self._n_step_update = n_step_update
        self._boltzmann_temperature = boltzmann_temperature
        self._optimizer = optimizer
        self._td_errors_loss_fn = (td_errors_loss_fn
                                   or common.element_wise_huber_loss)
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._gradient_clipping = gradient_clipping
        self._update_target = self._get_target_updater(target_update_tau,
                                                       target_update_period)
        self.entropy_tau = entropy_tau
        self.alpha = alpha

        policy, collect_policy = self._setup_policy(time_step_spec,
                                                    action_spec,
                                                    boltzmann_temperature,
                                                    emit_log_probability)

        if q_network.state_spec and n_step_update != 1:
            raise NotImplementedError(
                'DqnAgent does not currently support n-step updates with stateful '
                'networks (i.e., RNNs), but n_step_update = {}'.format(
                    n_step_update))

        train_sequence_length = (n_step_update +
                                 1 if not q_network.state_spec else None)

        super(dqn_agent.DqnAgent, self).__init__(
            time_step_spec,
            action_spec,
            policy,
            collect_policy,
            train_sequence_length=train_sequence_length,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step_counter,
            validate_args=False,
        )

        if q_network.state_spec:
            # AsNStepTransition does not support emitting [B, T, ...] tensors,
            # which we need for DQN-RNN.
            self._as_transition = data_converter.AsTransition(
                self.data_context, squeeze_time_dim=False)
        else:
            # This reduces the n-step return and removes the extra time dimension,
            # allowing the rest of the computations to be independent of the
            # n-step parameter.
            self._as_transition = data_converter.AsNStepTransition(
                self.data_context, gamma=gamma, n=n_step_update)
예제 #10
0
파일: dqn_agent.py 프로젝트: wuzh07/agents
  def __init__(
      self,
      time_step_spec: ts.TimeStep,
      action_spec: types.NestedTensorSpec,
      q_network: network.Network,
      optimizer: types.Optimizer,
      observation_and_action_constraint_splitter: Optional[
          types.Splitter] = None,
      epsilon_greedy: types.Float = 0.1,
      n_step_update: int = 1,
      boltzmann_temperature: Optional[types.Int] = None,
      emit_log_probability: bool = False,
      # Params for target network updates
      target_q_network: Optional[network.Network] = None,
      target_update_tau: types.Float = 1.0,
      target_update_period: int = 1,
      # Params for training.
      td_errors_loss_fn: Optional[types.LossFn] = None,
      gamma: types.Float = 1.0,
      reward_scale_factor: types.Float = 1.0,
      gradient_clipping: Optional[types.Float] = None,
      # Params for debugging
      debug_summaries: bool = False,
      summarize_grads_and_vars: bool = False,
      train_step_counter: Optional[tf.Variable] = None,
      name: Optional[Text] = None):
    """Creates a DQN Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A `tf_agents.network.Network` to be used by the agent. The
        network will be called with `call(observation, step_type)` and should
        emit logits over the action space.
      optimizer: The optimizer to use for training.
      observation_and_action_constraint_splitter: A function used to process
        observations with action constraints. These constraints can indicate,
        for example, a mask of valid/invalid actions for a given state of the
        environment.
        The function takes in a full observation and returns a tuple consisting
        of 1) the part of the observation intended as input to the network and
        2) the constraint. An example
        `observation_and_action_constraint_splitter` could be as simple as:
        ```
        def observation_and_action_constraint_splitter(observation):
          return observation['network_input'], observation['constraint']
        ```
        *Note*: when using `observation_and_action_constraint_splitter`, make
        sure the provided `q_network` is compatible with the network-specific
        half of the output of the `observation_and_action_constraint_splitter`.
        In particular, `observation_and_action_constraint_splitter` will be
        called on the observation before passing to the network.
        If `observation_and_action_constraint_splitter` is None, action
        constraints are not applied.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      n_step_update: The number of steps to consider when computing TD error and
        TD loss. Defaults to single-step updates. Note that this requires the
        user to call train on Trajectory objects with a time dimension of
        `n_step_update + 1`. However, note that we do not yet support
        `n_step_update > 1` in the case of RNNs (i.e., non-empty
        `q_network.state_spec`).
      boltzmann_temperature: Temperature value to use for Boltzmann sampling of
        the actions during data collection. The closer to 0.0, the higher the
        probability of choosing the best action.
      emit_log_probability: Whether policies emit log probabilities or not.
      target_q_network: (Optional.)  A `tf_agents.network.Network`
        to be used as the target network during Q learning.  Every
        `target_update_period` train steps, the weights from
        `q_network` are copied (possibly with smoothing via
        `target_update_tau`) to `target_q_network`.

        If `target_q_network` is not provided, it is created by
        making a copy of `q_network`, which initializes a new
        network with the same structure and its own layers and weights.

        Network copying is performed via the `Network.copy` superclass method,
        and may inadvertently lead to the resulting network to share weights
        with the original.  This can happen if, for example, the original
        network accepted a pre-built Keras layer in its `__init__`, or
        accepted a Keras layer that wasn't built, but neglected to create
        a new copy.

        In these cases, it is up to you to provide a target Network having
        weights that are not shared with the original `q_network`.
        If you provide a `target_q_network` that shares any
        weights with `q_network`, a warning will be logged but
        no exception is thrown.

        Note; shallow copies of Keras layers may be built via the code:

        ```python
        new_layer = type(layer).from_config(layer.get_config())
        ```
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      ValueError: If `action_spec` contains more than one action or action
        spec minimum is not equal to 0.
      ValueError: If the q networks do not emit floating point outputs with
        inner shape matching `action_spec`.
      NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an
        RNN is provided) and `n_step_update > 1`.
    """
    tf.Module.__init__(self, name=name)

    self._check_action_spec(action_spec)

    if epsilon_greedy is not None and boltzmann_temperature is not None:
      raise ValueError(
          'Configured both epsilon_greedy value {} and temperature {}, '
          'however only one of them can be used for exploration.'.format(
              epsilon_greedy, boltzmann_temperature))

    self._observation_and_action_constraint_splitter = (
        observation_and_action_constraint_splitter)
    self._q_network = q_network
    net_observation_spec = time_step_spec.observation
    if observation_and_action_constraint_splitter:
      net_observation_spec, _ = observation_and_action_constraint_splitter(
          net_observation_spec)
    q_network.create_variables(net_observation_spec)
    if target_q_network:
      target_q_network.create_variables(net_observation_spec)
    self._target_q_network = common.maybe_copy_target_network_with_checks(
        self._q_network, target_q_network, input_spec=net_observation_spec,
        name='TargetQNetwork')

    self._check_network_output(self._q_network, 'q_network')
    self._check_network_output(self._target_q_network, 'target_q_network')

    self._epsilon_greedy = epsilon_greedy
    self._n_step_update = n_step_update
    self._boltzmann_temperature = boltzmann_temperature
    self._optimizer = optimizer
    self._td_errors_loss_fn = (
        td_errors_loss_fn or common.element_wise_huber_loss)
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._gradient_clipping = gradient_clipping
    self._update_target = self._get_target_updater(
        target_update_tau, target_update_period)

    policy, collect_policy = self._setup_policy(time_step_spec, action_spec,
                                                boltzmann_temperature,
                                                emit_log_probability)

    if q_network.state_spec and n_step_update != 1:
      raise NotImplementedError(
          'DqnAgent does not currently support n-step updates with stateful '
          'networks (i.e., RNNs), but n_step_update = {}'.format(n_step_update))

    train_sequence_length = (
        n_step_update + 1 if not q_network.state_spec else None)

    super(DqnAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=train_sequence_length,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter,
        validate_args=False,
    )

    if q_network.state_spec:
      # AsNStepTransition does not support emitting [B, T, ...] tensors,
      # which we need for DQN-RNN.
      self._as_transition = data_converter.AsTransition(
          self.data_context, squeeze_time_dim=False)
    else:
      # This reduces the n-step return and removes the extra time dimension,
      # allowing the rest of the computations to be independent of the
      # n-step parameter.
      self._as_transition = data_converter.AsNStepTransition(
          self.data_context, gamma=gamma, n=n_step_update)
예제 #11
0
  def __init__(self,
               time_step_spec,
               action_spec,
               critic_network,
               actor_network,
               actor_optimizer,
               critic_optimizer,
               actor_loss_weight = 1.0,
               critic_loss_weight = 0.5,
               actor_policy_ctor = actor_policy.ActorPolicy,
               critic_network_2 = None,
               target_critic_network = None,
               target_critic_network_2 = None,
               target_update_tau = 1.0,
               target_update_period = 1,
               td_errors_loss_fn = tf.math.squared_difference,
               gamma = 1.0,
               reward_scale_factor = 1.0,
               gradient_clipping = None,
               debug_summaries = False,
               summarize_grads_and_vars = False,
               train_step_counter = None,
               name = None,
               n_step = None,
               use_behavior_policy = False):
    """Creates a RCE Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      critic_network: A function critic_network((observations, actions)) that
        returns the q_values for each observation and action.
      actor_network: A function actor_network(observation, action_spec) that
        returns action distribution.
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      actor_loss_weight: The weight on actor loss.
      critic_loss_weight: The weight on critic loss.
      actor_policy_ctor: The policy class to use.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target critic network during Q learning. Every
        `target_update_period` train steps, the weights from `critic_network`
        are copied (possibly withsmoothing via `target_update_tau`) to `
        target_critic_network`.  If `target_critic_network` is not provided, it
        is created by making a copy of `critic_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `critic_network`, so that this can be used as a target network.
        If you provide a `target_critic_network` that shares any weights with
        `critic_network`, a warning will be logged but no exception is thrown.
      target_critic_network_2: (Optional.) Similar network as
        target_critic_network but for the critic_network_2. See documentation
        for target_critic_network. Will only be used if 'critic_network_2' is
        also specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn:  A function for computing the elementwise TD errors
        loss.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.
      n_step: An integer specifying whether to use n-step returns. Empirically,
        a value of 10 works well for most tasks. Use None to disable n-step
        returns.
      use_behavior_policy: A boolean indicating how to sample actions for the
        success states. When use_behavior_policy=True, we use the historical
        average policy; otherwise, we use the current policy.
    """
    tf.Module.__init__(self, name=name)

    self._check_action_spec(action_spec)

    self._critic_network_1 = critic_network
    self._critic_network_1.create_variables(
        (time_step_spec.observation, action_spec))
    if target_critic_network:
      target_critic_network.create_variables(
          (time_step_spec.observation, action_spec))
      self._target_critic_network_1 = target_critic_network
    else:
      self._target_critic_network_1 = (
          common.maybe_copy_target_network_with_checks(self._critic_network_1,
                                                       None,
                                                       'TargetCriticNetwork1'))

    if critic_network_2 is not None:
      self._critic_network_2 = critic_network_2
    else:
      self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
      # Do not use target_critic_network_2 if critic_network_2 is None.
      target_critic_network_2 = None
    self._critic_network_2.create_variables(
        (time_step_spec.observation, action_spec))

    if target_critic_network_2:
      target_critic_network_2.create_variables(
          (time_step_spec.observation, action_spec))
      self._target_critic_network_2 = target_critic_network
    else:
      self._target_critic_network_2 = (
          common.maybe_copy_target_network_with_checks(self._critic_network_2,
                                                       None,
                                                       'TargetCriticNetwork2'))

    if actor_network:
      actor_network.create_variables(time_step_spec.observation)
    self._actor_network = actor_network

    self._use_behavior_policy = use_behavior_policy
    if use_behavior_policy:
      self._behavior_actor_network = actor_network.copy(
          name='BehaviorActorNetwork')
      self._behavior_policy = actor_policy_ctor(
          time_step_spec=time_step_spec,
          action_spec=action_spec,
          actor_network=self._behavior_actor_network,
          training=True)

    policy = actor_policy_ctor(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network,
        training=False)

    self._train_policy = actor_policy_ctor(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network,
        training=True)

    self._target_update_tau = target_update_tau
    self._target_update_period = target_update_period
    self._actor_optimizer = actor_optimizer
    self._critic_optimizer = critic_optimizer
    self._actor_loss_weight = actor_loss_weight
    self._critic_loss_weight = critic_loss_weight
    self._td_errors_loss_fn = td_errors_loss_fn
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._gradient_clipping = gradient_clipping
    self._debug_summaries = debug_summaries
    self._summarize_grads_and_vars = summarize_grads_and_vars
    self._update_target = self._get_target_updater(
        tau=self._target_update_tau, period=self._target_update_period)
    self._n_step = n_step

    train_sequence_length = 2 if not critic_network.state_spec else None

    super(RceAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy=policy,
        collect_policy=policy,
        train_sequence_length=train_sequence_length,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter,
        validate_args=False
    )

    self._as_transition = data_converter.AsTransition(
        self.data_context, squeeze_time_dim=(train_sequence_length == 2))
예제 #12
0
  def __init__(self,
               time_step_spec: ts.TimeStep,
               action_spec: types.NestedTensor,
               actor_network: network.Network,
               critic_network: network.Network,
               actor_optimizer: types.Optimizer,
               critic_optimizer: types.Optimizer,
               exploration_noise_std: types.Float = 0.1,
               critic_network_2: Optional[network.Network] = None,
               target_actor_network: Optional[network.Network] = None,
               target_critic_network: Optional[network.Network] = None,
               target_critic_network_2: Optional[network.Network] = None,
               target_update_tau: types.Float = 1.0,
               target_update_period: types.Int = 1,
               actor_update_period: types.Int = 1,
               td_errors_loss_fn: Optional[types.LossFn] = None,
               gamma: types.Float = 1.0,
               reward_scale_factor: types.Float = 1.0,
               target_policy_noise: types.Float = 0.2,
               target_policy_noise_clip: types.Float = 0.5,
               gradient_clipping: Optional[types.Float] = None,
               debug_summaries: bool = False,
               summarize_grads_and_vars: bool = False,
               train_step_counter: Optional[tf.Variable] = None,
               name: Text = None):
    """Creates a Td3Agent Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      actor_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      critic_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, action, step_type).
      actor_optimizer: The default optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      exploration_noise_std: Scale factor on exploration policy noise.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_actor_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target actor network during Q learning. Every
        `target_update_period` train steps, the weights from `actor_network` are
        copied (possibly withsmoothing via `target_update_tau`) to `
        target_actor_network`.  If `target_actor_network` is not provided, it is
        created by making a copy of `actor_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `actor_network`, so that this can be used as a target network.
        If you provide a `target_actor_network` that shares any weights with
        `actor_network`, a warning will be logged but no exception is thrown.
      target_critic_network: (Optional.) Similar network as target_actor_network
        but for the critic_network. See documentation for target_actor_network.
      target_critic_network_2: (Optional.) Similar network as
        target_actor_network but for the critic_network_2. See documentation for
        target_actor_network. Will only be used if 'critic_network_2' is also
        specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      actor_update_period: Period for the optimization step on actor network.
      td_errors_loss_fn:  A function for computing the TD errors loss. If None,
        a default value of elementwise huber_loss is used.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      target_policy_noise: Scale factor on target action noise
      target_policy_noise_clip: Value to clip noise.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.
    """
    tf.Module.__init__(self, name=name)
    self._actor_network = actor_network
    actor_network.create_variables()
    if target_actor_network:
      target_actor_network.create_variables()
    self._target_actor_network = common.maybe_copy_target_network_with_checks(
        self._actor_network, target_actor_network, 'TargetActorNetwork')

    self._critic_network_1 = critic_network
    critic_network.create_variables()
    if target_critic_network:
      target_critic_network.create_variables()
    self._target_critic_network_1 = (
        common.maybe_copy_target_network_with_checks(self._critic_network_1,
                                                     target_critic_network,
                                                     'TargetCriticNetwork1'))

    if critic_network_2 is not None:
      self._critic_network_2 = critic_network_2
    else:
      self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
      # Do not use target_critic_network_2 if critic_network_2 is None.
      target_critic_network_2 = None
    self._critic_network_2.create_variables()
    if target_critic_network_2:
      target_critic_network_2.create_variables()
    self._target_critic_network_2 = (
        common.maybe_copy_target_network_with_checks(self._critic_network_2,
                                                     target_critic_network_2,
                                                     'TargetCriticNetwork2'))

    self._actor_optimizer = actor_optimizer
    self._critic_optimizer = critic_optimizer

    self._exploration_noise_std = exploration_noise_std
    self._target_update_tau = target_update_tau
    self._target_update_period = target_update_period
    self._actor_update_period = actor_update_period
    self._td_errors_loss_fn = (
        td_errors_loss_fn or common.element_wise_huber_loss)
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._target_policy_noise = target_policy_noise
    self._target_policy_noise_clip = target_policy_noise_clip
    self._gradient_clipping = gradient_clipping

    self._update_target = self._get_target_updater(
        target_update_tau, target_update_period)

    policy = actor_policy.ActorPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec,
        actor_network=self._actor_network, clip=True)
    collect_policy = actor_policy.ActorPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec,
        actor_network=self._actor_network, clip=False)
    collect_policy = gaussian_policy.GaussianPolicy(
        collect_policy,
        scale=self._exploration_noise_std,
        clip=True)

    train_sequence_length = 2 if not self._actor_network.state_spec else None
    super(Td3Agent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=train_sequence_length,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter,
        validate_args=False
    )

    self._as_transition = data_converter.AsTransition(
        self.data_context, squeeze_time_dim=(train_sequence_length == 2))
예제 #13
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 actor_network: network.Network,
                 critic_network: network.Network,
                 actor_optimizer: Optional[types.Optimizer] = None,
                 critic_optimizer: Optional[types.Optimizer] = None,
                 ou_stddev: types.Float = 1.0,
                 ou_damping: types.Float = 1.0,
                 target_actor_network: Optional[network.Network] = None,
                 target_critic_network: Optional[network.Network] = None,
                 target_update_tau: types.Float = 1.0,
                 target_update_period: types.Int = 1,
                 dqda_clipping: Optional[types.Float] = None,
                 td_errors_loss_fn: Optional[types.LossFn] = None,
                 gamma: types.Float = 1.0,
                 reward_scale_factor: types.Float = 1.0,
                 gradient_clipping: Optional[types.Float] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 train_step_counter: Optional[tf.Variable] = None,
                 name: Optional[Text] = None):
        """Creates a DDPG Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      actor_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type[, policy_state])
        and should return (action, new_state).
      critic_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call((observation, action), step_type[,
        policy_state]) and should return (q_value, new_state).
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizer: The optimizer to use for the critic network.
      ou_stddev: Standard deviation for the Ornstein-Uhlenbeck (OU) noise added
        in the default collect policy.
      ou_damping: Damping factor for the OU noise added in the default collect
        policy.
      target_actor_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the actor target network during Q learning.  Every
        `target_update_period` train steps, the weights from `actor_network` are
        copied (possibly withsmoothing via `target_update_tau`) to `
        target_q_network`.

        If `target_actor_network` is not provided, it is created by making a
        copy of `actor_network`, which initializes a new network with the same
        structure and its own layers and weights.

        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or
        when the network is sharing layers with another).  In these cases, it is
        up to you to build a copy having weights that are not
        shared with the original `actor_network`, so that this can be used as a
        target network.  If you provide a `target_actor_network` that shares any
        weights with `actor_network`, a warning will be logged but no exception
        is thrown.
      target_critic_network: (Optional.) Similar network as target_actor_network
         but for the critic_network. See documentation for target_actor_network.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      dqda_clipping: when computing the actor loss, clips the gradient dqda
        element-wise between [-dqda_clipping, dqda_clipping]. Does not perform
        clipping if dqda_clipping == 0.
      td_errors_loss_fn:  A function for computing the TD errors loss. If None,
        a default value of elementwise huber_loss is used.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.
    """
        tf.Module.__init__(self, name=name)
        self._actor_network = actor_network
        actor_network.create_variables(time_step_spec.observation)
        if target_actor_network:
            target_actor_network.create_variables(time_step_spec.observation)
        self._target_actor_network = common.maybe_copy_target_network_with_checks(
            self._actor_network,
            target_actor_network,
            'TargetActorNetwork',
            input_spec=time_step_spec.observation)

        self._critic_network = critic_network
        critic_input_spec = (time_step_spec.observation, action_spec)
        critic_network.create_variables(critic_input_spec)
        if target_critic_network:
            target_critic_network.create_variables(critic_input_spec)
        self._target_critic_network = common.maybe_copy_target_network_with_checks(
            self._critic_network,
            target_critic_network,
            'TargetCriticNetwork',
            input_spec=critic_input_spec)

        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer

        self._ou_stddev = ou_stddev
        self._ou_damping = ou_damping
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._dqda_clipping = dqda_clipping
        self._td_errors_loss_fn = (td_errors_loss_fn
                                   or common.element_wise_huber_loss)
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._gradient_clipping = gradient_clipping

        self._update_target = self._get_target_updater(target_update_tau,
                                                       target_update_period)

        policy = actor_policy.ActorPolicy(time_step_spec=time_step_spec,
                                          action_spec=action_spec,
                                          actor_network=self._actor_network,
                                          clip=True)
        collect_policy = actor_policy.ActorPolicy(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            clip=False)
        collect_policy = ou_noise_policy.OUNoisePolicy(
            collect_policy,
            ou_stddev=self._ou_stddev,
            ou_damping=self._ou_damping,
            clip=True)

        super(DdpgAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy,
                             collect_policy,
                             train_sequence_length=2
                             if not self._actor_network.state_spec else None,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter,
                             validate_args=False)

        self._as_transition = data_converter.AsTransition(
            self.data_context,
            squeeze_time_dim=not self._actor_network.state_spec)
예제 #14
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 critic_network: network.Network,
                 actor_network: network.Network,
                 actor_optimizer: types.Optimizer,
                 critic_optimizer: types.Optimizer,
                 alpha_optimizer: types.Optimizer,
                 actor_loss_weight: types.Float = 1.0,
                 critic_loss_weight: types.Float = 0.5,
                 alpha_loss_weight: types.Float = 1.0,
                 actor_policy_ctor: Callable[
                     ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy,
                 critic_network_2: Optional[network.Network] = None,
                 target_critic_network: Optional[network.Network] = None,
                 target_critic_network_2: Optional[network.Network] = None,
                 target_update_tau: types.Float = 1.0,
                 target_update_period: types.Int = 1,
                 td_errors_loss_fn: types.LossFn = tf.math.squared_difference,
                 gamma: types.Float = 1.0,
                 reward_scale_factor: types.Float = 1.0,
                 initial_log_alpha: types.Float = 0.0,
                 use_log_alpha_in_alpha_loss: bool = True,
                 target_entropy: Optional[types.Float] = None,
                 gradient_clipping: Optional[types.Float] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 train_step_counter: Optional[tf.Variable] = None,
                 name: Optional[Text] = None):
        """Creates a SAC Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      critic_network: A function critic_network((observations, actions)) that
        returns the q_values for each observation and action.
      actor_network: A function actor_network(observation, action_spec) that
        returns action distribution.
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      alpha_optimizer: The default optimizer to use for the alpha variable.
      actor_loss_weight: The weight on actor loss.
      critic_loss_weight: The weight on critic loss.
      alpha_loss_weight: The weight on alpha loss.
      actor_policy_ctor: The policy class to use.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target critic network during Q learning. Every
        `target_update_period` train steps, the weights from `critic_network`
        are copied (possibly withsmoothing via `target_update_tau`) to `
        target_critic_network`.  If `target_critic_network` is not provided, it
        is created by making a copy of `critic_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `critic_network`, so that this can be used as a target network.
        If you provide a `target_critic_network` that shares any weights with
        `critic_network`, a warning will be logged but no exception is thrown.
      target_critic_network_2: (Optional.) Similar network as
        target_critic_network but for the critic_network_2. See documentation
        for target_critic_network. Will only be used if 'critic_network_2' is
        also specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn:  A function for computing the elementwise TD errors
        loss.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      initial_log_alpha: Initial value for log_alpha.
      use_log_alpha_in_alpha_loss: A boolean, whether using log_alpha or alpha
        in alpha loss. Certain implementations of SAC use log_alpha as log
        values are generally nicer to work with.
      target_entropy: The target average policy entropy, for updating alpha. The
        default value is negative of the total number of actions.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.
    """
        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        self._critic_network_1 = critic_network
        self._critic_network_1.create_variables(
            (time_step_spec.observation, action_spec))
        if target_critic_network:
            target_critic_network.create_variables(
                (time_step_spec.observation, action_spec))
        self._target_critic_network_1 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_1, target_critic_network,
                'TargetCriticNetwork1'))

        if critic_network_2 is not None:
            self._critic_network_2 = critic_network_2
        else:
            self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
            # Do not use target_critic_network_2 if critic_network_2 is None.
            target_critic_network_2 = None
        self._critic_network_2.create_variables(
            (time_step_spec.observation, action_spec))
        if target_critic_network_2:
            target_critic_network_2.create_variables(
                (time_step_spec.observation, action_spec))
        self._target_critic_network_2 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_2, target_critic_network_2,
                'TargetCriticNetwork2'))

        if actor_network:
            actor_network.create_variables(time_step_spec.observation)
        self._actor_network = actor_network

        policy = actor_policy_ctor(time_step_spec=time_step_spec,
                                   action_spec=action_spec,
                                   actor_network=self._actor_network,
                                   training=False)

        self._train_policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=True)

        self._log_alpha = common.create_variable(
            'initial_log_alpha',
            initial_value=initial_log_alpha,
            dtype=tf.float32,
            trainable=True)

        if target_entropy is None:
            target_entropy = self._get_default_target_entropy(action_spec)

        self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer
        self._actor_loss_weight = actor_loss_weight
        self._critic_loss_weight = critic_loss_weight
        self._alpha_loss_weight = alpha_loss_weight
        self._td_errors_loss_fn = td_errors_loss_fn
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_entropy = target_entropy
        self._gradient_clipping = gradient_clipping
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, period=self._target_update_period)

        train_sequence_length = 2 if not critic_network.state_spec else None

        super(SacAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy=policy,
                             collect_policy=policy,
                             train_sequence_length=train_sequence_length,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter,
                             validate_args=False)

        self._as_transition = data_converter.AsTransition(
            self.data_context, squeeze_time_dim=(train_sequence_length == 2))
예제 #15
0
def train_eval(
    root_dir,
    env_name='HalfCheetah-v2',
    num_iterations=3000000,
    actor_fc_layers=(),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    initial_collect_steps=10000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=1000000,
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    batch_size=256,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    dual_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=50000,
    policy_checkpoint_interval=50000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None,
    latent_dim=10,
    log_prob_reward_scale=0.0,
    predictor_updates_encoder=False,
    predict_prior=True,
    use_recurrent_actor=False,
    rnn_sequence_length=20,
    clip_max_stddev=10.0,
    clip_min_stddev=0.1,
    clip_mean=30.0,
    predictor_num_layers=2,
    use_identity_encoder=False,
    identity_encoder_single_stddev=False,
    kl_constraint=1.0,
    eval_dropout=(),
    use_residual_predictor=True,
    gym_kwargs=None,
    predict_prior_std=True,
    random_seed=0,
):
    """A simple train and eval for SAC."""
    np.random.seed(random_seed)
    tf.random.set_seed(random_seed)
    if use_recurrent_actor:
        batch_size = batch_size // rnn_sequence_length
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)

    global_step = tf.compat.v1.train.get_or_create_global_step()
    with tf.compat.v2.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):

        _build_env = functools.partial(
            suite_gym.load,
            environment_name=env_name,  # pylint: disable=invalid-name
            gym_env_wrappers=(),
            gym_kwargs=gym_kwargs)

        tf_env = tf_py_environment.TFPyEnvironment(_build_env())
        eval_vec = []  # (name, env, metrics)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes)
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(_build_env())
        name = ''
        eval_vec.append((name, eval_tf_env, eval_metrics))

        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        if latent_dim == 'obs':
            latent_dim = observation_spec.shape[0]

        def _activation(t):
            t1, t2 = tf.split(t, 2, axis=1)
            low = -np.inf if clip_mean is None else -clip_mean
            high = np.inf if clip_mean is None else clip_mean
            t1 = rpc_utils.squash_to_range(t1, low, high)

            if clip_min_stddev is None:
                low = -np.inf
            else:
                low = tf.math.log(tf.exp(clip_min_stddev) - 1.0)
            if clip_max_stddev is None:
                high = np.inf
            else:
                high = tf.math.log(tf.exp(clip_max_stddev) - 1.0)
            t2 = rpc_utils.squash_to_range(t2, low, high)
            return tf.concat([t1, t2], axis=1)

        if use_identity_encoder:
            assert latent_dim == observation_spec.shape[0]
            obs_input = tf.keras.layers.Input(observation_spec.shape)
            zeros = 0.0 * obs_input[:, :1]
            stddev_dim = 1 if identity_encoder_single_stddev else latent_dim
            pre_stddev = tf.keras.layers.Dense(stddev_dim,
                                               activation=None)(zeros)
            ones = zeros + tf.ones((1, latent_dim))
            pre_stddev = pre_stddev * ones  # Multiply to broadcast to latent_dim.
            pre_mean_stddev = tf.concat([obs_input, pre_stddev], axis=1)
            output = tfp.layers.IndependentNormal(latent_dim)(pre_mean_stddev)
            encoder_net = tf.keras.Model(inputs=obs_input, outputs=output)
        else:
            encoder_net = tf.keras.Sequential([
                tf.keras.layers.Dense(256, activation='relu'),
                tf.keras.layers.Dense(256, activation='relu'),
                tf.keras.layers.Dense(
                    tfp.layers.IndependentNormal.params_size(latent_dim),
                    activation=_activation,
                    kernel_initializer='glorot_uniform'),
                tfp.layers.IndependentNormal(latent_dim),
            ])

        # Build the predictor net
        obs_input = tf.keras.layers.Input(observation_spec.shape)
        action_input = tf.keras.layers.Input(action_spec.shape)

        class ConstantIndependentNormal(tfp.layers.IndependentNormal):
            """A keras layer that always returns N(0, 1) distribution."""
            def call(self, inputs):
                loc_scale = tf.concat([
                    tf.zeros((latent_dim, )),
                    tf.fill((latent_dim, ), tf.math.log(tf.exp(1.0) - 1))
                ],
                                      axis=0)
                # Multiple by [B x 1] tensor to broadcast batch dimension.
                loc_scale = loc_scale * tf.ones_like(inputs[:, :1])
                return super(ConstantIndependentNormal, self).call(loc_scale)

        if predict_prior:
            z = encoder_net(obs_input)
            if not predictor_updates_encoder:
                z = tf.stop_gradient(z)
            za = tf.concat([z, action_input], axis=1)
            if use_residual_predictor:
                za_input = tf.keras.layers.Input(za.shape[1])
                loc_scale = tf.keras.Sequential(
                    predictor_num_layers *
                    [tf.keras.layers.Dense(256, activation='relu')] + [  # pylint: disable=line-too-long
                        tf.keras.layers.Dense(tfp.layers.IndependentNormal.
                                              params_size(latent_dim),
                                              activation=_activation,
                                              kernel_initializer='zeros'),
                    ])(za_input)
                if predict_prior_std:
                    combined_loc_scale = tf.concat([
                        loc_scale[:, :latent_dim] + za_input[:, :latent_dim],
                        loc_scale[:, latent_dim:]
                    ],
                                                   axis=1)
                else:
                    # Note that softplus(log(e - 1)) = 1.
                    combined_loc_scale = tf.concat([
                        loc_scale[:, :latent_dim] + za_input[:, :latent_dim],
                        tf.math.log(np.e - 1) *
                        tf.ones_like(loc_scale[:, latent_dim:])
                    ],
                                                   axis=1)
                dist = tfp.layers.IndependentNormal(latent_dim)(
                    combined_loc_scale)
                output = tf.keras.Model(inputs=za_input, outputs=dist)(za)
            else:
                assert predict_prior_std
                output = tf.keras.Sequential(
                    predictor_num_layers *
                    [tf.keras.layers.Dense(256, activation='relu')] +  # pylint: disable=line-too-long
                    [
                        tf.keras.layers.Dense(tfp.layers.IndependentNormal.
                                              params_size(latent_dim),
                                              activation=_activation,
                                              kernel_initializer='zeros'),
                        tfp.layers.IndependentNormal(latent_dim),
                    ])(za)
        else:
            # scale is chosen by inverting the softplus function to equal 1.
            if len(obs_input.shape) > 2:
                input_reshaped = tf.reshape(
                    obs_input,
                    [-1, tf.math.reduce_prod(obs_input.shape[1:])])
                #  Multiply by [B x 1] tensor to broadcast batch dimension.
                za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(input_reshaped[:, :1])  # pylint: disable=line-too-long
            else:
                #  Multiple by [B x 1] tensor to broadcast batch dimension.
                za = tf.zeros(latent_dim + action_spec.shape[0], ) * tf.ones_like(obs_input[:, :1])  # pylint: disable=line-too-long
            output = tf.keras.Sequential([
                ConstantIndependentNormal(latent_dim),
            ])(za)
        predictor_net = tf.keras.Model(inputs=(obs_input, action_input),
                                       outputs=output)
        if use_recurrent_actor:
            ActorClass = rpc_utils.RecurrentActorNet  # pylint: disable=invalid-name
        else:
            ActorClass = rpc_utils.ActorNet  # pylint: disable=invalid-name
        actor_net = ActorClass(input_tensor_spec=observation_spec,
                               output_tensor_spec=action_spec,
                               encoder=encoder_net,
                               predictor=predictor_net,
                               fc_layers=actor_fc_layers)

        critic_net = rpc_utils.CriticNet(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')
        critic_net_2 = None
        target_critic_net_1 = None
        target_critic_net_2 = None

        tf_agent = rpc_agent.RpAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            critic_network_2=critic_net_2,
            target_critic_network=target_critic_net_1,
            target_critic_network_2=target_critic_net_2,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
        dual_optimizer = tf.compat.v1.train.AdamOptimizer(
            learning_rate=dual_learning_rate)
        tf_agent.initialize()

        # Make the replay buffer.
        replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=tf_agent.collect_data_spec,
            batch_size=tf_env.batch_size,
            max_length=replay_buffer_capacity)
        replay_observer = [replay_buffer.add_batch]

        train_metrics = [
            tf_metrics.NumberOfEpisodes(),
            tf_metrics.EnvironmentSteps(),
            tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes,
                                           batch_size=tf_env.batch_size),
            tf_metrics.AverageEpisodeLengthMetric(
                buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        ]
        kl_metric = rpc_utils.AverageKLMetric(encoder=encoder_net,
                                              predictor=predictor_net,
                                              batch_size=tf_env.batch_size)
        eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
        initial_collect_policy = random_tf_policy.RandomTFPolicy(
            tf_env.time_step_spec(), tf_env.action_spec())
        collect_policy = tf_agent.collect_policy

        checkpoint_items = {
            'ckpt_dir': train_dir,
            'agent': tf_agent,
            'global_step': global_step,
            'metrics': metric_utils.MetricsGroup(train_metrics,
                                                 'train_metrics'),
            'dual_optimizer': dual_optimizer,
        }
        train_checkpointer = common.Checkpointer(**checkpoint_items)

        policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'policy'),
                                                  policy=eval_policy,
                                                  global_step=global_step)
        rb_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
            train_dir, 'replay_buffer'),
                                              max_to_keep=1,
                                              replay_buffer=replay_buffer)

        train_checkpointer.initialize_or_restore()
        rb_checkpointer.initialize_or_restore()

        initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            initial_collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=initial_collect_steps,
            transition_observers=[kl_metric])

        collect_driver = dynamic_step_driver.DynamicStepDriver(
            tf_env,
            collect_policy,
            observers=replay_observer + train_metrics,
            num_steps=collect_steps_per_iteration,
            transition_observers=[kl_metric])

        if use_tf_functions:
            initial_collect_driver.run = common.function(
                initial_collect_driver.run)
            collect_driver.run = common.function(collect_driver.run)
            tf_agent.train = common.function(tf_agent.train)

        if replay_buffer.num_frames() == 0:
            # Collect initial replay data.
            logging.info(
                'Initializing replay buffer by collecting experience for %d steps '
                'with a random policy.', initial_collect_steps)
            initial_collect_driver.run()

        for name, eval_tf_env, eval_metrics in eval_vec:
            results = metric_utils.eager_compute(
                eval_metrics,
                eval_tf_env,
                eval_policy,
                num_episodes=num_eval_episodes,
                train_step=global_step,
                summary_writer=eval_summary_writer,
                summary_prefix='Metrics-%s' % name,
            )
            if eval_metrics_callback is not None:
                eval_metrics_callback(results, global_step.numpy())
            metric_utils.log_metrics(eval_metrics, prefix=name)

        time_step = None
        policy_state = collect_policy.get_initial_state(tf_env.batch_size)

        timed_at_step = global_step.numpy()
        time_acc = 0
        train_time_acc = 0
        env_time_acc = 0

        if use_recurrent_actor:  # default from sac/train_eval_rnn.py
            num_steps = rnn_sequence_length + 1

            def _filter_invalid_transition(trajectories, unused_arg1):
                return tf.reduce_all(~trajectories.is_boundary()[:-1])

            tf_agent._as_transition = data_converter.AsTransition(  # pylint: disable=protected-access
                tf_agent.data_context,
                squeeze_time_dim=False)
        else:
            num_steps = 2

            def _filter_invalid_transition(trajectories, unused_arg1):
                return ~trajectories.is_boundary()[0]

        dataset = replay_buffer.as_dataset(
            sample_batch_size=batch_size,
            num_steps=num_steps).unbatch().filter(_filter_invalid_transition)

        dataset = dataset.batch(batch_size).prefetch(5)
        # Dataset generates trajectories with shape [Bx2x...]
        iterator = iter(dataset)

        @tf.function
        def train_step():
            experience, _ = next(iterator)

            prior = predictor_net(
                (experience.observation[:, 0], experience.action[:, 0]),
                training=False)
            z_next = encoder_net(experience.observation[:, 1], training=False)
            # predictor_kl is a vector of size batch_size.
            predictor_kl = tfp.distributions.kl_divergence(z_next, prior)

            with tf.GradientTape() as tape:
                tape.watch(actor_net._log_kl_coefficient)  # pylint: disable=protected-access
                dual_loss = -1.0 * actor_net._log_kl_coefficient * (  # pylint: disable=protected-access
                    tf.stop_gradient(tf.reduce_mean(predictor_kl)) -
                    kl_constraint)
            dual_grads = tape.gradient(dual_loss,
                                       [actor_net._log_kl_coefficient])  # pylint: disable=protected-access
            grads_and_vars = list(
                zip(dual_grads, [actor_net._log_kl_coefficient]))  # pylint: disable=protected-access
            dual_optimizer.apply_gradients(grads_and_vars)

            # Clip the dual variable so exp(log_kl_coef) <= 1e6.
            log_kl_coef = tf.clip_by_value(
                actor_net._log_kl_coefficient,  # pylint: disable=protected-access
                -1.0 * np.log(1e6),
                np.log(1e6))
            actor_net._log_kl_coefficient.assign(log_kl_coef)  # pylint: disable=protected-access

            with tf.name_scope('dual_loss'):
                tf.compat.v2.summary.scalar(name='dual_loss',
                                            data=tf.reduce_mean(dual_loss),
                                            step=global_step)
                tf.compat.v2.summary.scalar(
                    name='log_kl_coefficient',
                    data=actor_net._log_kl_coefficient,  # pylint: disable=protected-access
                    step=global_step)

            z_entropy = z_next.entropy()
            log_prob = prior.log_prob(z_next.sample())
            with tf.name_scope('rp-metrics'):
                common.generate_tensor_summaries('predictor_kl', predictor_kl,
                                                 global_step)
                common.generate_tensor_summaries('z_entropy', z_entropy,
                                                 global_step)
                common.generate_tensor_summaries('log_prob', log_prob,
                                                 global_step)
                common.generate_tensor_summaries('z_mean', z_next.mean(),
                                                 global_step)
                common.generate_tensor_summaries('z_stddev', z_next.stddev(),
                                                 global_step)
                common.generate_tensor_summaries('prior_mean', prior.mean(),
                                                 global_step)
                common.generate_tensor_summaries('prior_stddev',
                                                 prior.stddev(), global_step)

            if log_prob_reward_scale == 'auto':
                coef = tf.stop_gradient(tf.exp(actor_net._log_kl_coefficient))  # pylint: disable=protected-access
            else:
                coef = log_prob_reward_scale
            tf.debugging.check_numerics(tf.reduce_mean(predictor_kl),
                                        'predictor_kl is inf or nan.')
            tf.debugging.check_numerics(coef, 'coef is inf or nan.')
            new_reward = experience.reward - coef * predictor_kl[:, None]

            experience = experience._replace(reward=new_reward)
            return tf_agent.train(experience)

        if use_tf_functions:
            train_step = common.function(train_step)

        # Save the hyperparameters
        operative_filename = os.path.join(root_dir, 'operative.gin')
        with tf.compat.v1.gfile.Open(operative_filename, 'w') as f:
            f.write(gin.operative_config_str())
            print(gin.operative_config_str())

        global_step_val = global_step.numpy()
        while global_step_val < num_iterations:
            start_time = time.time()
            time_step, policy_state = collect_driver.run(
                time_step=time_step,
                policy_state=policy_state,
            )
            env_time_acc += time.time() - start_time
            train_start_time = time.time()
            for _ in range(train_steps_per_iteration):
                train_loss = train_step()
            train_time_acc += time.time() - train_start_time
            time_acc += time.time() - start_time

            global_step_val = global_step.numpy()

            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             train_loss.loss)
                steps_per_sec = (global_step_val - timed_at_step) / time_acc
                logging.info('%.3f steps/sec', steps_per_sec)
                tf.compat.v2.summary.scalar(name='global_steps_per_sec',
                                            data=steps_per_sec,
                                            step=global_step)
                train_steps_per_sec = (global_step_val -
                                       timed_at_step) / train_time_acc
                logging.info('Train: %.3f steps/sec', train_steps_per_sec)
                tf.compat.v2.summary.scalar(name='train_steps_per_sec',
                                            data=train_steps_per_sec,
                                            step=global_step)
                env_steps_per_sec = (global_step_val -
                                     timed_at_step) / env_time_acc
                logging.info('Env: %.3f steps/sec', env_steps_per_sec)
                tf.compat.v2.summary.scalar(name='env_steps_per_sec',
                                            data=env_steps_per_sec,
                                            step=global_step)
                timed_at_step = global_step_val
                time_acc = 0
                train_time_acc = 0
                env_time_acc = 0

            for train_metric in train_metrics + [kl_metric]:
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2])

            if global_step_val % eval_interval == 0:
                start_time = time.time()
                for name, eval_tf_env, eval_metrics in eval_vec:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='Metrics-%s' % name,
                    )
                    if eval_metrics_callback is not None:
                        eval_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics, prefix=name)
                logging.info('Evaluation: %d min',
                             (time.time() - start_time) / 60)
                for prob_dropout in eval_dropout:
                    rpc_utils.eval_dropout_fn(eval_tf_env,
                                              actor_net,
                                              global_step,
                                              prob_dropout=prob_dropout)

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % rb_checkpoint_interval == 0:
                rb_checkpointer.save(global_step=global_step_val)
예제 #16
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 critic_network: network.Network,
                 actor_network: network.Network,
                 actor_optimizer: types.Optimizer,
                 critic_optimizer: types.Optimizer,
                 alpha_optimizer: types.Optimizer,
                 actor_loss_weight: types.Float = 1.0,
                 critic_loss_weight: types.Float = 0.5,
                 alpha_loss_weight: types.Float = 1.0,
                 actor_policy_ctor: Callable[
                     ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy,
                 critic_network_2: Optional[network.Network] = None,
                 target_critic_network: Optional[network.Network] = None,
                 target_critic_network_2: Optional[network.Network] = None,
                 target_update_tau: types.Float = 1.0,
                 target_update_period: types.Int = 1,
                 td_errors_loss_fn: types.LossFn = tf.math.squared_difference,
                 gamma: types.Float = 1.0,
                 sigma: types.Float = 0.9,
                 reward_scale_factor: types.Float = 1.0,
                 initial_log_alpha: types.Float = 0.0,
                 use_log_alpha_in_alpha_loss: bool = True,
                 target_entropy: Optional[types.Float] = None,
                 gradient_clipping: Optional[types.Float] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 train_step_counter: Optional[tf.Variable] = None,
                 name: Optional[Text] = None):

        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        net_observation_spec = time_step_spec.observation
        critic_spec = (net_observation_spec, action_spec)

        self._critic_network_1 = critic_network

        if critic_network_2 is not None:
            self._critic_network_2 = critic_network_2
        else:
            self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
            # Do not use target_critic_network_2 if critic_network_2 is None.
            target_critic_network_2 = None

        # Wait until critic_network_2 has been copied from critic_network_1 before
        # creating variables on both.
        self._critic_network_1.create_variables(critic_spec)
        self._critic_network_2.create_variables(critic_spec)

        if target_critic_network:
            target_critic_network.create_variables(critic_spec)

        self._target_critic_network_1 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_1,
                target_critic_network,
                input_spec=critic_spec,
                name='TargetCriticNetwork1'))

        if target_critic_network_2:
            target_critic_network_2.create_variables(critic_spec)
        self._target_critic_network_2 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_2,
                target_critic_network_2,
                input_spec=critic_spec,
                name='TargetCriticNetwork2'))

        if actor_network:
            actor_network.create_variables(net_observation_spec)
        self._actor_network = actor_network

        policy = actor_policy_ctor(time_step_spec=time_step_spec,
                                   action_spec=action_spec,
                                   actor_network=self._actor_network,
                                   training=False)

        self._train_policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=True)

        self._log_alpha = common.create_variable(
            'initial_log_alpha',
            initial_value=initial_log_alpha,
            dtype=tf.float32,
            trainable=True)

        if target_entropy is None:
            target_entropy = self._get_default_target_entropy(action_spec)

        self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer
        self._actor_loss_weight = actor_loss_weight
        self._critic_loss_weight = critic_loss_weight
        self._alpha_loss_weight = alpha_loss_weight
        self._td_errors_loss_fn = td_errors_loss_fn
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_entropy = target_entropy
        self._gradient_clipping = gradient_clipping
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, period=self._target_update_period)

        self.sigma = sigma

        train_sequence_length = 2 if not critic_network.state_spec else None

        super(sac_agent.SacAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy=policy,
                             collect_policy=policy,
                             train_sequence_length=train_sequence_length,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter,
                             validate_args=False)

        self._as_transition = data_converter.AsTransition(
            self.data_context, squeeze_time_dim=(train_sequence_length == 2))
    def __init__(
            self,
            time_step_spec,
            action_spec,
            optimizer=None,
            actor_net=None,
            value_net=None,
            importance_ratio_clipping=0.0,
            lambda_value=0.95,
            discount_factor=0.99,
            entropy_regularization=0.0,
            policy_l2_reg=0.0,
            value_function_l2_reg=0.0,
            shared_vars_l2_reg=0.0,
            value_pred_loss_coef=0.5,
            num_epochs=25,
            use_gae=False,
            use_td_lambda_return=False,
            normalize_rewards=True,
            reward_norm_clipping=10.0,
            normalize_observations=True,
            log_prob_clipping=0.0,
            kl_cutoff_factor=0.0,
            kl_cutoff_coef=0.0,
            initial_adaptive_kl_beta=0.0,
            adaptive_kl_target=0.0,
            adaptive_kl_tolerance=0.0,
            gradient_clipping=None,
            value_clipping=None,
            check_numerics=False,
            # TODO(b/150244758): Change the default to False once we move
            # clients onto Reverb.
            compute_value_and_advantage_in_train=True,
            update_normalizers_in_train=True,
            debug_summaries=False,
            summarize_grads_and_vars=False,
            train_step_counter=None,
            name='AttentionPPOAgent'):
        """Creates a PPO Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of `BoundedTensorSpec` representing the actions.
      optimizer: Optimizer to use for the agent, default to using
        `tf.compat.v1.train.AdamOptimizer`.
      actor_net: A `network.DistributionNetwork` which maps observations to
        action distributions. Commonly, it is set to
        `actor_distribution_network.ActorDistributionNetwork`.
      value_net: A `Network` which returns the value prediction for input
        states, with `call(observation, step_type, network_state)`. Commonly, it
        is set to `value_network.ValueNetwork`.
      importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective.
        For more detail, see explanation at the top of the doc.
      lambda_value: Lambda parameter for TD-lambda computation.
      discount_factor: Discount factor for return computation. Default to `0.99`
        which is the value used for all environments from (Schulman, 2017).
      entropy_regularization: Coefficient for entropy regularization loss term.
        Default to `0.0` because no entropy bonus was used in (Schulman, 2017).
      policy_l2_reg: Coefficient for L2 regularization of unshared actor_net
        weights. Default to `0.0` because no L2 regularization was applied on
        the policy network weights in (Schulman, 2017).
      value_function_l2_reg: Coefficient for l2 regularization of unshared value
        function weights. Default to `0.0` because no L2 regularization was
        applied on the policy network weights in (Schulman, 2017).
      shared_vars_l2_reg: Coefficient for l2 regularization of weights shared
        between actor_net and value_net. Default to `0.0` because no L2
        regularization was applied on the policy network or value network
        weights in (Schulman, 2017).
      value_pred_loss_coef: Multiplier for value prediction loss to balance with
        policy gradient loss. Default to `0.5`, which was used for all
        environments in the OpenAI baseline implementation. This parameters is
        irrelevant unless you are sharing part of actor_net and value_net. In
        that case, you would want to tune this coeeficient, whose value depends
        on the network architecture of your choice.
      num_epochs: Number of epochs for computing policy updates. (Schulman,2017)
        sets this to 10 for Mujoco, 15 for Roboschool and 3 for Atari.
      use_gae: If True (default False), uses generalized advantage estimation
        for computing per-timestep advantage. Else, just subtracts value
        predictions from empirical return.
      use_td_lambda_return: If True (default False), uses td_lambda_return for
        training value function; here: `td_lambda_return = gae_advantage +
          value_predictions`. `use_gae` must be set to `True` as well to enable
          TD -lambda returns. If `use_td_lambda_return` is set to True while
          `use_gae` is False, the empirical return will be used and a warning
          will be logged.
      normalize_rewards: If true, keeps moving variance of rewards and
        normalizes incoming rewards. While not mentioned directly in (Schulman,
        2017), reward normalization was implemented in OpenAI baselines and
        (Ilyas et al., 2018) pointed out that it largely improves performance.
        You may refer to Figure 1 of https://arxiv.org/pdf/1811.02553.pdf for a
          comparison with and without reward scaling.
      reward_norm_clipping: Value above and below to clip normalized reward.
        Additional optimization proposed in (Ilyas et al., 2018) set to `5` or
        `10`.
      normalize_observations: If `True`, keeps moving mean and variance of
        observations and normalizes incoming observations. Additional
        optimization proposed in (Ilyas et al., 2018). If true, and the
        observation spec is not tf.float32 (such as Atari), please manually
        convert the observation spec received from the environment to tf.float32
        before creating the networks. Otherwise, the normalized input to the
        network (float32) will have a different dtype as what the network
        expects, resulting in a mismatch error.
        Example usage: ```python observation_tensor_spec, action_spec,
          time_step_tensor_spec = ( spec_utils.get_tensor_specs(env))
          normalized_observation_tensor_spec = tf.nest.map_structure(
            lambda s: tf.TensorSpec( dtype=tf.float32, shape=s.shape,
              name=s.name ), observation_tensor_spec )  actor_net =
              actor_distribution_network.ActorDistributionNetwork(
              normalized_observation_tensor_spec, ...) value_net =
              value_network.ValueNetwork( normalized_observation_tensor_spec,
              ...) # Note that the agent still uses the original
              time_step_tensor_spec # from the environment. agent =
              ppo_clip_agent.PPOClipAgent( time_step_tensor_spec, action_spec,
              actor_net, value_net, ...) ```
      log_prob_clipping: +/- value for clipping log probs to prevent inf / NaN
        values.  Default: no clipping.
      kl_cutoff_factor: Only meaningful when `kl_cutoff_coef > 0.0`. A multipler
        used for calculating the KL cutoff ( = `kl_cutoff_factor *
        adaptive_kl_target`). If policy KL averaged across the batch changes
        more than the cutoff, a squared cutoff loss would be added to the loss
        function.
      kl_cutoff_coef: kl_cutoff_coef and kl_cutoff_factor are additional params
        if one wants to use a KL cutoff loss term in addition to the adaptive KL
        loss term. Default to 0.0 to disable the KL cutoff loss term as this was
        not used in the paper.  kl_cutoff_coef is the coefficient to mulitply by
        the KL cutoff loss term, before adding to the total loss function.
      initial_adaptive_kl_beta: Initial value for beta coefficient of adaptive
        KL penalty. This initial value is not important in practice because the
        algorithm quickly adjusts to it. A common default is 1.0.
      adaptive_kl_target: Desired KL target for policy updates. If actual KL is
        far from this target, adaptive_kl_beta will be updated. You should tune
        this for your environment. 0.01 was found to perform well for Mujoco.
      adaptive_kl_tolerance: A tolerance for adaptive_kl_beta. Mean KL above `(1
        + tol) * adaptive_kl_target`, or below `(1 - tol) * adaptive_kl_target`,
        will cause `adaptive_kl_beta` to be updated. `0.5` was chosen
        heuristically in the paper, but the algorithm is not very sensitive to
        it.
      gradient_clipping: Norm length to clip gradients.  Default: no clipping.
      value_clipping: Difference between new and old value predictions are
        clipped to this threshold. Value clipping could be helpful when training
        very deep networks. Default: no clipping.
      check_numerics: If true, adds `tf.debugging.check_numerics` to help find
        NaN / Inf values. For debugging only.
      compute_value_and_advantage_in_train: A bool to indicate where value
        prediction and advantage calculation happen.  If True, both happen in
        agent.train(). If False, value prediction is computed during data
        collection. This argument must be set to `False` if mini batch learning
        is enabled.
      update_normalizers_in_train: A bool to indicate whether normalizers are
        updated as parts of the `train` method. Set to `False` if mini batch
        learning is enabled, or if `train` is called on multiple iterations of
        the same trajectories. In that case, you would need to use `PPOLearner`
        (which updates all the normalizers outside of the agent). This ensures
        that normalizers are updated in the same way as (Schulman, 2017).
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If true, gradient summaries will be written.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.

    Raises:
      TypeError: if `actor_net` or `value_net` is not of type
        `tf_agents.networks.Network`.
    """
        if not isinstance(actor_net, network.Network):
            raise TypeError(
                'actor_net must be an instance of a network.Network.')
        if not isinstance(value_net, network.Network):
            raise TypeError(
                'value_net must be an instance of a network.Network.')

        # PPOPolicy validates these, so we skip validation here.
        actor_net.create_variables(time_step_spec.observation)
        value_net.create_variables(time_step_spec.observation)

        tf.Module.__init__(self, name=name)

        self._optimizer = optimizer
        self._actor_net = actor_net
        self._value_net = value_net
        self._importance_ratio_clipping = importance_ratio_clipping
        self._lambda = lambda_value
        self._discount_factor = discount_factor
        self._entropy_regularization = entropy_regularization
        self._policy_l2_reg = policy_l2_reg
        self._value_function_l2_reg = value_function_l2_reg
        self._shared_vars_l2_reg = shared_vars_l2_reg
        self._value_pred_loss_coef = value_pred_loss_coef
        self._num_epochs = num_epochs
        self._use_gae = use_gae
        self._use_td_lambda_return = use_td_lambda_return
        self._reward_norm_clipping = reward_norm_clipping
        self._log_prob_clipping = log_prob_clipping
        self._kl_cutoff_factor = kl_cutoff_factor
        self._kl_cutoff_coef = kl_cutoff_coef
        self._adaptive_kl_target = adaptive_kl_target
        self._adaptive_kl_tolerance = adaptive_kl_tolerance
        self._gradient_clipping = gradient_clipping or 0.0
        self._value_clipping = value_clipping or 0.0
        self._check_numerics = check_numerics
        self._compute_value_and_advantage_in_train = (
            compute_value_and_advantage_in_train)
        self.update_normalizers_in_train = update_normalizers_in_train
        if not isinstance(self._optimizer, tf.keras.optimizers.Optimizer):
            logging.warning(
                'Only tf.keras.optimizers.Optimizers are well supported, got a '
                'non-TF2 optimizer: %s', self._optimizer)

        self._initial_adaptive_kl_beta = initial_adaptive_kl_beta
        if initial_adaptive_kl_beta > 0.0:
            self._adaptive_kl_beta = common.create_variable(
                'adaptive_kl_beta', initial_adaptive_kl_beta, dtype=tf.float32)
        else:
            self._adaptive_kl_beta = None

        self._reward_normalizer = None
        if normalize_rewards:
            self._reward_normalizer = tensor_normalizer.StreamingTensorNormalizer(
                tensor_spec.TensorSpec([], tf.float32),
                scope='normalize_reward')

        self._observation_normalizer = None
        if normalize_observations:
            self._observation_normalizer = (
                tensor_normalizer.StreamingTensorNormalizer(
                    time_step_spec.observation,
                    scope='normalize_observations'))

        self._advantage_normalizer = tensor_normalizer.StreamingTensorNormalizer(
            tensor_spec.TensorSpec([], tf.float32),
            scope='normalize_advantages')

        policy = greedy_policy.GreedyPolicy(
            attention_ppo_policy.AttentionPPOPolicy(
                time_step_spec=time_step_spec,
                action_spec=action_spec,
                actor_network=actor_net,
                value_network=value_net,
                observation_normalizer=self._observation_normalizer,
                clip=False,
                collect=False))

        collect_policy = attention_ppo_policy.AttentionPPOPolicy(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=actor_net,
            value_network=value_net,
            observation_normalizer=self._observation_normalizer,
            clip=False,
            collect=True,
            compute_value_and_advantage_in_train=(
                self._compute_value_and_advantage_in_train),
        )

        if isinstance(self._actor_net, network.DistributionNetwork):
            # Legacy behavior
            self._action_distribution_spec = self._actor_net.output_spec
        else:
            self._action_distribution_spec = self._actor_net.create_variables(
                time_step_spec.observation)

        # Set training_data_spec to collect_data_spec with augmented policy info,
        # iff return and normalized advantage are saved in preprocess_sequence.
        if self._compute_value_and_advantage_in_train:
            training_data_spec = None
        else:
            training_policy_info = collect_policy.trajectory_spec.policy_info.copy(
            )
            training_policy_info.update({
                'value_prediction':
                collect_policy.trajectory_spec.policy_info['value_prediction'],
                'return':
                tensor_spec.TensorSpec(shape=[], dtype=tf.float32),
                'advantage':
                tensor_spec.TensorSpec(shape=[], dtype=tf.float32),
            })
            training_data_spec = collect_policy.trajectory_spec.replace(
                policy_info=training_policy_info)

        super(ppo_agent.PPOAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy,
                             collect_policy,
                             train_sequence_length=None,
                             training_data_spec=training_data_spec,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)

        # This must be built after super() which sets up self.data_context.
        self._collected_as_transition = data_converter.AsTransition(
            self.collect_data_context, squeeze_time_dim=False)

        self._as_trajectory = data_converter.AsTrajectory(self.data_context,
                                                          sequence_length=None)