Example #1
0
    def __init__(self,
                 action_spec,
                 actor_network: Network,
                 critic_network: Network,
                 critic_loss=None,
                 target_entropy=None,
                 initial_log_alpha=0.0,
                 target_update_tau=0.05,
                 target_update_period=1,
                 dqda_clipping=None,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 alpha_optimizer=None,
                 gradient_clipping=None,
                 train_step_counter=None,
                 debug_summaries=False,
                 name="SacAlgorithm"):
        """Create a SacAlgorithm

        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            actor_network (Network): The network will be called with
                call(observation, step_type).
            critic_network (Network): The network will be called with
                call(observation, action, step_type).
            critic_loss (None|OneStepTDLoss): an object for calculating critic loss.
                If None, a default OneStepTDLoss will be used.
            initial_log_alpha (float): initial value for variable log_alpha
            target_entropy (float|None): The target average policy entropy, for updating alpha.
            target_update_tau (float): Factor for soft update of the target
                networks.
            target_update_period (int): Period for soft update of the target
                networks.
            dqda_clipping (float): when computing the actor loss, clips the
                gradient dqda element-wise between [-dqda_clipping, dqda_clipping].
                Does not perform clipping if dqda_clipping == 0.
            actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor.
            critic_optimizer (tf.optimizers.Optimizer): The optimizer for critic.
            alpha_optimizer (tf.optimizers.Optimizer): The optimizer for alpha.
            gradient_clipping (float): Norm length to clip gradients.
            train_step_counter (tf.Variable): An optional counter to increment
                every time the a new iteration is started. If None, it will use
                tf.summary.experimental.get_step(). If this is still None, a
                counter will be created.
            debug_summaries (bool): True if debug summaries should be created.
            name (str): The name of this algorithm.
        """
        critic_network1 = critic_network
        critic_network2 = critic_network.copy(name='CriticNetwork2')
        log_alpha = tfa_common.create_variable(name='log_alpha',
                                               initial_value=initial_log_alpha,
                                               dtype=tf.float32,
                                               trainable=True)
        super().__init__(
            action_spec,
            train_state_spec=SacState(
                share=SacShareState(actor=actor_network.state_spec),
                actor=SacActorState(critic1=critic_network.state_spec,
                                    critic2=critic_network.state_spec),
                critic=SacCriticState(
                    critic1=critic_network.state_spec,
                    critic2=critic_network.state_spec,
                    target_critic1=critic_network.state_spec,
                    target_critic2=critic_network.state_spec)),
            action_distribution_spec=actor_network.output_spec,
            predict_state_spec=actor_network.state_spec,
            optimizer=[actor_optimizer, critic_optimizer, alpha_optimizer],
            get_trainable_variables_func=[
                lambda: actor_network.trainable_variables, lambda:
                (critic_network1.trainable_variables + critic_network2.
                 trainable_variables), lambda: [log_alpha]
            ],
            gradient_clipping=gradient_clipping,
            train_step_counter=train_step_counter,
            debug_summaries=debug_summaries,
            name=name)

        self._log_alpha = log_alpha
        self._actor_network = actor_network
        self._critic_network1 = critic_network1
        self._critic_network2 = critic_network2
        self._target_critic_network1 = self._critic_network1.copy(
            name='TargetCriticNetwork1')
        self._target_critic_network2 = self._critic_network2.copy(
            name='TargetCriticNetwork2')
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer

        if critic_loss is None:
            critic_loss = OneStepTDLoss(debug_summaries=debug_summaries)
        self._critic_loss = critic_loss

        flat_action_spec = tf.nest.flatten(self._action_spec)
        self._is_continuous = tensor_spec.is_continuous(flat_action_spec[0])
        if target_entropy is None:
            target_entropy = np.sum(
                list(
                    map(dist_utils.calc_default_target_entropy,
                        flat_action_spec)))
        self._target_entropy = target_entropy

        self._dqda_clipping = dqda_clipping

        self._update_target = common.get_target_updater(
            models=[self._critic_network1, self._critic_network2],
            target_models=[
                self._target_critic_network1, self._target_critic_network2
            ],
            tau=target_update_tau,
            period=target_update_period)

        tfa_common.soft_variables_update(
            self._critic_network1.variables,
            self._target_critic_network1.variables,
            tau=1.0)

        tfa_common.soft_variables_update(
            self._critic_network2.variables,
            self._target_critic_network2.variables,
            tau=1.0)
