Пример #1
0
  def update_targets(self, tau=1.0):
    """Performs a soft update of the target network parameters.

    For each weight w_s in the actor/critic networks, and its corresponding
    weight w_t in the target actor/critic networks, a soft update is:
    w_t = (1- tau) x w_t + tau x ws

    Args:
      tau: A float scalar in [0, 1]
    Returns:
      An operation that performs a soft update of the target network parameters.
    Raises:
      ValueError: If `tau` is not in [0, 1].
    """
    if tau < 0 or tau > 1:
      raise ValueError('Input `tau` should be in [0, 1].')
    update_actor = utils.soft_variables_update(
        slim.get_trainable_variables(
            utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
        slim.get_trainable_variables(
            utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
        tau)
    # NOTE: This updates both critic networks.
    update_critic = utils.soft_variables_update(
        slim.get_trainable_variables(
            utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
        slim.get_trainable_variables(
            utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
        tau)
    return tf.group(update_actor, update_critic, name='update_targets')
Пример #2
0
  def update_targets(self, tau=1.0):
    """Performs a soft update of the target network parameters.

    For each weight w_s in the actor/critic networks, and its corresponding
    weight w_t in the target actor/critic networks, a soft update is:
    w_t = (1- tau) x w_t + tau x ws

    Args:
      tau: A float scalar in [0, 1]
    Returns:
      An operation that performs a soft update of the target network parameters.
    Raises:
      ValueError: If `tau` is not in [0, 1].
    """
    if tau < 0 or tau > 1:
      raise ValueError('Input `tau` should be in [0, 1].')
    update_actor = utils.soft_variables_update(
        slim.get_trainable_variables(
            utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)),
        slim.get_trainable_variables(
            utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)),
        tau)
    # NOTE: This updates both critic networks.
    update_critic = utils.soft_variables_update(
        slim.get_trainable_variables(
            utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)),
        slim.get_trainable_variables(
            utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)),
        tau)
    return tf.group(update_actor, update_critic, name='update_targets')
Пример #3
0
  def get_actor_vars(self):
    """Returns a list of all variables in the actor network.

    Returns:
      A list of trainable variables in the actor network.
    """
    return slim.get_model_variables(
        utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
Пример #4
0
  def get_critic_vars(self):
    """Returns a list of all variables in the critic network.

    Returns:
      A list of trainable variables in the critic network.
    """
    return slim.get_model_variables(
        utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
Пример #5
0
  def get_actor_vars(self):
    """Returns a list of all variables in the actor network.

    Returns:
      A list of trainable variables in the actor network.
    """
    return slim.get_model_variables(
        utils.join_scope(self._scope, self.ACTOR_NET_SCOPE))
Пример #6
0
  def get_critic_vars(self):
    """Returns a list of all variables in the critic network.

    Returns:
      A list of trainable variables in the critic network.
    """
    return slim.get_model_variables(
        utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))
    def get_trainable_completion_vars(self):
        """Returns a list of trainable variables in the critic network.

    Returns:
      A list of trainable variables in the critic network.
    """
        return slim.get_trainable_variables(
            uvf_utils.join_scope(self._scope, self.COMPLETION_NET_SCOPE))
    def get_trainable_reward_vars(self):
        """Returns a list of trainable variables in the critic network.

    Returns:
      A list of trainable variables in the critic network.
    """
        return slim.get_trainable_variables(
            uvf_utils.join_scope(self._scope, self.REWARD_NET_SCOPE))
Пример #9
0
    def get_trainable_critic_vars(self):
        """Returns a list of trainable variables in the critic network.
    NOTE: This gets the vars of both critic networks.

    Returns:
      A list of trainable variables in the critic network.
    """
        return (slim.get_trainable_variables(
            utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)))
Пример #10
0
  def get_trainable_critic_vars(self):
    """Returns a list of trainable variables in the critic network.
    NOTE: This gets the vars of both critic networks.

    Returns:
      A list of trainable variables in the critic network.
    """
    return (
        slim.get_trainable_variables(
            utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)))
Пример #11
0
 def get_trainable_vars(self):
   return (
       slim.get_trainable_variables(
           uvf_utils.join_scope(self._scope, self.STATE_PREPROCESS_NET_SCOPE)) +
       slim.get_trainable_variables(
           uvf_utils.join_scope(self._scope, self.ACTION_EMBED_NET_SCOPE)))
Пример #12
0
 def get_trainable_vars(self):
   return (
       slim.get_trainable_variables(
           uvf_utils.join_scope(self._scope, self.STATE_PREPROCESS_NET_SCOPE)) +
       slim.get_trainable_variables(
           uvf_utils.join_scope(self._scope, self.ACTION_EMBED_NET_SCOPE)))