Esempio n. 1
0
    def __init__(
            self,
            time_step_spec: ts.TimeStep,
            action_spec: types.NestedTensorSpec,
            q_network: network.Network,
            optimizer: types.Optimizer,
            observation_and_action_constraint_splitter: Optional[
                types.Splitter] = None,
            epsilon_greedy: types.Float = 0.1,
            n_step_update: int = 1,
            boltzmann_temperature: Optional[types.Int] = None,
            emit_log_probability: bool = False,
            # Params for target network updates
            target_q_network: Optional[network.Network] = None,
            target_update_tau: types.Float = 1.0,
            target_update_period: int = 1,
            # Params for training.
            td_errors_loss_fn: Optional[types.LossFn] = None,
            gamma: types.Float = 1.0,
            reward_scale_factor: types.Float = 1.0,
            gradient_clipping: Optional[types.Float] = None,
            # Params for debugging
            debug_summaries: bool = False,
            summarize_grads_and_vars: bool = False,
            train_step_counter: Optional[tf.Variable] = None,
            name: Optional[Text] = None,
            entropy_tau: types.Float = 0.9,
            alpha: types.Float = 0.3):

        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        if epsilon_greedy is not None and boltzmann_temperature is not None:
            raise ValueError(
                'Configured both epsilon_greedy value {} and temperature {}, '
                'however only one of them can be used for exploration.'.format(
                    epsilon_greedy, boltzmann_temperature))

        self._observation_and_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
        self._q_network = q_network
        net_observation_spec = time_step_spec.observation
        if observation_and_action_constraint_splitter:
            net_observation_spec, _ = observation_and_action_constraint_splitter(
                net_observation_spec)
        q_network.create_variables(net_observation_spec)
        if target_q_network:
            target_q_network.create_variables(net_observation_spec)
        self._target_q_network = common.maybe_copy_target_network_with_checks(
            self._q_network,
            target_q_network,
            input_spec=net_observation_spec,
            name='TargetQNetwork')

        self._check_network_output(self._q_network, 'q_network')
        self._check_network_output(self._target_q_network, 'target_q_network')

        self._epsilon_greedy = epsilon_greedy
        self._n_step_update = n_step_update
        self._boltzmann_temperature = boltzmann_temperature
        self._optimizer = optimizer
        self._td_errors_loss_fn = (td_errors_loss_fn
                                   or common.element_wise_huber_loss)
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._gradient_clipping = gradient_clipping
        self._update_target = self._get_target_updater(target_update_tau,
                                                       target_update_period)
        self.entropy_tau = entropy_tau
        self.alpha = alpha

        policy, collect_policy = self._setup_policy(time_step_spec,
                                                    action_spec,
                                                    boltzmann_temperature,
                                                    emit_log_probability)

        if q_network.state_spec and n_step_update != 1:
            raise NotImplementedError(
                'DqnAgent does not currently support n-step updates with stateful '
                'networks (i.e., RNNs), but n_step_update = {}'.format(
                    n_step_update))

        train_sequence_length = (n_step_update +
                                 1 if not q_network.state_spec else None)

        super(dqn_agent.DqnAgent, self).__init__(
            time_step_spec,
            action_spec,
            policy,
            collect_policy,
            train_sequence_length=train_sequence_length,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step_counter,
            validate_args=False,
        )

        if q_network.state_spec:
            # AsNStepTransition does not support emitting [B, T, ...] tensors,
            # which we need for DQN-RNN.
            self._as_transition = data_converter.AsTransition(
                self.data_context, squeeze_time_dim=False)
        else:
            # This reduces the n-step return and removes the extra time dimension,
            # allowing the rest of the computations to be independent of the
            # n-step parameter.
            self._as_transition = data_converter.AsNStepTransition(
                self.data_context, gamma=gamma, n=n_step_update)
Esempio n. 2
0
  def __init__(self,
               time_step_spec,
               action_spec,
               actor_network,
               critic_network,
               actor_optimizer,
               critic_optimizer,
               exploration_noise_std=0.1,
               critic_network_2=None,
               target_actor_network=None,
               target_critic_network=None,
               target_critic_network_2=None,
               target_update_tau=1.0,
               target_update_period=1,
               actor_update_period=1,
               dqda_clipping=None,
               td_errors_loss_fn=None,
               gamma=1.0,
               reward_scale_factor=1.0,
               target_policy_noise=0.2,
               target_policy_noise_clip=0.5,
               gradient_clipping=None,
               debug_summaries=False,
               summarize_grads_and_vars=False,
               train_step_counter=None,
               name=None):
    """Creates a Td3Agent Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      actor_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      critic_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, action, step_type).
      actor_optimizer: The default optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      exploration_noise_std: Scale factor on exploration policy noise.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_actor_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target actor network during Q learning. Every
        `target_update_period` train steps, the weights from `actor_network` are
        copied (possibly withsmoothing via `target_update_tau`) to `
        target_actor_network`.  If `target_actor_network` is not provided, it is
        created by making a copy of `actor_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `actor_network`, so that this can be used as a target network.
        If you provide a `target_actor_network` that shares any weights with
        `actor_network`, a warning will be logged but no exception is thrown.
      target_critic_network: (Optional.) Similar network as target_actor_network
        but for the critic_network. See documentation for target_actor_network.
      target_critic_network_2: (Optional.) Similar network as
        target_actor_network but for the critic_network_2. See documentation for
        target_actor_network. Will only be used if 'critic_network_2' is also
        specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      actor_update_period: Period for the optimization step on actor network.
      dqda_clipping: A scalar or float clips the gradient dqda element-wise
        between [-dqda_clipping, dqda_clipping]. Default is None representing no
        clippiing.
      td_errors_loss_fn:  A function for computing the TD errors loss. If None,
        a default value of elementwise huber_loss is used.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      target_policy_noise: Scale factor on target action noise
      target_policy_noise_clip: Value to clip noise.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.
    """
    tf.Module.__init__(self, name=name)
    self._actor_network = actor_network
    actor_network.create_variables()
    if target_actor_network:
      target_actor_network.create_variables()
    self._target_actor_network = common.maybe_copy_target_network_with_checks(
        self._actor_network, target_actor_network, 'TargetActorNetwork')

    self._critic_network_1 = critic_network
    critic_network.create_variables()
    if target_critic_network:
      target_critic_network.create_variables()
    self._target_critic_network_1 = (
        common.maybe_copy_target_network_with_checks(self._critic_network_1,
                                                     target_critic_network,
                                                     'TargetCriticNetwork1'))

    if critic_network_2 is not None:
      self._critic_network_2 = critic_network_2
    else:
      self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
      # Do not use target_critic_network_2 if critic_network_2 is None.
      target_critic_network_2 = None
    self._critic_network_2.create_variables()
    if target_critic_network_2:
      target_critic_network_2.create_variables()
    self._target_critic_network_2 = (
        common.maybe_copy_target_network_with_checks(self._critic_network_2,
                                                     target_critic_network_2,
                                                     'TargetCriticNetwork2'))

    self._actor_optimizer = actor_optimizer
    self._critic_optimizer = critic_optimizer

    self._exploration_noise_std = exploration_noise_std
    self._target_update_tau = target_update_tau
    self._target_update_period = target_update_period
    self._actor_update_period = actor_update_period
    self._dqda_clipping = dqda_clipping
    self._td_errors_loss_fn = (
        td_errors_loss_fn or common.element_wise_huber_loss)
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._target_policy_noise = target_policy_noise
    self._target_policy_noise_clip = target_policy_noise_clip
    self._gradient_clipping = gradient_clipping

    self._update_target = self._get_target_updater(target_update_tau,
                                                   target_update_period)

    policy = actor_policy.ActorPolicy(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network,
        clip=True)
    collect_policy = actor_policy.ActorPolicy(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network,
        clip=False)
    collect_policy = gaussian_policy.GaussianPolicy(
        collect_policy, scale=self._exploration_noise_std, clip=True)

    super(Td3Agent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=2 if not self._actor_network.state_spec else None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter)
Esempio n. 3
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 critic_network,
                 actor_network,
                 actor_optimizer,
                 critic_optimizer,
                 alpha_optimizer,
                 critic_network_no_entropy=None,
                 critic_no_entropy_optimizer=None,
                 actor_loss_weight=1.0,
                 critic_loss_weight=0.5,
                 alpha_loss_weight=1.0,
                 actor_policy_ctor=actor_policy.ActorPolicy,
                 critic_network_2=None,
                 target_critic_network=None,
                 target_critic_network_2=None,
                 target_update_tau=1.0,
                 target_update_period=1,
                 td_errors_loss_fn=tf.math.squared_difference,
                 gamma=1.0,
                 reward_scale_factor=1.0,
                 initial_log_alpha=0.0,
                 use_log_alpha_in_alpha_loss=True,
                 target_entropy=None,
                 gradient_clipping=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 train_step_counter=None,
                 name=None):
        """Creates a SAC Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      critic_network: A function critic_network((observations, actions)) that
        returns the q_values for each observation and action.
      actor_network: A function actor_network(observation, action_spec) that
        returns action distribution.
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      alpha_optimizer: The default optimizer to use for the alpha variable.
      actor_loss_weight: The weight on actor loss.
      critic_loss_weight: The weight on critic loss.
      alpha_loss_weight: The weight on alpha loss.
      actor_policy_ctor: The policy class to use.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target critic network during Q learning. Every
        `target_update_period` train steps, the weights from `critic_network`
        are copied (possibly withsmoothing via `target_update_tau`) to `
        target_critic_network`.  If `target_critic_network` is not provided, it
        is created by making a copy of `critic_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `critic_network`, so that this can be used as a target network.
        If you provide a `target_critic_network` that shares any weights with
        `critic_network`, a warning will be logged but no exception is thrown.
      target_critic_network_2: (Optional.) Similar network as
        target_critic_network but for the critic_network_2. See documentation
        for target_critic_network. Will only be used if 'critic_network_2' is
        also specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn:  A function for computing the elementwise TD errors
        loss.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      initial_log_alpha: Initial value for log_alpha.
      use_log_alpha_in_alpha_loss: A boolean, whether using log_alpha or alpha
        in alpha loss. Certain implementations of SAC use log_alpha as log
        values are generally nicer to work with.
      target_entropy: The target average policy entropy, for updating alpha. The
        default value is negative of the total number of actions.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.
    """
        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        self._critic_network_1 = critic_network
        self._critic_network_1.create_variables()
        if target_critic_network:
            target_critic_network.create_variables()
        self._target_critic_network_1 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_1, target_critic_network,
                'TargetCriticNetwork1'))

        # for estimating critics without entropy included
        self._critic_network_no_entropy_1 = critic_network_no_entropy
        if critic_network_no_entropy is not None:
            self._critic_network_no_entropy_1.create_variables()
            self._target_critic_network_no_entropy_1 = (
                common.maybe_copy_target_network_with_checks(
                    self._critic_network_no_entropy_1, None,
                    'TargetCriticNetworkNoEntropy1'))
            # Network 2
            self._critic_network_no_entropy_2 = self._critic_network_no_entropy_1.copy(
                name='CriticNetworkNoEntropy2')
            self._critic_network_no_entropy_2.create_variables()
            self._target_critic_network_no_entropy_2 = (
                common.maybe_copy_target_network_with_checks(
                    self._critic_network_no_entropy_2, None,
                    'TargetCriticNetworkNoEntropy2'))

        if critic_network_2 is not None:
            self._critic_network_2 = critic_network_2
        else:
            self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
            # Do not use target_critic_network_2 if critic_network_2 is None.
            target_critic_network_2 = None
        self._critic_network_2.create_variables()
        if target_critic_network_2:
            target_critic_network_2.create_variables()
        self._target_critic_network_2 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_2, target_critic_network_2,
                'TargetCriticNetwork2'))

        if actor_network:
            actor_network.create_variables()
        self._actor_network = actor_network

        policy = actor_policy_ctor(time_step_spec=time_step_spec,
                                   action_spec=action_spec,
                                   actor_network=self._actor_network,
                                   training=False)

        self._train_policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=True)

        self._log_alpha = common.create_variable(
            'initial_log_alpha',
            initial_value=initial_log_alpha,
            dtype=tf.float32,
            trainable=True)

        if target_entropy is None:
            target_entropy = self._get_default_target_entropy(action_spec)

        self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._critic_no_entropy_optimizer = critic_no_entropy_optimizer
        self._alpha_optimizer = alpha_optimizer
        self._actor_loss_weight = actor_loss_weight
        self._critic_loss_weight = critic_loss_weight
        self._alpha_loss_weight = alpha_loss_weight
        self._td_errors_loss_fn = td_errors_loss_fn
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_entropy = target_entropy
        self._gradient_clipping = gradient_clipping
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, period=self._target_update_period)

        train_sequence_length = 2 if not critic_network.state_spec else None

        super(SacAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy=policy,
                             collect_policy=policy,
                             train_sequence_length=train_sequence_length,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)