Example #2
0
    def __init__(self,
                 action_spec,
                 actor_network: Network,
                 critic_network: Network,
                 ou_stddev=0.2,
                 ou_damping=0.15,
                 critic_loss=None,
                 target_update_tau=0.05,
                 target_update_period=1,
                 dqda_clipping=None,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 gradient_clipping=None,
                 train_step_counter=None,
                 debug_summaries=False,
                 name="DdpgAlgorithm"):
        """
        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            actor_network (Network):  The network will be called with
                call(observation, step_type).
            critic_network (Network): The network will be called with
                call(observation, action, step_type).
            ou_stddev (float): Standard deviation for the Ornstein-Uhlenbeck
                (OU) noise added in the default collect policy.
            ou_damping (float): Damping factor for the OU noise added in the
                default collect policy.
            critic_loss (None|OneStepTDLoss): an object for calculating critic
                loss. If None, a default OneStepTDLoss will be used.
            target_update_tau (float): Factor for soft update of the target
                networks.
            target_update_period (int): Period for soft update of the target
                networks.
            dqda_clipping (float): when computing the actor loss, clips the
                gradient dqda element-wise between [-dqda_clipping, dqda_clipping].
                Does not perform clipping if dqda_clipping == 0.
            actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor.
            critic_optimizer (tf.optimizers.Optimizer): The optimizer for actor.
            gradient_clipping (float): Norm length to clip gradients.
            train_step_counter (tf.Variable): An optional counter to increment
                every time the a new iteration is started. If None, it will use
                tf.summary.experimental.get_step(). If this is still None, a
                counter will be created.
            debug_summaries (bool): True if debug summaries should be created.
            name (str): The name of this algorithm.
        """
        train_state_spec = DdpgState(
            actor=DdpgActorState(actor=actor_network.state_spec,
                                 critic=critic_network.state_spec),
            critic=DdpgCriticState(critic=critic_network.state_spec,
                                   target_actor=actor_network.state_spec,
                                   target_critic=critic_network.state_spec))

        super().__init__(action_spec,
                         train_state_spec=train_state_spec,
                         action_distribution_spec=action_spec,
                         optimizer=[actor_optimizer, critic_optimizer],
                         get_trainable_variables_func=[
                             lambda: actor_network.trainable_variables,
                             lambda: critic_network.trainable_variables
                         ],
                         gradient_clipping=gradient_clipping,
                         train_step_counter=train_step_counter,
                         debug_summaries=debug_summaries,
                         name=name)

        self._actor_network = actor_network
        self._critic_network = critic_network
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer

        self._target_actor_network = actor_network.copy(
            name='target_actor_network')
        self._target_critic_network = critic_network.copy(
            name='target_critic_network')

        self._ou_stddev = ou_stddev
        self._ou_damping = ou_damping

        if critic_loss is None:
            critic_loss = OneStepTDLoss(debug_summaries=debug_summaries)
        self._critic_loss = critic_loss

        self._ou_process = self._create_ou_process(ou_stddev, ou_damping)

        self._update_target = common.get_target_updater(
            models=[self._actor_network, self._critic_network],
            target_models=[
                self._target_actor_network, self._target_critic_network
            ],
            tau=target_update_tau,
            period=target_update_period)

        self._dqda_clipping = dqda_clipping

        tfa_common.soft_variables_update(self._critic_network.variables,
                                         self._target_critic_network.variables,
                                         tau=1.0)
        tfa_common.soft_variables_update(self._actor_network.variables,
                                         self._target_actor_network.variables,
                                         tau=1.0)