Esempio n. 4
0
  def __init__(self,
               time_step_spec,
               action_spec,
               critic_network,
               actor_network,
               actor_optimizer,
               critic_optimizer,
               actor_loss_weight = 1.0,
               critic_loss_weight = 0.5,
               actor_policy_ctor = actor_policy.ActorPolicy,
               critic_network_2 = None,
               target_critic_network = None,
               target_critic_network_2 = None,
               target_update_tau = 1.0,
               target_update_period = 1,
               td_errors_loss_fn = tf.math.squared_difference,
               gamma = 1.0,
               reward_scale_factor = 1.0,
               gradient_clipping = None,
               debug_summaries = False,
               summarize_grads_and_vars = False,
               train_step_counter = None,
               name = None,
               n_step = None,
               use_behavior_policy = False):
    """Creates a RCE Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      critic_network: A function critic_network((observations, actions)) that
        returns the q_values for each observation and action.
      actor_network: A function actor_network(observation, action_spec) that
        returns action distribution.
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      actor_loss_weight: The weight on actor loss.
      critic_loss_weight: The weight on critic loss.
      actor_policy_ctor: The policy class to use.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target critic network during Q learning. Every
        `target_update_period` train steps, the weights from `critic_network`
        are copied (possibly withsmoothing via `target_update_tau`) to `
        target_critic_network`.  If `target_critic_network` is not provided, it
        is created by making a copy of `critic_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `critic_network`, so that this can be used as a target network.
        If you provide a `target_critic_network` that shares any weights with
        `critic_network`, a warning will be logged but no exception is thrown.
      target_critic_network_2: (Optional.) Similar network as
        target_critic_network but for the critic_network_2. See documentation
        for target_critic_network. Will only be used if 'critic_network_2' is
        also specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn:  A function for computing the elementwise TD errors
        loss.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.
      n_step: An integer specifying whether to use n-step returns. Empirically,
        a value of 10 works well for most tasks. Use None to disable n-step
        returns.
      use_behavior_policy: A boolean indicating how to sample actions for the
        success states. When use_behavior_policy=True, we use the historical
        average policy; otherwise, we use the current policy.
    """
    tf.Module.__init__(self, name=name)

    self._check_action_spec(action_spec)

    self._critic_network_1 = critic_network
    self._critic_network_1.create_variables(
        (time_step_spec.observation, action_spec))
    if target_critic_network:
      target_critic_network.create_variables(
          (time_step_spec.observation, action_spec))
      self._target_critic_network_1 = target_critic_network
    else:
      self._target_critic_network_1 = (
          common.maybe_copy_target_network_with_checks(self._critic_network_1,
                                                       None,
                                                       'TargetCriticNetwork1'))

    if critic_network_2 is not None:
      self._critic_network_2 = critic_network_2
    else:
      self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
      # Do not use target_critic_network_2 if critic_network_2 is None.
      target_critic_network_2 = None
    self._critic_network_2.create_variables(
        (time_step_spec.observation, action_spec))

    if target_critic_network_2:
      target_critic_network_2.create_variables(
          (time_step_spec.observation, action_spec))
      self._target_critic_network_2 = target_critic_network
    else:
      self._target_critic_network_2 = (
          common.maybe_copy_target_network_with_checks(self._critic_network_2,
                                                       None,
                                                       'TargetCriticNetwork2'))

    if actor_network:
      actor_network.create_variables(time_step_spec.observation)
    self._actor_network = actor_network

    self._use_behavior_policy = use_behavior_policy
    if use_behavior_policy:
      self._behavior_actor_network = actor_network.copy(
          name='BehaviorActorNetwork')
      self._behavior_policy = actor_policy_ctor(
          time_step_spec=time_step_spec,
          action_spec=action_spec,
          actor_network=self._behavior_actor_network,
          training=True)

    policy = actor_policy_ctor(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network,
        training=False)

    self._train_policy = actor_policy_ctor(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network,
        training=True)

    self._target_update_tau = target_update_tau
    self._target_update_period = target_update_period
    self._actor_optimizer = actor_optimizer
    self._critic_optimizer = critic_optimizer
    self._actor_loss_weight = actor_loss_weight
    self._critic_loss_weight = critic_loss_weight
    self._td_errors_loss_fn = td_errors_loss_fn
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._gradient_clipping = gradient_clipping
    self._debug_summaries = debug_summaries
    self._summarize_grads_and_vars = summarize_grads_and_vars
    self._update_target = self._get_target_updater(
        tau=self._target_update_tau, period=self._target_update_period)
    self._n_step = n_step

    train_sequence_length = 2 if not critic_network.state_spec else None

    super(RceAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy=policy,
        collect_policy=policy,
        train_sequence_length=train_sequence_length,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter,
        validate_args=False
    )

    self._as_transition = data_converter.AsTransition(
        self.data_context, squeeze_time_dim=(train_sequence_length == 2))
Esempio n. 5
0
  def __init__(
      self,
      time_step_spec: ts.TimeStep,
      action_spec: types.NestedTensorSpec,
      q_network: network.Network,
      optimizer: types.Optimizer,
      observation_and_action_constraint_splitter: Optional[
          types.Splitter] = None,
      epsilon_greedy: types.Float = 0.1,
      n_step_update: int = 1,
      boltzmann_temperature: Optional[types.Int] = None,
      emit_log_probability: bool = False,
      # Params for target network updates
      target_q_network: Optional[network.Network] = None,
      target_update_tau: types.Float = 1.0,
      target_update_period: int = 1,
      # Params for training.
      td_errors_loss_fn: Optional[types.LossFn] = None,
      gamma: types.Float = 1.0,
      reward_scale_factor: types.Float = 1.0,
      gradient_clipping: Optional[types.Float] = None,
      # Params for debugging
      debug_summaries: bool = False,
      summarize_grads_and_vars: bool = False,
      train_step_counter: Optional[tf.Variable] = None,
      name: Optional[Text] = None):
    """Creates a DQN Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A `tf_agents.network.Network` to be used by the agent. The
        network will be called with `call(observation, step_type)` and should
        emit logits over the action space.
      optimizer: The optimizer to use for training.
      observation_and_action_constraint_splitter: A function used to process
        observations with action constraints. These constraints can indicate,
        for example, a mask of valid/invalid actions for a given state of the
        environment.
        The function takes in a full observation and returns a tuple consisting
        of 1) the part of the observation intended as input to the network and
        2) the constraint. An example
        `observation_and_action_constraint_splitter` could be as simple as:
        ```
        def observation_and_action_constraint_splitter(observation):
          return observation['network_input'], observation['constraint']
        ```
        *Note*: when using `observation_and_action_constraint_splitter`, make
        sure the provided `q_network` is compatible with the network-specific
        half of the output of the `observation_and_action_constraint_splitter`.
        In particular, `observation_and_action_constraint_splitter` will be
        called on the observation before passing to the network.
        If `observation_and_action_constraint_splitter` is None, action
        constraints are not applied.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      n_step_update: The number of steps to consider when computing TD error and
        TD loss. Defaults to single-step updates. Note that this requires the
        user to call train on Trajectory objects with a time dimension of
        `n_step_update + 1`. However, note that we do not yet support
        `n_step_update > 1` in the case of RNNs (i.e., non-empty
        `q_network.state_spec`).
      boltzmann_temperature: Temperature value to use for Boltzmann sampling of
        the actions during data collection. The closer to 0.0, the higher the
        probability of choosing the best action.
      emit_log_probability: Whether policies emit log probabilities or not.
      target_q_network: (Optional.)  A `tf_agents.network.Network`
        to be used as the target network during Q learning.  Every
        `target_update_period` train steps, the weights from
        `q_network` are copied (possibly with smoothing via
        `target_update_tau`) to `target_q_network`.

        If `target_q_network` is not provided, it is created by
        making a copy of `q_network`, which initializes a new
        network with the same structure and its own layers and weights.

        Network copying is performed via the `Network.copy` superclass method,
        and may inadvertently lead to the resulting network to share weights
        with the original.  This can happen if, for example, the original
        network accepted a pre-built Keras layer in its `__init__`, or
        accepted a Keras layer that wasn't built, but neglected to create
        a new copy.

        In these cases, it is up to you to provide a target Network having
        weights that are not shared with the original `q_network`.
        If you provide a `target_q_network` that shares any
        weights with `q_network`, a warning will be logged but
        no exception is thrown.

        Note; shallow copies of Keras layers may be built via the code:

        ```python
        new_layer = type(layer).from_config(layer.get_config())
        ```
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      ValueError: If `action_spec` contains more than one action or action
        spec minimum is not equal to 0.
      ValueError: If the q networks do not emit floating point outputs with
        inner shape matching `action_spec`.
      NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an
        RNN is provided) and `n_step_update > 1`.
    """
    tf.Module.__init__(self, name=name)

    self._check_action_spec(action_spec)

    if epsilon_greedy is not None and boltzmann_temperature is not None:
      raise ValueError(
          'Configured both epsilon_greedy value {} and temperature {}, '
          'however only one of them can be used for exploration.'.format(
              epsilon_greedy, boltzmann_temperature))

    self._observation_and_action_constraint_splitter = (
        observation_and_action_constraint_splitter)
    self._q_network = q_network
    net_observation_spec = time_step_spec.observation
    if observation_and_action_constraint_splitter:
      net_observation_spec, _ = observation_and_action_constraint_splitter(
          net_observation_spec)
    q_network.create_variables(net_observation_spec)
    if target_q_network:
      target_q_network.create_variables(net_observation_spec)
    self._target_q_network = common.maybe_copy_target_network_with_checks(
        self._q_network, target_q_network, input_spec=net_observation_spec,
        name='TargetQNetwork')

    self._check_network_output(self._q_network, 'q_network')
    self._check_network_output(self._target_q_network, 'target_q_network')

    self._epsilon_greedy = epsilon_greedy
    self._n_step_update = n_step_update
    self._boltzmann_temperature = boltzmann_temperature
    self._optimizer = optimizer
    self._td_errors_loss_fn = (
        td_errors_loss_fn or common.element_wise_huber_loss)
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._gradient_clipping = gradient_clipping
    self._update_target = self._get_target_updater(
        target_update_tau, target_update_period)

    policy, collect_policy = self._setup_policy(time_step_spec, action_spec,
                                                boltzmann_temperature,
                                                emit_log_probability)

    if q_network.state_spec and n_step_update != 1:
      raise NotImplementedError(
          'DqnAgent does not currently support n-step updates with stateful '
          'networks (i.e., RNNs), but n_step_update = {}'.format(n_step_update))

    train_sequence_length = (
        n_step_update + 1 if not q_network.state_spec else None)

    super(DqnAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=train_sequence_length,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter,
        validate_args=False,
    )

    if q_network.state_spec:
      # AsNStepTransition does not support emitting [B, T, ...] tensors,
      # which we need for DQN-RNN.
      self._as_transition = data_converter.AsTransition(
          self.data_context, squeeze_time_dim=False)
    else:
      # This reduces the n-step return and removes the extra time dimension,
      # allowing the rest of the computations to be independent of the
      # n-step parameter.
      self._as_transition = data_converter.AsNStepTransition(
          self.data_context, gamma=gamma, n=n_step_update)