Example #3
0
  def __init__(self,
               time_step_spec: ts.TimeStep,
               action_spec: types.NestedTensor,
               actor_network: network.Network,
               critic_network: network.Network,
               actor_optimizer: types.Optimizer,
               critic_optimizer: types.Optimizer,
               exploration_noise_std: types.Float = 0.1,
               critic_network_2: Optional[network.Network] = None,
               target_actor_network: Optional[network.Network] = None,
               target_critic_network: Optional[network.Network] = None,
               target_critic_network_2: Optional[network.Network] = None,
               target_update_tau: types.Float = 1.0,
               target_update_period: types.Int = 1,
               actor_update_period: types.Int = 1,
               td_errors_loss_fn: Optional[types.LossFn] = None,
               gamma: types.Float = 1.0,
               reward_scale_factor: types.Float = 1.0,
               target_policy_noise: types.Float = 0.2,
               target_policy_noise_clip: types.Float = 0.5,
               gradient_clipping: Optional[types.Float] = None,
               debug_summaries: bool = False,
               summarize_grads_and_vars: bool = False,
               train_step_counter: Optional[tf.Variable] = None,
               name: Text = None):
    """Creates a Td3Agent Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      actor_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, step_type).
      critic_network: A tf_agents.network.Network to be used by the agent. The
        network will be called with call(observation, action, step_type).
      actor_optimizer: The default optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      exploration_noise_std: Scale factor on exploration policy noise.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_actor_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target actor network during Q learning. Every
        `target_update_period` train steps, the weights from `actor_network` are
        copied (possibly withsmoothing via `target_update_tau`) to `
        target_actor_network`.  If `target_actor_network` is not provided, it is
        created by making a copy of `actor_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `actor_network`, so that this can be used as a target network.
        If you provide a `target_actor_network` that shares any weights with
        `actor_network`, a warning will be logged but no exception is thrown.
      target_critic_network: (Optional.) Similar network as target_actor_network
        but for the critic_network. See documentation for target_actor_network.
      target_critic_network_2: (Optional.) Similar network as
        target_actor_network but for the critic_network_2. See documentation for
        target_actor_network. Will only be used if 'critic_network_2' is also
        specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      actor_update_period: Period for the optimization step on actor network.
      td_errors_loss_fn:  A function for computing the TD errors loss. If None,
        a default value of elementwise huber_loss is used.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      target_policy_noise: Scale factor on target action noise
      target_policy_noise_clip: Value to clip noise.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall
        under that name. Defaults to the class name.
    """
    tf.Module.__init__(self, name=name)
    self._actor_network = actor_network
    actor_network.create_variables()
    if target_actor_network:
      target_actor_network.create_variables()
    self._target_actor_network = common.maybe_copy_target_network_with_checks(
        self._actor_network, target_actor_network, 'TargetActorNetwork')

    self._critic_network_1 = critic_network
    critic_network.create_variables()
    if target_critic_network:
      target_critic_network.create_variables()
    self._target_critic_network_1 = (
        common.maybe_copy_target_network_with_checks(self._critic_network_1,
                                                     target_critic_network,
                                                     'TargetCriticNetwork1'))

    if critic_network_2 is not None:
      self._critic_network_2 = critic_network_2
    else:
      self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
      # Do not use target_critic_network_2 if critic_network_2 is None.
      target_critic_network_2 = None
    self._critic_network_2.create_variables()
    if target_critic_network_2:
      target_critic_network_2.create_variables()
    self._target_critic_network_2 = (
        common.maybe_copy_target_network_with_checks(self._critic_network_2,
                                                     target_critic_network_2,
                                                     'TargetCriticNetwork2'))

    self._actor_optimizer = actor_optimizer
    self._critic_optimizer = critic_optimizer

    self._exploration_noise_std = exploration_noise_std
    self._target_update_tau = target_update_tau
    self._target_update_period = target_update_period
    self._actor_update_period = actor_update_period
    self._td_errors_loss_fn = (
        td_errors_loss_fn or common.element_wise_huber_loss)
    self._gamma = gamma
    self._reward_scale_factor = reward_scale_factor
    self._target_policy_noise = target_policy_noise
    self._target_policy_noise_clip = target_policy_noise_clip
    self._gradient_clipping = gradient_clipping

    self._update_target = self._get_target_updater(
        target_update_tau, target_update_period)

    policy = actor_policy.ActorPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec,
        actor_network=self._actor_network, clip=True)
    collect_policy = actor_policy.ActorPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec,
        actor_network=self._actor_network, clip=False)
    collect_policy = gaussian_policy.GaussianPolicy(
        collect_policy,
        scale=self._exploration_noise_std,
        clip=True)

    train_sequence_length = 2 if not self._actor_network.state_spec else None
    super(Td3Agent, self).__init__(
        time_step_spec,
        action_spec,
        policy,
        collect_policy,
        train_sequence_length=train_sequence_length,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step_counter,
        validate_args=False
    )

    self._as_transition = data_converter.AsTransition(
        self.data_context, squeeze_time_dim=(train_sequence_length == 2))