Esempio n. 6
0
    def __init__(
            self,
            time_step_spec,
            action_spec,
            q_network,
            optimizer,
            actions_sampler,
            epsilon_greedy=0.1,
            n_step_update=1,
            emit_log_probability=False,
            in_graph_bellman_update=True,
            # Params for cem
            init_mean_cem=None,
            init_var_cem=None,
            num_samples_cem=32,
            num_elites_cem=4,
            num_iter_cem=3,
            # Params for target network updates
            target_q_network=None,
            target_update_tau=1.0,
            target_update_period=1,
            enable_td3=True,
            target_q_network_delayed=None,
            target_q_network_delayed_2=None,
            delayed_target_update_period=5,
            # Params for training.
            td_errors_loss_fn=None,
            auxiliary_loss_fns=None,
            gamma=1.0,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for debugging
            debug_summaries=False,
            summarize_grads_and_vars=False,
            train_step_counter=None,
            info_spec=None,
            name=None):
        """Creates a Qtopt Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call((observation, action), step_type). The
        q_network is different from the one used in DQN where the input is state
        and the output has multiple dimension representing Q values for
        different actions. The input of this q_network is a tuple of state and
        action. The output is one dimension representing Q value for that
        specific action. DDPG critic network can be used directly here.
      optimizer: The optimizer to use for training.
      actions_sampler: A tf_agents.policies.sampler.ActionsSampler to be used to
        sample actions in CEM.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      n_step_update: Currently, only n_step_update == 1 is supported.
      emit_log_probability: Whether policies emit log probabilities or not.
      in_graph_bellman_update: If False, configures the agent to expect
        experience containing computed q_values in the policy_step's info field.
        This allows simplifies splitting the loss calculation across several
        jobs.
      init_mean_cem: Initial mean value of the Gaussian distribution to sample
        actions for CEM.
      init_var_cem: Initial variance value of the Gaussian distribution to
        sample actions for CEM.
      num_samples_cem: Number of samples to sample for each iteration in CEM.
      num_elites_cem: Number of elites to select for each iteration in CEM.
      num_iter_cem: Number of iterations in CEM.
      target_q_network: (Optional.)  A `tf_agents.network.Network`
        to be used as the target network during Q learning.  Every
        `target_update_period` train steps, the weights from
        `q_network` are copied (possibly with smoothing via
        `target_update_tau`) to `target_q_network`.

        If `target_q_network` is not provided, it is created by
        making a copy of `q_network`, which initializes a new
        network with the same structure and its own layers and weights.

        Network copying is performed via the `Network.copy` superclass method,
        with the same arguments used during the original network's construction
        and may inadvertently lead to weights being shared between networks.
        This can happen if, for example, the original
        network accepted a pre-built Keras layer in its `__init__`, or
        accepted a Keras layer that wasn't built, but neglected to create
        a new copy.

        In these cases, it is up to you to provide a target Network having
        weights that are not shared with the original `q_network`.
        If you provide a `target_q_network` that shares any
        weights with `q_network`, an exception is thrown.

      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      enable_td3: Whether or not to enable using a delayed target network to
        calculate q value and assign min(q_delayed, q_delayed_2) as
        q_next_state.
      target_q_network_delayed: (Optional.) Similar network as
        target_q_network but lags behind even more. See documentation
        for target_q_network. Will only be used if 'enable_td3' is True.
      target_q_network_delayed_2: (Optional.) Similar network as
        target_q_network_delayed but lags behind even more. See documentation
        for target_q_network. Will only be used if 'enable_td3' is True.
      delayed_target_update_period: Used when enable_td3 is true. Period for
        soft update of the delayed target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      auxiliary_loss_fns: An optional list of functions for computing auxiliary
        losses. Each auxiliary_loss_fn expects network and transition as
        input and should output auxiliary_loss and auxiliary_reg_loss.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      info_spec: If not None, the policy info spec is set to this spec.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.

    Raises:
      ValueError: If the action spec contains more than one action or action
        spec minimum is not equal to 0.
      NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an
        RNN is provided) and `n_step_update > 1`.
    """
        tf.Module.__init__(self, name=name)

        self._sampler = actions_sampler
        self._init_mean_cem = init_mean_cem
        self._init_var_cem = init_var_cem
        self._num_samples_cem = num_samples_cem
        self._num_elites_cem = num_elites_cem
        self._num_iter_cem = num_iter_cem
        self._in_graph_bellman_update = in_graph_bellman_update
        if not in_graph_bellman_update:
            if info_spec is not None:
                self._info_spec = info_spec
            else:
                self._info_spec = {
                    'target_q': tensor_spec.TensorSpec((), tf.float32),
                }
        else:
            self._info_spec = ()

        self._q_network = q_network
        net_observation_spec = (time_step_spec.observation, action_spec)

        q_network.create_variables(net_observation_spec)

        if target_q_network:
            target_q_network.create_variables(net_observation_spec)

        self._target_q_network = common.maybe_copy_target_network_with_checks(
            self._q_network,
            target_q_network,
            input_spec=net_observation_spec,
            name='TargetQNetwork')

        self._target_updater = self._get_target_updater(
            target_update_tau, target_update_period)

        self._enable_td3 = enable_td3

        if (not self._enable_td3
                and (target_q_network_delayed or target_q_network_delayed_2)):
            raise ValueError(
                'enable_td3 is set to False but target_q_network_delayed'
                ' or target_q_network_delayed_2 is passed.')

        if self._enable_td3:
            if target_q_network_delayed:
                target_q_network_delayed.create_variables()
            self._target_q_network_delayed = (
                common.maybe_copy_target_network_with_checks(
                    self._q_network, target_q_network_delayed,
                    'TargetQNetworkDelayed'))
            self._target_updater_delayed = self._get_target_updater_delayed(
                1.0, delayed_target_update_period)

            if target_q_network_delayed_2:
                target_q_network_delayed_2.create_variables()
            self._target_q_network_delayed_2 = (
                common.maybe_copy_target_network_with_checks(
                    self._q_network, target_q_network_delayed_2,
                    'TargetQNetworkDelayed2'))
            self._target_updater_delayed_2 = self._get_target_updater_delayed_2(
                1.0, delayed_target_update_period)

            self._update_target = self._update_both
        else:
            self._update_target = self._target_updater
            self._target_q_network_delayed = None
            self._target_q_network_delayed_2 = None

        self._check_network_output(self._q_network, 'q_network')
        self._check_network_output(self._target_q_network, 'target_q_network')

        self._epsilon_greedy = epsilon_greedy
        self._n_step_update = n_step_update
        self._optimizer = optimizer
        self._td_errors_loss_fn = (td_errors_loss_fn
                                   or common.element_wise_huber_loss)
        self._auxiliary_loss_fns = auxiliary_loss_fns
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._gradient_clipping = gradient_clipping

        policy, collect_policy = self._setup_policy(time_step_spec,
                                                    action_spec,
                                                    emit_log_probability)

        if q_network.state_spec and n_step_update != 1:
            raise NotImplementedError(
                'QtOptAgent does not currently support n-step updates with stateful '
                'networks (i.e., RNNs), but n_step_update = {}'.format(
                    n_step_update))

        # Bypass the train_sequence_length check when RNN is used.
        train_sequence_length = (n_step_update +
                                 1 if not q_network.state_spec else None)

        super(QtOptAgent, self).__init__(
            time_step_spec,
            action_spec,
            policy,
            collect_policy,
            train_sequence_length=train_sequence_length,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step_counter,
        )

        self._setup_data_converter(q_network, gamma, n_step_update)
Esempio n. 7
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 actor_network,
                 critic_network,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 ou_stddev=1.0,
                 ou_damping=1.0,
                 target_actor_network=None,
                 target_critic_network=None,
                 target_update_tau=1.0,
                 target_update_period=1,
                 dqda_clipping=None,
                 td_errors_loss_fn=None,
                 gamma=1.0,
                 reward_scale_factor=1.0,
                 gradient_clipping=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 train_step_counter=None,
                 name=None):
        """Creates a DDPG Agent.
    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      actor_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type[, policy_state])
        and should return (action, new_state).
      critic_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call((observation, action), step_type[,
        policy_state]) and should return (q_value, new_state).
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizer: The optimizer to use for the critic network.
      ou_stddev: Standard deviation for the Ornstein-Uhlenbeck (OU) noise added
        in the default collect policy.
      ou_damping: Damping factor for the OU noise added in the default collect
        policy.
      target_actor_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the actor target network during Q learning.  Every
        `target_update_period` train steps, the weights from `actor_network` are
        copied (possibly withsmoothing via `target_update_tau`) to `
        target_q_network`.
        If `target_actor_network` is not provided, it is created by making a
        copy of `actor_network`, which initializes a new network with the same
        structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or
        when the network is sharing layers with another).  In these cases, it is
        up to you to build a copy having weights that are not
        shared with the original `actor_network`, so that this can be used as a
        target network.  If you provide a `target_actor_network` that shares any
        weights with `actor_network`, a warning will be logged but no exception
        is thrown.
      target_critic_network: (Optional.) Similar network as target_actor_network
         but for the critic_network. See documentation for target_actor_network.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      dqda_clipping: when computing the actor loss, clips the gradient dqda
        element-wise between [-dqda_clipping, dqda_clipping]. Does not perform
        clipping if dqda_clipping == 0.
      td_errors_loss_fn:  A function for computing the TD errors loss. If None,
        a default value of elementwise huber_loss is used.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.
    """
        tf.Module.__init__(self, name=name)
        self._actor_network = actor_network
        actor_network.create_variables()
        if target_actor_network:
            target_actor_network.create_variables()
        self._target_actor_network = common.maybe_copy_target_network_with_checks(
            self._actor_network, target_actor_network, 'TargetActorNetwork')
        self._critic_network = critic_network
        critic_network.create_variables()
        if target_critic_network:
            target_critic_network.create_variables()
        self._target_critic_network = common.maybe_copy_target_network_with_checks(
            self._critic_network, target_critic_network, 'TargetCriticNetwork')

        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer

        self._ou_stddev = ou_stddev
        self._ou_damping = ou_damping
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._dqda_clipping = dqda_clipping
        self._td_errors_loss_fn = (td_errors_loss_fn
                                   or common.element_wise_huber_loss)
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._gradient_clipping = gradient_clipping

        self._update_target = self._get_target_updater(target_update_tau,
                                                       target_update_period)
        """Nitty: change time_step_spec to that of individual agent from total spec"""
        individual_time_step_spec = ts.get_individual_time_step_spec(
            time_step_spec)
        policy = actor_policy.ActorPolicy(
            time_step_spec=individual_time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            clip=True)
        collect_policy = actor_policy.ActorPolicy(
            time_step_spec=individual_time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            clip=False)

        # policy = actor_policy.ActorPolicy(
        #     time_step_spec=time_step_spec, action_spec=action_spec,
        #     actor_network=self._actor_network, clip=True)
        # collect_policy = actor_policy.ActorPolicy(
        #     time_step_spec=time_step_spec, action_spec=action_spec,
        #     actor_network=self._actor_network, clip=False)
        collect_policy = ou_noise_policy.OUNoisePolicy(
            collect_policy,
            ou_stddev=self._ou_stddev,
            ou_damping=self._ou_damping,
            clip=True)

        super(DdpgAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy,
                             collect_policy,
                             train_sequence_length=2
                             if not self._actor_network.state_spec else None,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)