Example #4
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 critic_network: network.Network,
                 actor_network: network.Network,
                 actor_optimizer: types.Optimizer,
                 critic_optimizer: types.Optimizer,
                 alpha_optimizer: types.Optimizer,
                 actor_loss_weight: types.Float = 1.0,
                 critic_loss_weight: types.Float = 0.5,
                 alpha_loss_weight: types.Float = 1.0,
                 actor_policy_ctor: Callable[
                     ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy,
                 critic_network_2: Optional[network.Network] = None,
                 target_critic_network: Optional[network.Network] = None,
                 target_critic_network_2: Optional[network.Network] = None,
                 target_update_tau: types.Float = 1.0,
                 target_update_period: types.Int = 1,
                 td_errors_loss_fn: types.LossFn = tf.math.squared_difference,
                 gamma: types.Float = 1.0,
                 reward_scale_factor: types.Float = 1.0,
                 initial_log_alpha: types.Float = 0.0,
                 use_log_alpha_in_alpha_loss: bool = True,
                 target_entropy: Optional[types.Float] = None,
                 gradient_clipping: Optional[types.Float] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 train_step_counter: Optional[tf.Variable] = None,
                 name: Optional[Text] = None):
        """Creates a SAC Agent.

    Args:
      time_step_spec: A `TimeStep` spec of the expected time_steps.
      action_spec: A nest of BoundedTensorSpec representing the actions.
      critic_network: A function critic_network((observations, actions)) that
        returns the q_values for each observation and action.
      actor_network: A function actor_network(observation, action_spec) that
        returns action distribution.
      actor_optimizer: The optimizer to use for the actor network.
      critic_optimizer: The default optimizer to use for the critic network.
      alpha_optimizer: The default optimizer to use for the alpha variable.
      actor_loss_weight: The weight on actor loss.
      critic_loss_weight: The weight on critic loss.
      alpha_loss_weight: The weight on alpha loss.
      actor_policy_ctor: The policy class to use.
      critic_network_2: (Optional.)  A `tf_agents.network.Network` to be used as
        the second critic network during Q learning.  The weights from
        `critic_network` are copied if this is not provided.
      target_critic_network: (Optional.)  A `tf_agents.network.Network` to be
        used as the target critic network during Q learning. Every
        `target_update_period` train steps, the weights from `critic_network`
        are copied (possibly withsmoothing via `target_update_tau`) to `
        target_critic_network`.  If `target_critic_network` is not provided, it
        is created by making a copy of `critic_network`, which initializes a new
        network with the same structure and its own layers and weights.
        Performing a `Network.copy` does not work when the network instance
        already has trainable parameters (e.g., has already been built, or when
        the network is sharing layers with another).  In these cases, it is up
        to you to build a copy having weights that are not shared with the
        original `critic_network`, so that this can be used as a target network.
        If you provide a `target_critic_network` that shares any weights with
        `critic_network`, a warning will be logged but no exception is thrown.
      target_critic_network_2: (Optional.) Similar network as
        target_critic_network but for the critic_network_2. See documentation
        for target_critic_network. Will only be used if 'critic_network_2' is
        also specified.
      target_update_tau: Factor for soft update of the target networks.
      target_update_period: Period for soft update of the target networks.
      td_errors_loss_fn:  A function for computing the elementwise TD errors
        loss.
      gamma: A discount factor for future rewards.
      reward_scale_factor: Multiplicative scale for the reward.
      initial_log_alpha: Initial value for log_alpha.
      use_log_alpha_in_alpha_loss: A boolean, whether using log_alpha or alpha
        in alpha loss. Certain implementations of SAC use log_alpha as log
        values are generally nicer to work with.
      target_entropy: The target average policy entropy, for updating alpha. The
        default value is negative of the total number of actions.
      gradient_clipping: Norm length to clip gradients.
      debug_summaries: A bool to gather debug summaries.
      summarize_grads_and_vars: If True, gradient and network variable summaries
        will be written during training.
      train_step_counter: An optional counter to increment every time the train
        op is run.  Defaults to the global_step.
      name: The name of this agent. All variables in this module will fall under
        that name. Defaults to the class name.
    """
        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        self._critic_network_1 = critic_network
        self._critic_network_1.create_variables(
            (time_step_spec.observation, action_spec))
        if target_critic_network:
            target_critic_network.create_variables(
                (time_step_spec.observation, action_spec))
        self._target_critic_network_1 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_1, target_critic_network,
                'TargetCriticNetwork1'))

        if critic_network_2 is not None:
            self._critic_network_2 = critic_network_2
        else:
            self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
            # Do not use target_critic_network_2 if critic_network_2 is None.
            target_critic_network_2 = None
        self._critic_network_2.create_variables(
            (time_step_spec.observation, action_spec))
        if target_critic_network_2:
            target_critic_network_2.create_variables(
                (time_step_spec.observation, action_spec))
        self._target_critic_network_2 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_2, target_critic_network_2,
                'TargetCriticNetwork2'))

        if actor_network:
            actor_network.create_variables(time_step_spec.observation)
        self._actor_network = actor_network

        policy = actor_policy_ctor(time_step_spec=time_step_spec,
                                   action_spec=action_spec,
                                   actor_network=self._actor_network,
                                   training=False)

        self._train_policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=True)

        self._log_alpha = common.create_variable(
            'initial_log_alpha',
            initial_value=initial_log_alpha,
            dtype=tf.float32,
            trainable=True)

        if target_entropy is None:
            target_entropy = self._get_default_target_entropy(action_spec)

        self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer
        self._actor_loss_weight = actor_loss_weight
        self._critic_loss_weight = critic_loss_weight
        self._alpha_loss_weight = alpha_loss_weight
        self._td_errors_loss_fn = td_errors_loss_fn
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_entropy = target_entropy
        self._gradient_clipping = gradient_clipping
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, period=self._target_update_period)

        train_sequence_length = 2 if not critic_network.state_spec else None

        super(SacAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy=policy,
                             collect_policy=policy,
                             train_sequence_length=train_sequence_length,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter,
                             validate_args=False)

        self._as_transition = data_converter.AsTransition(
            self.data_context, squeeze_time_dim=(train_sequence_length == 2))