Esempio n. 8
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 q_networks,
                 critic_optimizer,
                 exploration_noise_std=0.1,
                 boltzmann_temperature=10.0,
                 epsilon_greedy=0.1,
                 q_networks_2=None,
                 target_q_networks=None,
                 target_q_networks_2=None,
                 target_update_tau=1.0,
                 target_update_period=1,
                 dqda_clipping=None,
                 td_errors_loss_fn=None,
                 gamma=1.0,
                 reward_scale_factor=1.0,
                 target_policy_noise=0.2,
                 target_policy_noise_clip=0.5,
                 gradient_clipping=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 train_step_counter=None,
                 action_params_mask=None,
                 n_step_update=1,
                 name=None):
        """Creates a Td3Agent Agent.

        Args:
          time_step_spec: A `TimeStep` spec of the expected time_steps.
          action_spec: A namedtuple of nested BoundedTensorSpec representing the actions.
          action_spec: A list of tf_agents.network.Network to be used by the agent. The
            network will be called with call(observation, action, step_type).
          q_networks: A tf_agents.network.Network to be used by the agent. The
            network will be called with call(observation, action, step_type).
          critic_optimizer: The default optimizer to use for the critic network.
          exploration_noise_std: Scale factor on exploration policy noise.
          q_networks_2: (Optional.)  A `tf_agents.network.Network` to be used as
            the second critic network during Q learning.  The weights from
            `q_network` are copied if this is not provided.
          target_q_networks: (Optional.)  A `tf_agents.network.Network` to be
            used as the target Q network during Q learning. Every
            `target_update_period` train steps, the weights from `q_networks` are
            copied (possibly withsmoothing via `target_update_tau`) to `
            target_actor_network`.  If `target_actor_network` is not provided, it is
            created by making a copy of `actor_network`, which initializes a new
            network with the same structure and its own layers and weights.
            Performing a `Network.copy` does not work when the network instance
            already has trainable parameters (e.g., has already been built, or when
            the network is sharing layers with another).  In these cases, it is up
            to you to build a copy having weights that are not shared with the
            original `actor_network`, so that this can be used as a target network.
            If you provide a `target_actor_network` that shares any weights with
            `actor_network`, a warning will be logged but no exception is thrown.
          target_q_networks_2: (Optional.) Similar network as target_actor_network
            but for the q_network. See documentation for target_actor_network.
          target_update_tau: Factor for soft update of the target networks.
          target_update_period: Period for soft update of the target networks.
          dqda_clipping: A scalar or float clips the gradient dqda element-wise
            between [-dqda_clipping, dqda_clipping]. Default is None representing no
            clippiing.
          td_errors_loss_fn:  A function for computing the TD errors loss. If None,
            a default value of elementwise huber_loss is used.
          gamma: A discount factor for future rewards.
          reward_scale_factor: Multiplicative scale for the reward.
          target_policy_noise: Scale factor on target action noise
          target_policy_noise_clip: Value to clip noise.
          gradient_clipping: Norm length to clip gradients.
          debug_summaries: A bool to gather debug summaries.
          summarize_grads_and_vars: If True, gradient and network variable summaries
            will be written during training.
          train_step_counter: An optional counter to increment every time the train
            op is run.  Defaults to the global_step.
          action_params_mask: A mask of continuous parameter actions for discrete action
          name: The name of this agent. All variables in this module will fall
            under that name. Defaults to the class name.
        """
        tf.Module.__init__(self, name=name)

        # critic network here is Q-network
        if not isinstance(q_networks, (list, tuple)):
            q_networks = (q_networks,)
        self._q_network_1 = q_networks

        if target_q_networks is not None and not isinstance(target_q_networks, (list, tuple)):
            target_q_networks = (target_q_networks,)
        if target_q_networks is None:
            target_q_networks = [None]*len(self._q_network_1)
        assert len(self._q_network_1) == len(target_q_networks)

        self._target_q_network_1 = [common.maybe_copy_target_network_with_checks(
                    q, target_q, 'Target'+q.name+'_1')
                for q, target_q in zip(self._q_network_1, target_q_networks)]

        if q_networks_2 is not None:
            self._q_network_2 = q_networks_2
        else:
            self._q_network_2 = [q.copy(name=q.name+'_2') for q in q_networks]
            # Do not use target_q_network_2 if q_network_2 is None.
            target_q_networks_2 = None

        if target_q_networks_2 is not None and not isinstance(target_q_networks, (list, tuple)):
            target_q_networks_2 = (target_q_networks_2,)
        if target_q_networks_2 is None:
            target_q_networks_2 = [None]*len(self._q_network_2)

        self._target_q_network_2 = [ common.maybe_copy_target_network_with_checks(
                    q, target_q, 'Target'+q.name+'_2')
                for q, target_q in zip(self._q_network_2, target_q_networks_2)]

        self._critic_optimizer = critic_optimizer

        self._exploration_noise_std = exploration_noise_std
        self._epsilon_greedy = epsilon_greedy
        self._boltzmann_temperature = boltzmann_temperature
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._dqda_clipping = dqda_clipping
        self._td_errors_loss_fn = (
                td_errors_loss_fn or common.element_wise_huber_loss)
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_policy_noise = target_policy_noise
        self._target_policy_noise_clip = target_policy_noise_clip
        self._gradient_clipping = gradient_clipping

        self._update_target = self._get_target_updater(
            target_update_tau, target_update_period)

        num_networks = len(self._q_network_1)
        if len(action_spec) != num_networks or len(time_step_spec) != num_networks:
            raise ValueError("number of networks should match number of action_spec and time_step_spec")

        policies = []
        self._q_value_policies_1, self._q_value_policies_2 = [], []
        for i, q_network in enumerate(self._q_network_1):
            use_previous_action = i != 0
            policies.append(hetero_q_policy.HeteroQPolicy(time_step_spec=time_step_spec[i], action_spec=action_spec[i],
                              mixed_q_network=q_networks[i], func_arg_mask=action_params_mask[i],
                              use_previous_action=use_previous_action, name="HeteroQPolicy_1_"+str(i)))
            self._q_value_policies_2.append(
                        hetero_q_policy.HeteroQPolicy(time_step_spec=time_step_spec[i], action_spec=action_spec[i],
                              mixed_q_network=self._q_network_2[i], func_arg_mask=action_params_mask[i],
                              use_previous_action=use_previous_action, name="HeteroQPolicy_2_"+str(i)))
        self._q_value_policies_1 = policies

        collect_policies = [epsilon_boltzmann_policy.EpsilonBoltzmannPolicy(
            p, temperature=boltzmann_temperature, epsilon=self._epsilon_greedy, remove_neg_inf=True) for p in policies]

        collect_policy = sequential_policy.SequentialPolicy(collect_policies)

        policies = [greedy_policy.GreedyPolicy(p, remove_neg_inf=True) for p in policies]
        policy = sequential_policy.SequentialPolicy(policies)

        # Create self._target_greedy_policy in order to compute target Q-values in _compute_next_q_values.
        target_policies = []
        self._target_q_value_policies_1, self._target_q_value_policies_2 = [], []
        for i, q_network in enumerate(self._q_network_1):
            use_previous_action = i != 0
            target_policies.append(hetero_q_policy.HeteroQPolicy(
                              time_step_spec=time_step_spec[i], action_spec=action_spec[i],
                              mixed_q_network=self._target_q_network_1[i], func_arg_mask=action_params_mask[i],
                              use_previous_action=use_previous_action, name="TargetHeteroQPolicy_1"+str(i)))
            self._target_q_value_policies_2.append(hetero_q_policy.HeteroQPolicy(
                              time_step_spec=time_step_spec[i], action_spec=action_spec[i],
                              mixed_q_network=self._target_q_network_2[i], func_arg_mask=action_params_mask[i],
                              use_previous_action=use_previous_action, name="TargetHeteroQPolicy_2"+str(i)))
        self._target_q_value_policies_1 = target_policies

        self._target_q_value_policies_1 = [greedy_policy.GreedyPolicy(p, remove_neg_inf=True) for p in self._target_q_value_policies_1]
        self._target_q_value_policies_2 = [greedy_policy.GreedyPolicy(p, remove_neg_inf=True) for p in self._target_q_value_policies_2]

        target_policies = self._target_q_value_policies_1
        target_policy = sequential_policy.SequentialPolicy(target_policies)

        self._target_greedy_policies = target_policy

        self._action_params_mask = action_params_mask
        self._n_step_update = n_step_update

        super(Td3DqnAgent, self).__init__(
            policy.time_step_spec,
            policy.action_spec,
            policy,
            collect_policy,
            train_sequence_length=2 if not self._q_network_1[0].state_spec else None,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=train_step_counter)
Esempio n. 9
0
def train_eval(
        load_root_dir,
        env_load_fn=None,
        gym_env_wrappers=[],
        monitor=False,
        env_name=None,
        agent_class=None,
        train_metrics_callback=None,
        # SacAgent args
        actor_fc_layers=(256, 256),
        critic_joint_fc_layers=(256, 256),
        # Safety Critic training args
        safety_critic_joint_fc_layers=None,
        safety_critic_lr=3e-4,
        safety_critic_bias_init_val=None,
        safety_critic_kernel_scale=None,
        n_envs=None,
        target_safety=0.2,
        fail_weight=None,
        # Params for train
        num_global_steps=10000,
        batch_size=256,
        # Params for eval
        run_eval=False,
        eval_metrics=[],
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for summaries and logging
        train_checkpoint_interval=10000,
        summary_interval=1000,
        monitor_interval=5000,
        summaries_flush_secs=10,
        debug_summaries=False,
        seed=None):

    if isinstance(agent_class, str):
        assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(
            agent_class)
        agent_class = ALGOS.get(agent_class)

    train_ckpt_dir = osp.join(load_root_dir, 'train')
    rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer')

    py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)

    if monitor:
        vid_path = os.path.join(load_root_dir, 'rollouts')
        monitor_env_wrapper = misc.monitor_freq(1, vid_path)
        monitor_env = gym.make(env_name)
        for wrapper in gym_env_wrappers:
            monitor_env = wrapper(monitor_env)
        monitor_env = monitor_env_wrapper(monitor_env)
        # auto_reset must be False to ensure Monitor works correctly
        monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

    if run_eval:
        eval_dir = os.path.join(load_root_dir, 'eval')
        n_envs = n_envs or num_eval_episodes
        eval_summary_writer = tf.compat.v2.summary.create_file_writer(
            eval_dir, flush_millis=summaries_flush_secs * 1000)
        eval_metrics = [
            tf_metrics.AverageReturnMetric(prefix='EvalMetrics',
                                           buffer_size=num_eval_episodes,
                                           batch_size=n_envs),
            tf_metrics.AverageEpisodeLengthMetric(
                prefix='EvalMetrics',
                buffer_size=num_eval_episodes,
                batch_size=n_envs)
        ] + [
            tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name))
            for m in eval_metrics
        ]
        eval_tf_env = tf_py_environment.TFPyEnvironment(
            parallel_py_environment.ParallelPyEnvironment([
                lambda: env_load_fn(env_name,
                                    gym_env_wrappers=gym_env_wrappers)
            ] * n_envs))
        if seed:
            seeds = [seed * n_envs + i for i in range(n_envs)]
            try:
                eval_tf_env.pyenv.seed(seeds)
            except:
                pass

    global_step = tf.compat.v1.train.get_or_create_global_step()

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)

    critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
        safety_critic_net = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=critic_joint_fc_layers)
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               safety_critic_network=safety_critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)
    else:
        tf_agent = agent_class(time_step_spec,
                               action_spec,
                               actor_network=actor_net,
                               critic_network=critic_net,
                               train_step_counter=global_step,
                               debug_summaries=False)

    collect_data_spec = tf_agent.collect_data_spec
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=1000000)
    replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer)

    tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent)
    if agent_class in SAFETY_AGENTS:
        target_safety = target_safety or tf_agent._target_safety
    loaded_train_steps = global_step.numpy()
    logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir,
                 loaded_train_steps)
    global_step.assign(0)
    tf.summary.experimental.set_step(global_step)

    thresholds = [target_safety, 0.5]
    sc_metrics = [
        tf.keras.metrics.AUC(name='safety_critic_auc'),
        tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                        threshold=0.5),
        tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                       thresholds=thresholds),
        tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                        thresholds=thresholds),
        tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                       thresholds=thresholds),
        tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                        thresholds=thresholds)
    ]

    if seed:
        tf.compat.v1.set_random_seed(seed)

    summaries_flush_secs = 10
    timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S')
    offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp)
    config_saver = gin.tf.GinConfigSaverHook(offline_train_dir,
                                             summarize_config=True)
    tf.function(config_saver.after_create_session)()

    sc_summary_writer = tf.compat.v2.summary.create_file_writer(
        offline_train_dir, flush_millis=summaries_flush_secs * 1000)
    sc_summary_writer.set_as_default()

    if safety_critic_kernel_scale is not None:
        ki = tf.compat.v1.variance_scaling_initializer(
            scale=safety_critic_kernel_scale,
            mode='fan_in',
            distribution='truncated_normal')
    else:
        ki = tf.compat.v1.keras.initializers.VarianceScaling(
            scale=1. / 3., mode='fan_in', distribution='uniform')

    if safety_critic_bias_init_val is not None:
        bi = tf.constant_initializer(safety_critic_bias_init_val)
    else:
        bi = None
    sc_net_off = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=safety_critic_joint_fc_layers,
        kernel_initializer=ki,
        value_bias_initializer=bi,
        name='SafetyCriticOffline')
    sc_net_off.create_variables()
    target_sc_net_off = common.maybe_copy_target_network_with_checks(
        sc_net_off, None, 'TargetSafetyCriticNetwork')
    optimizer = tf.keras.optimizers.Adam(safety_critic_lr)
    sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic')
    sc_checkpointer = common.Checkpointer(
        ckpt_dir=sc_net_off_ckpt_dir,
        safety_critic=sc_net_off,
        target_safety_critic=target_sc_net_off,
        optimizer=optimizer,
        global_step=global_step,
        max_to_keep=5)
    sc_checkpointer.initialize_or_restore()

    resample_counter = py_metrics.CounterMetric('ActionResampleCounter')
    eval_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=actor_net,
        safety_critic_network=sc_net_off,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)

    dataset = replay_buffer.as_dataset(num_parallel_calls=3,
                                       num_steps=2,
                                       sample_batch_size=batch_size //
                                       2).prefetch(3)
    data = iter(dataset)
    full_data = replay_buffer.gather_all()

    fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool)
    fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, fail_mask), full_data)
    init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data)
    before_fail_mask = tf.roll(fail_mask, [-1], axis=[1])
    after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1])
    before_fail_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data)
    after_init_step = nest_utils.fast_map_structure(
        lambda *x: tf.boolean_mask(*x, after_init_mask), full_data)

    filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask))
    filter_mask = tf.pad(
        filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]])
    n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy()

    failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec,
        batch_size=1,
        max_length=n_failures,
        dataset_window_shift=1)
    data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask)

    sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3,
                                               sample_batch_size=batch_size //
                                               2,
                                               num_steps=2).prefetch(3)
    neg_data = iter(sc_dataset_neg)

    get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0]
    eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step,
                                after_init_step, get_action)

    losses = []
    mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss')
    target_update = train_utils.get_target_updater(sc_net_off,
                                                   target_sc_net_off)

    with tf.summary.record_if(
            lambda: tf.math.equal(global_step % summary_interval, 0)):
        while global_step.numpy() < num_global_steps:
            pos_experience, _ = next(data)
            neg_experience, _ = next(neg_data)
            exp = data_utils.concat_batches(pos_experience, neg_experience,
                                            collect_data_spec)
            boundary_mask = tf.logical_not(exp.is_boundary()[:, 0])
            exp = nest_utils.fast_map_structure(
                lambda *x: tf.boolean_mask(*x, boundary_mask), exp)
            safe_rew = exp.observation['task_agn_rew'][:, 1]
            if fail_weight:
                weights = tf.where(tf.cast(safe_rew, tf.bool),
                                   fail_weight / 0.5, (1 - fail_weight) / 0.5)
            else:
                weights = None
            train_loss, sc_loss, lam_loss = train_step(
                exp,
                safe_rew,
                tf_agent,
                sc_net=sc_net_off,
                target_sc_net=target_sc_net_off,
                metrics=sc_metrics,
                weights=weights,
                target_safety=target_safety,
                optimizer=optimizer,
                target_update=target_update,
                debug_summaries=debug_summaries)
            global_step.assign_add(1)
            global_step_val = global_step.numpy()
            losses.append(
                (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy()))
            mean_loss(train_loss)
            with tf.name_scope('Losses'):
                tf.compat.v2.summary.scalar(name='sc_loss',
                                            data=sc_loss,
                                            step=global_step_val)
                tf.compat.v2.summary.scalar(name='lam_loss',
                                            data=lam_loss,
                                            step=global_step_val)
                if global_step_val % summary_interval == 0:
                    tf.compat.v2.summary.scalar(name=mean_loss.name,
                                                data=mean_loss.result(),
                                                step=global_step_val)
            if global_step_val % summary_interval == 0:
                with tf.name_scope('Metrics'):
                    for metric in sc_metrics:
                        if len(tf.squeeze(metric.result()).shape) == 0:
                            tf.compat.v2.summary.scalar(name=metric.name,
                                                        data=metric.result(),
                                                        step=global_step_val)
                        else:
                            fmt_str = '_{}'.format(thresholds[0])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[0],
                                step=global_step_val)
                            fmt_str = '_{}'.format(thresholds[1])
                            tf.compat.v2.summary.scalar(
                                name=metric.name + fmt_str,
                                data=metric.result()[1],
                                step=global_step_val)
                        metric.reset_states()
            if global_step_val % eval_interval == 0:
                eval_sc(sc_net_off, step=global_step_val)
                if run_eval:
                    results = metric_utils.eager_compute(
                        eval_metrics,
                        eval_tf_env,
                        eval_policy,
                        num_episodes=num_eval_episodes,
                        train_step=global_step,
                        summary_writer=eval_summary_writer,
                        summary_prefix='EvalMetrics',
                    )
                    if train_metrics_callback is not None:
                        train_metrics_callback(results, global_step_val)
                    metric_utils.log_metrics(eval_metrics)
                    with eval_summary_writer.as_default():
                        for eval_metric in eval_metrics[2:]:
                            eval_metric.tf_summaries(
                                train_step=global_step,
                                step_metrics=eval_metrics[:2])
            if monitor and global_step_val % monitor_interval == 0:
                monitor_time_step = monitor_py_env.reset()
                monitor_policy_state = eval_policy.get_initial_state(1)
                ep_len = 0
                monitor_start = time.time()
                while not monitor_time_step.is_last():
                    monitor_action = eval_policy.action(
                        monitor_time_step, monitor_policy_state)
                    action, monitor_policy_state = monitor_action.action, monitor_action.state
                    monitor_time_step = monitor_py_env.step(action)
                    ep_len += 1
                logging.debug(
                    'saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                    global_step_val, ep_len,
                    time.time() - monitor_start)

            if global_step_val % train_checkpoint_interval == 0:
                sc_checkpointer.save(global_step=global_step_val)