Example #5
0
    def __init__(self,
                 observation_spec,
                 action_spec,
                 actor_network: DistributionNetwork,
                 critic_network: Network,
                 gamma=0.99,
                 ou_stddev=0.2,
                 ou_damping=0.15,
                 actor_optimizer=None,
                 critic_optimizer=None,
                 target_update_tau=0.05,
                 target_update_period=10,
                 dqda_clipping=None,
                 gradient_clipping=None,
                 debug_summaries=False,
                 name="SarsaAlgorithm"):
        """Create an SarsaAlgorithm.

        Args:
            action_spec (nested BoundedTensorSpec): representing the actions.
            observation_spec (nested TensorSpec): spec for observation.
            actor_network (Network|DistributionNetwork):  The network will be
                called with call(observation, step_type). If it is DistributionNetwork
                an action will be sampled.
            critic_network (Network): The network will be called with
                call(observation, action, step_type).
            gamma (float): discount rate for reward
            ou_stddev (float): Only used for DDPG. Standard deviation for the
                Ornstein-Uhlenbeck (OU) noise added in the default collect policy.
            ou_damping (float): Only used for DDPG. Damping factor for the OU
                noise added in the default collect policy.
            target_update_tau (float): Factor for soft update of the target
                networks.
            target_update_period (int): Period for soft update of the target
                networks.
            dqda_clipping (float): when computing the actor loss, clips the
                gradient dqda element-wise between [-dqda_clipping, dqda_clipping].
                Does not perform clipping if dqda_clipping == 0.
            actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor.
            critic_optimizer (tf.optimizers.Optimizer): The optimizer for actor.
            gradient_clipping (float): Norm length to clip gradients.
            debug_summaries (bool): True if debug summaries should be created.
            name (str): The name of this algorithm.
        """
        if isinstance(actor_network, DistributionNetwork):
            self._action_distribution_spec = actor_network.output_spec
        elif isinstance(actor_network, Network):
            self._action_distribution_spec = action_spec
        else:
            raise ValueError("Expect DistributionNetwork or Network for"
                             " `actor_network`, got %s" % type(actor_network))

        super().__init__(observation_spec,
                         action_spec,
                         predict_state_spec=SarsaState(
                             prev_observation=observation_spec,
                             prev_step_type=tf.TensorSpec((), tf.int32),
                             actor=actor_network.state_spec),
                         train_state_spec=SarsaState(
                             prev_observation=observation_spec,
                             prev_step_type=tf.TensorSpec((), tf.int32),
                             actor=actor_network.state_spec,
                             target_actor=actor_network.state_spec,
                             critic=critic_network.state_spec,
                             target_critic=critic_network.state_spec,
                         ),
                         optimizer=[actor_optimizer, critic_optimizer],
                         trainable_module_sets=[[actor_network],
                                                [critic_network]],
                         gradient_clipping=gradient_clipping,
                         debug_summaries=debug_summaries,
                         name=name)
        self._actor_network = actor_network
        self._critic_network = critic_network
        self._target_actor_network = actor_network.copy(
            name='target_actor_network')
        self._target_critic_network = critic_network.copy(
            name='target_critic_network')
        self._update_target = common.get_target_updater(
            models=[self._actor_network, self._critic_network],
            target_models=[
                self._target_actor_network, self._target_critic_network
            ],
            tau=target_update_tau,
            period=target_update_period)
        self._dqda_clipping = dqda_clipping
        self._gamma = gamma
        self._ou_process = create_ou_process(action_spec, ou_stddev,
                                             ou_damping)
Example #6
0
    def __init__(self,
                 time_step_spec: ts.TimeStep,
                 action_spec: types.NestedTensorSpec,
                 critic_network: network.Network,
                 actor_network: network.Network,
                 actor_optimizer: types.Optimizer,
                 critic_optimizer: types.Optimizer,
                 alpha_optimizer: types.Optimizer,
                 actor_loss_weight: types.Float = 1.0,
                 critic_loss_weight: types.Float = 0.5,
                 alpha_loss_weight: types.Float = 1.0,
                 actor_policy_ctor: Callable[
                     ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy,
                 critic_network_2: Optional[network.Network] = None,
                 target_critic_network: Optional[network.Network] = None,
                 target_critic_network_2: Optional[network.Network] = None,
                 target_update_tau: types.Float = 1.0,
                 target_update_period: types.Int = 1,
                 td_errors_loss_fn: types.LossFn = tf.math.squared_difference,
                 gamma: types.Float = 1.0,
                 sigma: types.Float = 0.9,
                 reward_scale_factor: types.Float = 1.0,
                 initial_log_alpha: types.Float = 0.0,
                 use_log_alpha_in_alpha_loss: bool = True,
                 target_entropy: Optional[types.Float] = None,
                 gradient_clipping: Optional[types.Float] = None,
                 debug_summaries: bool = False,
                 summarize_grads_and_vars: bool = False,
                 train_step_counter: Optional[tf.Variable] = None,
                 name: Optional[Text] = None):

        tf.Module.__init__(self, name=name)

        self._check_action_spec(action_spec)

        net_observation_spec = time_step_spec.observation
        critic_spec = (net_observation_spec, action_spec)

        self._critic_network_1 = critic_network

        if critic_network_2 is not None:
            self._critic_network_2 = critic_network_2
        else:
            self._critic_network_2 = critic_network.copy(name='CriticNetwork2')
            # Do not use target_critic_network_2 if critic_network_2 is None.
            target_critic_network_2 = None

        # Wait until critic_network_2 has been copied from critic_network_1 before
        # creating variables on both.
        self._critic_network_1.create_variables(critic_spec)
        self._critic_network_2.create_variables(critic_spec)

        if target_critic_network:
            target_critic_network.create_variables(critic_spec)

        self._target_critic_network_1 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_1,
                target_critic_network,
                input_spec=critic_spec,
                name='TargetCriticNetwork1'))

        if target_critic_network_2:
            target_critic_network_2.create_variables(critic_spec)
        self._target_critic_network_2 = (
            common.maybe_copy_target_network_with_checks(
                self._critic_network_2,
                target_critic_network_2,
                input_spec=critic_spec,
                name='TargetCriticNetwork2'))

        if actor_network:
            actor_network.create_variables(net_observation_spec)
        self._actor_network = actor_network

        policy = actor_policy_ctor(time_step_spec=time_step_spec,
                                   action_spec=action_spec,
                                   actor_network=self._actor_network,
                                   training=False)

        self._train_policy = actor_policy_ctor(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            actor_network=self._actor_network,
            training=True)

        self._log_alpha = common.create_variable(
            'initial_log_alpha',
            initial_value=initial_log_alpha,
            dtype=tf.float32,
            trainable=True)

        if target_entropy is None:
            target_entropy = self._get_default_target_entropy(action_spec)

        self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss
        self._target_update_tau = target_update_tau
        self._target_update_period = target_update_period
        self._actor_optimizer = actor_optimizer
        self._critic_optimizer = critic_optimizer
        self._alpha_optimizer = alpha_optimizer
        self._actor_loss_weight = actor_loss_weight
        self._critic_loss_weight = critic_loss_weight
        self._alpha_loss_weight = alpha_loss_weight
        self._td_errors_loss_fn = td_errors_loss_fn
        self._gamma = gamma
        self._reward_scale_factor = reward_scale_factor
        self._target_entropy = target_entropy
        self._gradient_clipping = gradient_clipping
        self._debug_summaries = debug_summaries
        self._summarize_grads_and_vars = summarize_grads_and_vars
        self._update_target = self._get_target_updater(
            tau=self._target_update_tau, period=self._target_update_period)

        self.sigma = sigma

        train_sequence_length = 2 if not critic_network.state_spec else None

        super(sac_agent.SacAgent,
              self).__init__(time_step_spec,
                             action_spec,
                             policy=policy,
                             collect_policy=policy,
                             train_sequence_length=train_sequence_length,
                             debug_summaries=debug_summaries,
                             summarize_grads_and_vars=summarize_grads_and_vars,
                             train_step_counter=train_step_counter,
                             validate_args=False)

        self._as_transition = data_converter.AsTransition(
            self.data_context, squeeze_time_dim=(train_sequence_length == 2))