Esempio n. 10
0
  def __init__(
      self,
      time_step_spec,
      action_spec,
      q_network,
      optimizer,
      epsilon_greedy=0.1,
      n_step_update=1,
      boltzmann_temperature=None,
      emit_log_probability=False,
      # Params for target network updates
      target_q_network=None,
      target_update_tau=1.0,
      target_update_period=1,
      # Params for training.
      td_errors_loss_fn=None,
      gamma=1.0,
      reward_scale_factor=1.0,
      gradient_clipping=None,
      # Params for debugging
      debug_summaries=False,
      summarize_grads_and_vars=False,
      train_step_counter=None,
      name=None):
    """Creates a DQN Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      q_network: A `tf_agents.network.Network` to be used by the agent. The
        network will be called with `call(observation, step_type)` and should
        emit logits over the action space.
      optimizer: The optimizer to use for training.
      epsilon_greedy: probability of choosing a random action in the default
        epsilon-greedy collect policy (used only if a wrapper is not provided to
        the collect_policy method).
      n_step_update: The number of steps to consider when computing TD error and
        TD loss. Defaults to single-step updates. Note that this requires the
        user to call train on Trajectory objects with a time dimension of
        `n_step_update + 1`. However, note that we do not yet support
        `n_step_update > 1` in the case of RNNs (i.e., non-empty
        `q_network.state_spec`).
      boltzmann_temperature: Temperature value to use for Boltzmann sampling of
        the actions during data collection. The closer to 0.0, the higher the
        probability of choosing the best action.
      emit_log_probability: Whether policies emit log probabilities or not.
      target_q_network: (Optional.)  A `tf_agents.network.Network`
        to be used as the target network during Q learning.  Every
        `target_update_period` train steps, the weights from
        `q_network` are copied (possibly with smoothing via
        `target_update_tau`) to `target_q_network`.

        If `target_q_network` is not provided, it is created by
        making a copy of `q_network`, which initializes a new
        network with the same structure and its own layers and weights.

        Network copying is performed via the `Network.copy` superclass method,
        and may inadvertently lead to the resulting network to share weights
        with the original.  This can happen if, for example, the original
        network accepted a pre-built Keras layer in its `__init__`, or
        accepted a Keras layer that wasn't built, but neglected to create
        a new copy.

        In these cases, it is up to you to provide a target Network having
        weights that are not shared with the original `q_network`.
        If you provide a `target_q_network` that shares any
        weights with `q_network`, a warning will be logged but
        no exception is thrown.

        Note; shallow copies of Keras layers may be built via the code:

        ```python
        new_layer = type(layer).from_config(layer.get_config())
        ```
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn: A function for computing the TD errors loss. If None, a
        default value of element_wise_huber_loss is used. This function takes as
        input the target and the estimated Q values and returns the loss for
        each element of the batch.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.

    Raises:
      ValueError: If the action spec contains more than one action or action
        spec minimum is not equal to 0.
      NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an
        RNN is provided) and `n_step_update > 1`.
    """
    tf.Module.__init__(self, name=name)

    self._check_action_spec(action_spec)

    if epsilon_greedy is not None and boltzmann_temperature is not None:
      raise ValueError(
          'Configured both epsilon_greedy value {} and temperature {}, '
          'however only one of them can be used for exploration.'.format(
              epsilon_greedy, boltzmann_temperature))

    self._q_network = q_network
    self._target_q_network = common.maybe_copy_target_network_with_checks(
        self._q_network, target_q_network, 'TargetQNetwork')

    self._epsilon_greedy = epsilon_greedy
    self._n_step_update = n_step_update
    self._boltzmann_temperature = boltzmann_temperature
    self._optimizer = optimizer
    self._td_errors_loss_fn = td_errors_loss_fn or common.element_wise_huber_loss
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._gradient_clipping = gradient_clipping
    self._update_target = self._get_target_updater(
        target_update_tau, target_update_period)

    policy, collect_policy = self._setup_policy(time_step_spec, action_spec,
                                                boltzmann_temperature,
                                                emit_log_probability)

    if q_network.state_spec and n_step_update != 1:
      raise NotImplementedError(
          'DqnAgent does not currently support n-step updates with stateful '
          'networks (i.e., RNNs), but n_step_update = {}'.format(n_step_update))

    train_sequence_length = (
        n_step_update + 1 if not q_network.state_spec else None)

    super(DqnAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=train_sequence_length,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter)
Esempio n. 11
0
    def __init__(self,
                 time_step_spec,
                 action_spec,
                 critic_network,
                 actor_network,
                 actor_optimizer,
                 critic_optimizer,
                 actor_loss_weight=1.0,
                 critic_loss_weight=0.5,
                 actor_policy_ctor=actor_policy.ActorPolicy,
                 critic_network_2=None,
                 target_critic_network=None,
                 target_critic_network_2=None,
                 target_update_tau=1.0,
                 target_update_period=1,
                 td_errors_loss_fn=tf.math.squared_difference,
                 gamma=1.0,
                 gradient_clipping=None,
                 debug_summaries=False,
                 summarize_grads_and_vars=False,
                 train_step_counter=None,
                 name=None):
        """Creates a C-learning Agent.

    By default, the environment observation contains the current state and goal
    state. By setting the obs_to_goal gin config in c_learning_utils, the user
    can specify that the agent should only look at certain subsets of the goal
    state.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      critic_network: A function critic_network((observations, actions)) that
        returns the q_values for each observation and action.
      actor_network: A function actor_network(observation, action_spec) that
        returns action distribution.
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      actor_loss_weight: The weight on actor loss.
      critic_loss_weight: The weight on critic loss.
      actor_policy_ctor: The policy class to use.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target critic network during Q learning. Every
        `target_update_period` train steps, the weights from `critic_network`
        are copied (possibly withsmoothing via `target_update_tau`) to `
        target_critic_network`.  If `target_critic_network` is not provided, it
        is created by making a copy of `critic_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `critic_network`, so that this can be used as a target network.
        If you provide a `target_critic_network` that shares any weights with
        `critic_network`, a warning will be logged but no exception is thrown.
      target_critic_network_2: (Optional.) Similar network as
        target_critic_network but for the critic_network_2. See documentation
        for target_critic_network. Will only be used if 'critic_network_2' is
        also specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn:  A function for computing the elementwise TD errors
        loss.
      gamma: A discount factor for future rewards.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.
    """
        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        self._critic_network_1 = critic_network
        self._critic_network_1.create_variables()
        if target_critic_network:
            target_critic_network.create_variables()
        self._target_critic_network_1 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_1, target_critic_network,
                'TargetCriticNetwork1'))

        if critic_network_2 is not None:
            self._critic_network_2 = critic_network_2
        else:
            self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
            # Do not use target_critic_network_2 if critic_network_2 is None.
            target_critic_network_2 = None
        self._critic_network_2.create_variables()
        if target_critic_network_2:
            target_critic_network_2.create_variables()
        self._target_critic_network_2 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_2, target_critic_network_2,
                'TargetCriticNetwork2'))

        if actor_network:
            actor_network.create_variables()
        self._actor_network = actor_network

        policy = actor_policy_ctor(time_step_spec=time_step_spec,
                                   action_spec=action_spec,
                                   actor_network=self._actor_network,
                                   training=False)

        self._train_policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=True)

        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._actor_loss_weight = actor_loss_weight
        self._critic_loss_weight = critic_loss_weight
        self._td_errors_loss_fn = td_errors_loss_fn
        self._gamma = gamma
        self._gradient_clipping = gradient_clipping
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, period=self._target_update_period)

        train_sequence_length = 2 if not critic_network.state_spec else None

        super(CLearningAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy=policy,
                             collect_policy=policy,
                             train_sequence_length=train_sequence_length,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)