Example #7
0
    def __init__(self,
                 output_dim,
                 noise_dim=32,
                 input_tensor_spec=None,
                 hidden_layers=(256, ),
                 net: Network = None,
                 net_moving_average_rate=None,
                 entropy_regularization=0.,
                 kernel_sharpness=2.,
                 mi_weight=None,
                 mi_estimator_cls=MIEstimator,
                 optimizer: tf.optimizers.Optimizer = None,
                 name="Generator"):
        """Create a Generator.

        Args:
            output_dim (int): dimension of output
            noise_dim (int): dimension of noise
            input_tensor_spec (nested TensorSpec): spec of inputs. If there is
                no inputs, this should be None.
            hidden_layers (tuple): size of hidden layers.
            net (Network): network for generating outputs from [noise, inputs]
                or noise (if inputs is None). If None, a default one with
                hidden_layers will be created
            net_moving_average_rate (float): If provided, use a moving average
                version of net to do prediction. This has been shown to be
                effective for GAN training (arXiv:1907.02544, arXiv:1812.04948).
            entropy_regularization (float): weight of entropy regularization
            kernel_sharpness (float): Used only for entropy_regularization > 0.
                We calcualte the kernel in SVGD as:
                    exp(-kernel_sharpness * reduce_mean((x-y)^2/width)),
                where width is the elementwise moving average of (x-y)^2
            mi_estimator_cls (type): the class of mutual information estimator
                for maximizing the mutual information between [noise, inputs]
                and [outputs, inputs].
            optimizer (tf.optimizers.Optimizer): optimizer (optional)
            name (str): name of this generator
        """
        super().__init__(train_state_spec=(), optimizer=optimizer, name=name)
        self._noise_dim = noise_dim
        self._entropy_regularization = entropy_regularization
        if entropy_regularization == 0:
            self._grad_func = self._ml_grad
        else:
            self._grad_func = self._stein_grad
            self._kernel_width_averager = AdaptiveAverager(
                tensor_spec=tf.TensorSpec(shape=(output_dim, )))
            self._kernel_sharpness = kernel_sharpness

        noise_spec = tf.TensorSpec(shape=[noise_dim])

        if net is None:
            net_input_spec = noise_spec
            if input_tensor_spec is not None:
                net_input_spec = [net_input_spec, input_tensor_spec]
            net = EncodingNetwork(
                name="Generator",
                input_tensor_spec=net_input_spec,
                fc_layer_params=hidden_layers,
                last_layer_size=output_dim)

        self._mi_estimator = None
        self._input_tensor_spec = input_tensor_spec
        if mi_weight is not None:
            x_spec = noise_spec
            y_spec = tf.TensorSpec((output_dim, ))
            if input_tensor_spec is not None:
                x_spec = [x_spec, input_tensor_spec]
            self._mi_estimator = mi_estimator_cls(
                x_spec, y_spec, sampler='shift')
            self._mi_weight = mi_weight
        self._net = net
        self._predict_net = None
        self._net_moving_average_rate = net_moving_average_rate
        if net_moving_average_rate:
            self._predict_net = net.copy(name="Genrator_average")
            tfa_common.soft_variables_update(
                self._net.variables, self._predict_net.variables, tau=1.0)