Esempio n. 12
0
  def __init__(self,
               time_step_spec,
               action_spec,
               critic_networks,
               actor_network,
               actor_optimizer,
               critic_optimizers,
               alpha_optimizer,
               actor_policy_ctor=actor_policy.ActorPolicy,
               target_critic_networks=None,
               target_update_tau=0.001,
               target_update_period=1,
               td_errors_loss_fn=tf.math.squared_difference,
               gamma=1.0,
               reward_scale_factor=1.0,
               initial_log_alpha=0.0,
               target_entropy=None,
               gradient_clipping=None,
               debug_summaries=False,
               summarize_grads_and_vars=False,
               train_step_counter=None,
               percentile=0.3,
               name=None):
    """Creates a SAC Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      critic_networks: A function critic_network((observations, actions)) that
        returns the q_values for each observation and action.
      actor_network: A function actor_network(observation, action_spec) that
        returns action distribution.
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizers: The default optimizer to use for the critic network.
      alpha_optimizer: The default optimizer to use for the alpha variable.
      actor_policy_ctor: The policy class to use.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_critic_networks: (Optional.)  A `tf_agents.network.Network` to be
        used as the target critic network during Q learning. Every
        `target_update_period` train steps, the weights from `critic_network`
        are copied (possibly withsmoothing via `target_update_tau`) to `
        target_critic_network`.  If `target_critic_network` is not provided, it
        is created by making a copy of `critic_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `critic_network`, so that this can be used as a target network.
        If you provide a `target_critic_network` that shares any weights with
        `critic_network`, a warning will be logged but no exception is thrown.
      target_critic_network_2: (Optional.) Similar network as
        target_critic_network but for the critic_network_2. See documentation
        for target_critic_network. Will only be used if 'critic_network_2' is
        also specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn:  A function for computing the elementwise TD errors
        loss.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      initial_log_alpha: Initial value for log_alpha.
      target_entropy: The target average policy entropy, for updating alpha. The
        default value is negative of the total number of actions.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.
    """
    tf.Module.__init__(self, name=name)

    flat_action_spec = tf.nest.flatten(action_spec)
    for spec in flat_action_spec:
      if spec.dtype.is_integer:
        raise NotImplementedError(
            'SacAgent does not currently support discrete actions. '
            'Action spec: {}'.format(action_spec))

    self._critic_networks = critic_networks
    [cn.create_variables() for cn in self._critic_networks]
    if target_critic_networks:
      self._target_critic_networks = target_critic_networks
      [cn.create_variables() for cn in self._target_critic_networks]
    else:
      self._target_critic_networks = [None for _ in range(len(self._critic_networks))]
    self._target_critic_networks = [
        common.maybe_copy_target_network_with_checks(cn, tcn, 'TargetCriticNetwork{}'.format(i))
      for i, (tcn, cn) in enumerate(zip(self._target_critic_networks, self._critic_networks))]

    if actor_network:
      actor_network.create_variables()
    self._actor_network = actor_network

    policy = actor_policy_ctor(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network)

    self._train_policy = actor_policy_ctor(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network)

    self._log_alpha = common.create_variable(
        'initial_log_alpha',
        initial_value=initial_log_alpha,
        dtype=tf.float32,
        trainable=True)

    # If target_entropy was not passed, set it to negative of the total number
    # of action dimensions.
    if target_entropy is None:
      flat_action_spec = tf.nest.flatten(action_spec)
      target_entropy = -np.sum([
          np.product(single_spec.shape.as_list())
          for single_spec in flat_action_spec
      ])

    self._target_update_tau = target_update_tau
    self._target_update_period = target_update_period
    self._actor_optimizer = actor_optimizer
    self._critic_optimizer = critic_optimizers
    self._alpha_optimizer = alpha_optimizer
    self._td_errors_loss_fn = td_errors_loss_fn
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._target_entropy = target_entropy
    self._gradient_clipping = gradient_clipping
    self._debug_summaries = debug_summaries
    self._summarize_grads_and_vars = summarize_grads_and_vars
    self._update_target = self._get_target_updater(
        tau=self._target_update_tau, period=self._target_update_period)
    self._percentile = percentile

    train_sequence_length = 2 if not critic_networks[0].state_spec else None

    super(EnsembleSacAgent, self).__init__(
        time_step_spec,
        action_spec,
        policy=policy,
        collect_policy=policy,
        train_sequence_length=train_sequence_length,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter)
Esempio n. 13
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 critic_network: network.Network,
                 actor_network: network.Network,
                 actor_optimizer: types.Optimizer,
                 critic_optimizer: types.Optimizer,
                 alpha_optimizer: types.Optimizer,
                 actor_loss_weight: types.Float = 1.0,
                 critic_loss_weight: types.Float = 0.5,
                 alpha_loss_weight: types.Float = 1.0,
                 actor_policy_ctor: Callable[
                     ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy,
                 critic_network_2: Optional[network.Network] = None,
                 target_critic_network: Optional[network.Network] = None,
                 target_critic_network_2: Optional[network.Network] = None,
                 target_update_tau: types.Float = 1.0,
                 target_update_period: types.Int = 1,
                 td_errors_loss_fn: types.LossFn = tf.math.squared_difference,
                 gamma: types.Float = 1.0,
                 sigma: types.Float = 0.9,
                 reward_scale_factor: types.Float = 1.0,
                 initial_log_alpha: types.Float = 0.0,
                 use_log_alpha_in_alpha_loss: bool = True,
                 target_entropy: Optional[types.Float] = None,
                 gradient_clipping: Optional[types.Float] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 train_step_counter: Optional[tf.Variable] = None,
                 name: Optional[Text] = None):

        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        net_observation_spec = time_step_spec.observation
        critic_spec = (net_observation_spec, action_spec)

        self._critic_network_1 = critic_network

        if critic_network_2 is not None:
            self._critic_network_2 = critic_network_2
        else:
            self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
            # Do not use target_critic_network_2 if critic_network_2 is None.
            target_critic_network_2 = None

        # Wait until critic_network_2 has been copied from critic_network_1 before
        # creating variables on both.
        self._critic_network_1.create_variables(critic_spec)
        self._critic_network_2.create_variables(critic_spec)

        if target_critic_network:
            target_critic_network.create_variables(critic_spec)

        self._target_critic_network_1 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_1,
                target_critic_network,
                input_spec=critic_spec,
                name='TargetCriticNetwork1'))

        if target_critic_network_2:
            target_critic_network_2.create_variables(critic_spec)
        self._target_critic_network_2 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_2,
                target_critic_network_2,
                input_spec=critic_spec,
                name='TargetCriticNetwork2'))

        if actor_network:
            actor_network.create_variables(net_observation_spec)
        self._actor_network = actor_network

        policy = actor_policy_ctor(time_step_spec=time_step_spec,
                                   action_spec=action_spec,
                                   actor_network=self._actor_network,
                                   training=False)

        self._train_policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=True)

        self._log_alpha = common.create_variable(
            'initial_log_alpha',
            initial_value=initial_log_alpha,
            dtype=tf.float32,
            trainable=True)

        if target_entropy is None:
            target_entropy = self._get_default_target_entropy(action_spec)

        self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer
        self._actor_loss_weight = actor_loss_weight
        self._critic_loss_weight = critic_loss_weight
        self._alpha_loss_weight = alpha_loss_weight
        self._td_errors_loss_fn = td_errors_loss_fn
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_entropy = target_entropy
        self._gradient_clipping = gradient_clipping
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, period=self._target_update_period)

        self.sigma = sigma

        train_sequence_length = 2 if not critic_network.state_spec else None

        super(sac_agent.SacAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy=policy,
                             collect_policy=policy,
                             train_sequence_length=train_sequence_length,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter,
                             validate_args=False)

        self._as_transition = data_converter.AsTransition(
            self.data_context, squeeze_time_dim=(train_sequence_length == 2))
Esempio n. 14
0
    def __init__(
            self,
            time_step_spec,
            action_spec,
            q_network,
            optimizer,
            observation_and_action_constraint_splitter=None,
            epsilon_greedy=0.1,
            n_step_update=1,
            boltzmann_temperature=None,
            emit_log_probability=False,
            # Params for target network updates
            target_q_network=None,
            target_update_tau=1.0,
            target_update_period=1,
            td_errors_loss_fn=None,
            gamma=1.0,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for debugging
            debug_summaries=False,
            summarize_grads_and_vars=False,
            train_step_counter=None,
            name=None):
        tf.Module.__init__(self, name=name)
        self._check_action_spec(action_spec)

        if epsilon_greedy is not None and boltzmann_temperature is not None:
            raise ValueError(
                'Configured both epsilon_greedy value {} and temperature {}, '
                'however only one of them can be used for exploration.'.format(
                    epsilon_greedy, boltzmann_temperature))

        self._observation_anc_action_constraint_splitter = (
            observation_and_action_constraint_splitter)
        self._q_network = q_network
        q_network.create_variables()
        if target_q_network:
            target_q_network.create_variables()
        self._target_q_network = common.maybe_copy_target_network_with_checks(
            self._q_network, target_q_network, "TargetQNetwork")

        self._epsilon_greedy = epsilon_greedy
        self._n_step_update = n_step_update
        self._boltzmann_temperature = boltzmann_temperature
        self._optimizer = optimizer
        self._td_error_loss_fn = (td_errors_loss_fn
                                  or common.element_wise_huber_loss)
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._gradient_clipping = gradient_clipping
        self._update_target = self._get_target_updater(target_update_tau,
                                                       target_update_period)

        policy, collect_policy = self._setup_policy(time_step_spec,
                                                    action_spec,
                                                    boltzmann_temperature,
                                                    emit_log_probability)

        if q_network.state_spec and n_step_update != 1:
            raise NotImplementedError(
                'DqnAgent does not currently support n-step updates with stateful '
                'networks (i.e., RNNs), but n_step_update = {}'.format(
                    n_step_update))

        train_sequence_length = (n_step_update +
                                 1 if not q_network.state_spec else None)

        super(DqnAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy,
                             collect_policy,
                             train_sequence_length=train_sequence_length,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter)
Esempio n. 15
0
def train_eval(
    root_dir,
    load_root_dir=None,
    env_load_fn=None,
    gym_env_wrappers=[],
    monitor=False,
    env_name=None,
    agent_class=None,
    initial_collect_driver_class=None,
    collect_driver_class=None,
    online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver,
    num_global_steps=1000000,
    rb_size=None,
    train_steps_per_iteration=1,
    train_metrics=None,
    eval_metrics=None,
    train_metrics_callback=None,
    # SacAgent args
    actor_fc_layers=(256, 256),
    critic_joint_fc_layers=(256, 256),
    # Safety Critic training args
    sc_rb_size=None,
    target_safety=None,
    train_sc_steps=10,
    train_sc_interval=1000,
    online_critic=False,
    n_envs=None,
    finetune_sc=False,
    pretraining=True,
    lambda_schedule_nsteps=0,
    lambda_initial=0.,
    lambda_final=1.,
    kstep_fail=0,
    # Ensemble Critic training args
    num_critics=None,
    critic_learning_rate=3e-4,
    # Wcpg Critic args
    critic_preprocessing_layer_size=256,
    # Params for train
    batch_size=256,
    # Params for eval
    run_eval=False,
    num_eval_episodes=10,
    eval_interval=1000,
    # Params for summaries and logging
    train_checkpoint_interval=10000,
    policy_checkpoint_interval=5000,
    rb_checkpoint_interval=50000,
    keep_rb_checkpoint=False,
    log_interval=1000,
    summary_interval=1000,
    monitor_interval=5000,
    summaries_flush_secs=10,
    early_termination_fn=None,
    debug_summaries=False,
    seed=None,
    eager_debug=False,
    env_metric_factories=None,
    wandb=False):  # pylint: disable=unused-argument

  """train and eval script for SQRL."""
  if isinstance(agent_class, str):
    assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(agent_class)
    agent_class = ALGOS.get(agent_class)
  n_envs = n_envs or num_eval_episodes
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')

  # =====================================================================#
  #  Setup summary metrics, file writers, and create env                 #
  # =====================================================================#
  train_summary_writer = tf.compat.v2.summary.create_file_writer(
    train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  train_metrics = train_metrics or []
  eval_metrics = eval_metrics or []

  updating_sc = online_critic and (not load_root_dir or finetune_sc)
  logging.debug('updating safety critic: %s', updating_sc)

  if seed:
    tf.compat.v1.set_random_seed(seed)

  if agent_class in SAFETY_AGENTS:
    if online_critic:
      sc_tf_env = tf_py_environment.TFPyEnvironment(
        parallel_py_environment.ParallelPyEnvironment(
          [lambda: env_load_fn(env_name)] * n_envs
        ))
      if seed:
        seeds = [seed * n_envs + i for i in range(n_envs)]
        try:
          sc_tf_env.pyenv.seed(seeds)
        except:
          pass

  if run_eval:
    eval_dir = os.path.join(root_dir, 'eval')
    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
                     tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs),
                     tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes, batch_size=n_envs),
                   ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics]
    eval_tf_env = tf_py_environment.TFPyEnvironment(
      parallel_py_environment.ParallelPyEnvironment(
        [lambda: env_load_fn(env_name)] * n_envs
      ))
    if seed:
      try:
        for i, pyenv in enumerate(eval_tf_env.pyenv.envs):
          pyenv.seed(seed * n_envs + i)
      except:
        pass
  elif 'Drunk' in env_name:
    # Just visualizes trajectories in drunk spider environment
    eval_tf_env = tf_py_environment.TFPyEnvironment(
      env_load_fn(env_name))
  else:
    eval_tf_env = None

  if monitor:
    vid_path = os.path.join(root_dir, 'rollouts')
    monitor_env_wrapper = misc.monitor_freq(1, vid_path)
    monitor_env = gym.make(env_name)
    for wrapper in gym_env_wrappers:
      monitor_env = wrapper(monitor_env)
    monitor_env = monitor_env_wrapper(monitor_env)
    # auto_reset must be False to ensure Monitor works correctly
    monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False)

  global_step = tf.compat.v1.train.get_or_create_global_step()

  with tf.summary.record_if(
          lambda: tf.math.equal(global_step % summary_interval, 0)):
    py_env = env_load_fn(env_name)
    tf_env = tf_py_environment.TFPyEnvironment(py_env)
    if seed:
      try:
        for i, pyenv in enumerate(tf_env.pyenv.envs):
          pyenv.seed(seed * n_envs + i)
      except:
        pass
    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    logging.debug('obs spec: %s', observation_spec)
    logging.debug('action spec: %s', action_spec)

    # =====================================================================#
    #  Setup agent class                                                   #
    # =====================================================================#

    if agent_class == wcpg_agent.WcpgAgent:
      alpha_spec = tensor_spec.BoundedTensorSpec(shape=(1,), dtype=tf.float32, minimum=0., maximum=1.,
                                                 name='alpha')
      input_tensor_spec = (observation_spec, action_spec, alpha_spec)
      critic_net = agents.DistributionalCriticNetwork(
        input_tensor_spec, preprocessing_layer_size=critic_preprocessing_layer_size,
        joint_fc_layer_params=critic_joint_fc_layers)
      actor_net = agents.WcpgActorNetwork((observation_spec, alpha_spec), action_spec)
    else:
      actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_spec,
        action_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=agents.normal_projection_net)
      critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)

    if agent_class in SAFETY_AGENTS:
      logging.debug('Making SQRL agent')
      if lambda_schedule_nsteps > 0:
        lambda_update_every_nsteps = num_global_steps // lambda_schedule_nsteps
        step_size = (lambda_final - lambda_initial) / lambda_update_every_nsteps
        lambda_scheduler = lambda lam: common.periodically(
          body=lambda: tf.group(lam.assign(lam + step_size)),
          period=lambda_update_every_nsteps)
      else:
        lambda_scheduler = None
      safety_critic_net = agents.CriticNetwork(
        (observation_spec, action_spec),
        joint_fc_layer_params=critic_joint_fc_layers)
      ts = target_safety
      thresholds = [ts, 0.5]
      sc_metrics = [tf.keras.metrics.AUC(name='safety_critic_auc'),
                    tf.keras.metrics.TruePositives(name='safety_critic_tp',
                                                   thresholds=thresholds),
                    tf.keras.metrics.FalsePositives(name='safety_critic_fp',
                                                    thresholds=thresholds),
                    tf.keras.metrics.TrueNegatives(name='safety_critic_tn',
                                                   thresholds=thresholds),
                    tf.keras.metrics.FalseNegatives(name='safety_critic_fn',
                                                    thresholds=thresholds),
                    tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc',
                                                    threshold=0.5)]
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        safety_critic_network=safety_critic_net,
        train_step_counter=global_step,
        debug_summaries=debug_summaries,
        safety_pretraining=pretraining,
        train_critic_online=online_critic,
        initial_log_lambda=lambda_initial,
        log_lambda=(lambda_scheduler is None),
        lambda_scheduler=lambda_scheduler)
    elif agent_class is ensemble_sac_agent.EnsembleSacAgent:
      critic_nets, critic_optimizers = [critic_net], [tf.keras.optimizers.Adam(critic_learning_rate)]
      for _ in range(num_critics - 1):
        critic_nets.append(agents.CriticNetwork((observation_spec, action_spec),
                                                joint_fc_layer_params=critic_joint_fc_layers))
        critic_optimizers.append(tf.keras.optimizers.Adam(critic_learning_rate))
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_networks=critic_nets,
        critic_optimizers=critic_optimizers,
        debug_summaries=debug_summaries
      )
    else:  # agent is either SacAgent or WcpgAgent
      logging.debug('critic input_tensor_spec: %s', critic_net.input_tensor_spec)
      tf_agent = agent_class(
        time_step_spec,
        action_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        train_step_counter=global_step,
        debug_summaries=debug_summaries)

    tf_agent.initialize()

    # =====================================================================#
    #  Setup replay buffer                                                 #
    # =====================================================================#
    collect_data_spec = tf_agent.collect_data_spec

    logging.debug('Allocating replay buffer ...')
    # Add to replay buffer and other agent specific observers.
    rb_size = rb_size or 1000000
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
      collect_data_spec,
      batch_size=1,
      max_length=rb_size)

    logging.debug('RB capacity: %i', replay_buffer.capacity)
    logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec)

    if agent_class in SAFETY_AGENTS:
      sc_rb_size = sc_rb_size or num_eval_episodes * 500
      sc_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        collect_data_spec, batch_size=1, max_length=sc_rb_size,
        dataset_window_shift=1)

    num_episodes = tf_metrics.NumberOfEpisodes()
    num_env_steps = tf_metrics.EnvironmentSteps()
    return_metric = tf_metrics.AverageReturnMetric(
      buffer_size=num_eval_episodes, batch_size=tf_env.batch_size)
    train_metrics = [
                      num_episodes, num_env_steps,
                      return_metric,
                      tf_metrics.AverageEpisodeLengthMetric(
                        buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
                    ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics]

    if 'Minitaur' in env_name and not pretraining:
      goal_vel = gin.query_parameter("%GOAL_VELOCITY")
      early_termination_fn = train_utils.MinitaurTerminationFn(
        speed_metric=train_metrics[-2], total_falls_metric=train_metrics[-3],
        env_steps_metric=num_env_steps, goal_speed=goal_vel)

    if env_metric_factories:
      for env_metric in env_metric_factories:
        train_metrics.append(tf_py_metric.TFPyMetric(env_metric(tf_env.pyenv.envs)))
        if run_eval:
          eval_metrics.append(env_metric([env for env in
                                          eval_tf_env.pyenv._envs]))

    # =====================================================================#
    #  Setup collect policies                                              #
    # =====================================================================#
    if not online_critic:
      eval_policy = tf_agent.policy
      collect_policy = tf_agent.collect_policy
      if not pretraining and agent_class in SAFETY_AGENTS:
        collect_policy = tf_agent.safe_policy
    else:
      eval_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy
      collect_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy
      online_collect_policy = tf_agent.safe_policy  # if pretraining else tf_agent.collect_policy
      if pretraining:
        online_collect_policy._training = False

    if not load_root_dir:
      initial_collect_policy = random_tf_policy.RandomTFPolicy(time_step_spec, action_spec)
    else:
      initial_collect_policy = collect_policy
    if agent_class == wcpg_agent.WcpgAgent:
      initial_collect_policy = agents.WcpgPolicyWrapper(initial_collect_policy)

    # =====================================================================#
    #  Setup Checkpointing                                                 #
    # =====================================================================#
    train_checkpointer = common.Checkpointer(
      ckpt_dir=train_dir,
      agent=tf_agent,
      global_step=global_step,
      metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
      ckpt_dir=os.path.join(train_dir, 'policy'),
      policy=eval_policy,
      global_step=global_step)

    rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer')
    rb_checkpointer = common.Checkpointer(
      ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer)

    if online_critic:
      online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer')
      online_rb_checkpointer = common.Checkpointer(
        ckpt_dir=online_rb_ckpt_dir,
        max_to_keep=1,
        replay_buffer=sc_buffer)

    # loads agent, replay buffer, and online sc/buffer if online_critic
    if load_root_dir:
      load_root_dir = os.path.expanduser(load_root_dir)
      load_train_dir = os.path.join(load_root_dir, 'train')
      misc.load_agent_ckpt(load_train_dir, tf_agent)
      if len(os.listdir(os.path.join(load_train_dir, 'replay_buffer'))) > 1:
        load_rb_ckpt_dir = os.path.join(load_train_dir, 'replay_buffer')
        misc.load_rb_ckpt(load_rb_ckpt_dir, replay_buffer)
      if online_critic:
        load_online_sc_ckpt_dir = os.path.join(load_root_dir, 'sc')
        load_online_rb_ckpt_dir = os.path.join(load_train_dir,
                                               'online_replay_buffer')
        if osp.exists(load_online_rb_ckpt_dir):
          misc.load_rb_ckpt(load_online_rb_ckpt_dir, sc_buffer)
        if osp.exists(load_online_sc_ckpt_dir):
          misc.load_safety_critic_ckpt(load_online_sc_ckpt_dir,
                                       safety_critic_net)
      elif agent_class in SAFETY_AGENTS:
        offline_run = sorted(os.listdir(os.path.join(load_train_dir, 'offline')))[-1]
        load_sc_ckpt_dir = os.path.join(load_train_dir, 'offline',
                                        offline_run, 'safety_critic')
        if osp.exists(load_sc_ckpt_dir):
          sc_net_off = agents.CriticNetwork(
            (observation_spec, action_spec),
            joint_fc_layer_params=(512, 512),
            name='SafetyCriticOffline')
          sc_net_off.create_variables()
          target_sc_net_off = common.maybe_copy_target_network_with_checks(
            sc_net_off, None, 'TargetSafetyCriticNetwork')
          sc_optimizer = tf.keras.optimizers.Adam(critic_learning_rate)
          _ = misc.load_safety_critic_ckpt(
            load_sc_ckpt_dir, safety_critic_net=sc_net_off,
            target_safety_critic=target_sc_net_off,
            optimizer=sc_optimizer)
          tf_agent._safety_critic_network = sc_net_off
          tf_agent._target_safety_critic_network = target_sc_net_off
          tf_agent._safety_critic_optimizer = sc_optimizer
    else:
      train_checkpointer.initialize_or_restore()
      rb_checkpointer.initialize_or_restore()
      if online_critic:
        online_rb_checkpointer.initialize_or_restore()

    if agent_class in SAFETY_AGENTS:
      sc_dir = os.path.join(root_dir, 'sc')
      safety_critic_checkpointer = common.Checkpointer(
        ckpt_dir=sc_dir,
        safety_critic=tf_agent._safety_critic_network,
        # pylint: disable=protected-access
        target_safety_critic=tf_agent._target_safety_critic_network,
        optimizer=tf_agent._safety_critic_optimizer,
        global_step=global_step)

      if not (load_root_dir and not online_critic):
        safety_critic_checkpointer.initialize_or_restore()

    agent_observers = [replay_buffer.add_batch] + train_metrics
    collect_driver = collect_driver_class(
      tf_env, collect_policy, observers=agent_observers)
    collect_driver.run = common.function_in_tf1()(collect_driver.run)

    if online_critic:
      logging.debug('online driver class: %s', online_driver_class)
      online_agent_observers = [num_episodes, num_env_steps,
                                sc_buffer.add_batch]
      online_driver = online_driver_class(
        sc_tf_env, online_collect_policy, observers=online_agent_observers,
        num_episodes=num_eval_episodes)
      online_driver.run = common.function_in_tf1()(online_driver.run)

    if eager_debug:
      tf.config.experimental_run_functions_eagerly(True)
    else:
      config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True)
      tf.function(config_saver.after_create_session)()

    if global_step == 0:
      logging.info('Performing initial collection ...')
      init_collect_observers = agent_observers
      if agent_class in SAFETY_AGENTS:
        init_collect_observers += [sc_buffer.add_batch]
      initial_collect_driver_class(
        tf_env,
        initial_collect_policy,
        observers=init_collect_observers).run()
      last_id = replay_buffer._get_last_id()  # pylint: disable=protected-access
      logging.info('Data saved after initial collection: %d steps', last_id)
      if agent_class in SAFETY_AGENTS:
        last_id = sc_buffer._get_last_id()  # pylint: disable=protected-access
        logging.debug('Data saved in sc_buffer after initial collection: %d steps', last_id)

    if run_eval:
      results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='EvalMetrics',
      )
      if train_metrics_callback is not None:
        train_metrics_callback(results, global_step.numpy())
      metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    train_step = train_utils.get_train_step(tf_agent, replay_buffer, batch_size)

    if agent_class in SAFETY_AGENTS:
      critic_train_step = train_utils.get_critic_train_step(
        tf_agent, replay_buffer, sc_buffer, batch_size=batch_size,
        updating_sc=updating_sc, metrics=sc_metrics)

    if early_termination_fn is None:
      early_termination_fn = lambda: False

    loss_diverged = False
    # How many consecutive steps was loss diverged for.
    loss_divergence_counter = 0
    mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss')

    if agent_class in SAFETY_AGENTS:
      resample_counter = collect_policy._resample_counter
      mean_resample_ac = tf.keras.metrics.Mean(name='mean_unsafe_ac_freq')
      sc_metrics.append(mean_resample_ac)

      if online_critic:
        logging.debug('starting safety critic pretraining')
        # don't fine-tune safety critic
        if global_step.numpy() == 0:
          for _ in range(train_sc_steps):
            sc_loss, lambda_loss = critic_train_step()
          critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())]
          for critic_metric in sc_metrics:
            res = critic_metric.result().numpy()
            if not res.shape:
              critic_results.append((critic_metric.name, res))
            else:
              for r, thresh in zip(res, thresholds):
                name = '_'.join([critic_metric.name, str(thresh)])
                critic_results.append((name, r))
            critic_metric.reset_states()
          if train_metrics_callback:
            train_metrics_callback(collections.OrderedDict(critic_results),
                                   step=global_step.numpy())

    logging.debug('Starting main train loop...')
    curr_ep = []
    global_step_val = global_step.numpy()
    while global_step_val <= num_global_steps and not early_termination_fn():
      start_time = time.time()

      # MEASURE ACTION RESAMPLING FREQUENCY
      if agent_class in SAFETY_AGENTS:
        if pretraining and global_step_val == num_global_steps // 2:
          if online_critic:
            online_collect_policy._training = True
          collect_policy._training = True
        if online_critic or collect_policy._training:
          mean_resample_ac(resample_counter.result())
          resample_counter.reset()
          if time_step is None or time_step.is_last():
            resample_ac_freq = mean_resample_ac.result()
            mean_resample_ac.reset_states()
            tf.compat.v2.summary.scalar(
              name='resample_ac_freq', data=resample_ac_freq, step=global_step)

      # RUN COLLECTION
      time_step, policy_state = collect_driver.run(
        time_step=time_step,
        policy_state=policy_state,
      )

      # get last step taken by step_driver
      traj = replay_buffer._data_table.read(replay_buffer._get_last_id() %
                                            replay_buffer._capacity)
      curr_ep.append(traj)

      if time_step.is_last():
        if agent_class in SAFETY_AGENTS:
          if time_step.observation['task_agn_rew']:
            if kstep_fail:
              # applies task agn rew. over last k steps
              for i, traj in enumerate(curr_ep[-kstep_fail:]):
                traj.observation['task_agn_rew'] = 1.
                sc_buffer.add_batch(traj)
            else:
              [sc_buffer.add_batch(traj) for traj in curr_ep]
        curr_ep = []
        if agent_class == wcpg_agent.WcpgAgent:
          collect_policy._alpha = None  # reset WCPG alpha

      if (global_step_val + 1) % log_interval == 0:
        logging.debug('policy eval: %4.2f sec', time.time() - start_time)

      # PERFORMS TRAIN STEP ON ALGORITHM (OFF-POLICY)
      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
        mean_train_loss(train_loss.loss)

      current_step = global_step.numpy()
      total_loss = mean_train_loss.result()
      mean_train_loss.reset_states()

      if train_metrics_callback and current_step % summary_interval == 0:
        train_metrics_callback(
          collections.OrderedDict([(k, v.numpy()) for k, v in
                                   train_loss.extra._asdict().items()]),
          step=current_step)
        train_metrics_callback(
          {'train_loss': total_loss.numpy()}, step=current_step)

      # TRAIN AND/OR EVAL SAFETY CRITIC
      if agent_class in SAFETY_AGENTS and current_step % train_sc_interval == 0:
        if online_critic:
          batch_time_step = sc_tf_env.reset()

          # run online critic training collect & update
          batch_policy_state = online_collect_policy.get_initial_state(
            sc_tf_env.batch_size)
          online_driver.run(time_step=batch_time_step,
                            policy_state=batch_policy_state)
        for _ in range(train_sc_steps):
          sc_loss, lambda_loss = critic_train_step()
        # log safety_critic loss results
        critic_results = [('sc_loss', sc_loss.numpy()),
                          ('lambda_loss', lambda_loss.numpy())]
        metric_utils.log_metrics(sc_metrics)
        for critic_metric in sc_metrics:
          res = critic_metric.result().numpy()
          if not res.shape:
            critic_results.append((critic_metric.name, res))
          else:
            for r, thresh in zip(res, thresholds):
              name = '_'.join([critic_metric.name, str(thresh)])
              critic_results.append((name, r))
          critic_metric.reset_states()
        if train_metrics_callback and current_step % summary_interval == 0:
          train_metrics_callback(collections.OrderedDict(critic_results),
                                 step=current_step)

      # Check for exploding losses.
      if (math.isnan(total_loss) or math.isinf(total_loss) or
              total_loss > MAX_LOSS):
        loss_divergence_counter += 1
        if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS:
          loss_diverged = True
          logging.info('Loss diverged, critic_loss: %s, actor_loss: %s',
                       train_loss.extra.critic_loss,
                       train_loss.extra.actor_loss)
          break
      else:
        loss_divergence_counter = 0

      time_acc += time.time() - start_time

      # LOGGING AND METRICS
      if current_step % log_interval == 0:
        metric_utils.log_metrics(train_metrics)
        logging.info('step = %d, loss = %f', current_step, total_loss)
        steps_per_sec = (current_step - timed_at_step) / time_acc
        logging.info('%4.2f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
          name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = current_step
        time_acc = 0

      train_results = []

      for metric in train_metrics[2:]:
        if isinstance(metric, (metrics.AverageEarlyFailureMetric,
                               metrics.AverageFallenMetric,
                               metrics.AverageSuccessMetric)):
          # Plot failure as a fn of return
          metric.tf_summaries(
            train_step=global_step, step_metrics=[num_env_steps, num_episodes,
                                                  return_metric])
        else:
          metric.tf_summaries(
            train_step=global_step, step_metrics=[num_env_steps, num_env_steps])
        train_results.append((metric.name, metric.result().numpy()))

      if train_metrics_callback and current_step % summary_interval == 0:
        train_metrics_callback(collections.OrderedDict(train_results),
                               step=global_step.numpy())

      if current_step % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=current_step)

      if current_step % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=current_step)
        if agent_class in SAFETY_AGENTS:
          safety_critic_checkpointer.save(global_step=current_step)
          if online_critic:
            online_rb_checkpointer.save(global_step=current_step)

      if rb_checkpoint_interval and current_step % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=current_step)

      if wandb and current_step % eval_interval == 0 and "Drunk" in env_name:
        misc.record_point_mass_episode(eval_tf_env, eval_policy, current_step)
        if online_critic:
          misc.record_point_mass_episode(eval_tf_env, tf_agent.safe_policy,
                                         current_step, 'safe-trajectory')

      if run_eval and current_step % eval_interval == 0:
        eval_results = metric_utils.eager_compute(
          eval_metrics,
          eval_tf_env,
          eval_policy,
          num_episodes=num_eval_episodes,
          train_step=global_step,
          summary_writer=eval_summary_writer,
          summary_prefix='EvalMetrics',
        )
        if train_metrics_callback is not None:
          train_metrics_callback(eval_results, current_step)
        metric_utils.log_metrics(eval_metrics)

        with eval_summary_writer.as_default():
          for eval_metric in eval_metrics[2:]:
            eval_metric.tf_summaries(train_step=global_step,
                                     step_metrics=eval_metrics[:2])

      if monitor and current_step % monitor_interval == 0:
        monitor_time_step = monitor_py_env.reset()
        monitor_policy_state = eval_policy.get_initial_state(1)
        ep_len = 0
        monitor_start = time.time()
        while not monitor_time_step.is_last():
          monitor_action = eval_policy.action(monitor_time_step, monitor_policy_state)
          action, monitor_policy_state = monitor_action.action, monitor_action.state
          monitor_time_step = monitor_py_env.step(action)
          ep_len += 1
        logging.debug('saved rollout at timestep %d, rollout length: %d, %4.2f sec',
                      current_step, ep_len, time.time() - monitor_start)

      global_step_val = current_step

  if early_termination_fn():
    #  Early stopped, save all checkpoints if not saved
    if global_step_val % train_checkpoint_interval != 0:
      train_checkpointer.save(global_step=global_step_val)

    if global_step_val % policy_checkpoint_interval != 0:
      policy_checkpointer.save(global_step=global_step_val)
      if agent_class in SAFETY_AGENTS:
        safety_critic_checkpointer.save(global_step=global_step_val)
        if online_critic:
          online_rb_checkpointer.save(global_step=global_step_val)

    if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0:
      rb_checkpointer.save(global_step=global_step_val)

  if not keep_rb_checkpoint:
    misc.cleanup_checkpoints(rb_ckpt_dir)

  if loss_diverged:
    # Raise an error at the very end after the cleanup.
    raise ValueError('Loss diverged to {} at step {}, terminating.'.format(
      total_loss, global_step.numpy()))

  return total_loss
Esempio n. 16
0
  def __init__(self,
               time_step_spec,
               action_spec,
               critic_network,
               actor_network,
               actor_optimizer,
               critic_optimizer,
               safety_critic_optimizer,
               alpha_optimizer,
               lambda_optimizer=None,
               lambda_scheduler=None,
               actor_policy_ctor=actor_policy.ActorPolicy,
               safety_critic_network=None,
               target_safety_critic_network=None,
               critic_network_2=None,
               target_critic_network=None,
               target_critic_network_2=None,
               target_update_tau=0.005,
               target_update_period=1,
               td_errors_loss_fn=tf.math.squared_difference,
               safe_td_errors_loss_fn=tf.keras.losses.binary_crossentropy,
               gamma=1.0,
               safety_gamma=None,
               reward_scale_factor=1.0,
               initial_log_alpha=0.0,
               initial_log_lambda=0.0,
               log_lambda=True,
               target_entropy=None,
               target_safety=0.1,
               gradient_clipping=None,
               debug_summaries=False,
               summarize_grads_and_vars=False,
               train_step_counter=None,
               safety_pretraining=True,
               train_critic_online=True,
               resample_counter=None,
               fail_weight=None,
               name="SqrlAgent"):
    self._safety_critic_network = safety_critic_network
    self._train_critic_online = train_critic_online
    super(SqrlAgent, self).__init__(
        time_step_spec=time_step_spec, action_spec=action_spec, critic_network=critic_network,
        actor_network=actor_network, actor_optimizer=actor_optimizer,
        critic_optimizer=critic_optimizer,
        alpha_optimizer=alpha_optimizer, critic_network_2=critic_network_2,
        target_critic_network=target_critic_network,
        target_critic_network_2=target_critic_network_2,
        actor_policy_ctor=actor_policy_ctor,
        target_update_tau=target_update_tau, target_update_period=target_update_period,
        td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor,
        initial_log_alpha=initial_log_alpha, target_entropy=target_entropy,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter, name=name)

    if safety_critic_network is not None:
      self._safety_critic_network.create_variables()
    else:
      self._safety_critic_network = common.maybe_copy_target_network_with_checks(
        critic_network, None, 'SafetyCriticNetwork')

    if target_safety_critic_network is not None:
      self._target_safety_critic_network = target_safety_critic_network
      self._target_safety_critic_network.create_variables()
    else:
      self._target_safety_critic_network = common.maybe_copy_target_network_with_checks(
        self._safety_critic_network, None, 'TargetSafetyCriticNetwork')

    lambda_name = 'initial_log_lambda' if log_lambda else 'initial_lambda'
    self._lambda_var = common.create_variable(
        lambda_name,
        initial_value=initial_log_lambda,
        dtype=tf.float32,
        trainable=True)
    self._target_safety = target_safety
    self._safe_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network,
        safety_critic_network=self._safety_critic_network,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=True)
    self._collect_policy = agents.SafeActorPolicyRSVar(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        actor_network=self._actor_network,
        safety_critic_network=self._safety_critic_network,
        safety_threshold=target_safety,
        resample_counter=resample_counter,
        training=(not safety_pretraining))

    self._safety_critic_optimizer = safety_critic_optimizer
    self._lambda_optimizer = lambda_optimizer or alpha_optimizer
    if lambda_scheduler is None:
      self._lambda_scheduler = lambda_scheduler
    else:
      self._lambda_scheduler = lambda_scheduler(self._lambda_var)
    self._safety_pretraining = safety_pretraining
    self._safe_td_errors_loss_fn = safe_td_errors_loss_fn
    self._safety_gamma = safety_gamma or self._gamma
    self._fail_weight = fail_weight
    if train_critic_online:
      self._update_target_safety_critic = self._get_target_updater_safety_critic(
        tau=self._target_update_tau, period=self._target_update_period)
    else:
      self._update_target_safety_critic = None