Beispiel #1
0
    def _get_target_updater(self, tau=1.0, period=1):
        """Performs a soft update of the target network parameters.

    For each weight w_s in the q network, and its corresponding
    weight w_t in the target_q_network, a soft update is:
    w_t = (1 - tau) * w_t + tau * w_s

    Args:
      tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
      period: Step interval at which the target network is updated.

    Returns:
      A callable that performs a soft update of the target network parameters.
    """
        with tf.name_scope('update_targets'):

            def update():
                return common.soft_variables_update(
                    self._q_network.variables,
                    self._target_q_network.variables,
                    tau,
                    tau_non_trainable=1.0)

            return common.Periodically(update, period,
                                       'periodic_update_targets')
Beispiel #2
0
    def _get_target_updater(self, tau=1.0, period=1):
        """Performs a soft update of the target network parameters.

		For each weight w_s in the original network, and its corresponding
		weight w_t in the target network, a soft update is:
		w_t = (1- tau) x w_t + tau x ws

		Args:
			tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
			period: Step interval at which the target networks are updated.
		Returns:
			An operation that performs a soft update of the target network parameters.
		"""
        with tf.name_scope('get_target_updater'):

            def update():  # pylint: disable=missing-docstring
                critic_update_list = []
                for ensemble_index in range(self._ensemble_size):
                    critic_update = common.soft_variables_update(
                        self._critic_network_list[ensemble_index].variables,
                        self._target_critic_network_list[ensemble_index].
                        variables, tau)
                    critic_update_list.append(critic_update)
                actor_update = common.soft_variables_update(
                    self._actor_network.variables,
                    self._target_actor_network.variables, tau)
                return tf.group(critic_update_list + [actor_update])

            return common.Periodically(update, period,
                                       'periodic_update_targets')
Beispiel #3
0
  def _get_target_updater(self, tau=1.0, period=1):
    """Performs a soft update of the target network parameters.

    For each weight w_s in the original network, and its corresponding
    weight w_t in the target network, a soft update is:
    w_t = (1- tau) x w_t + tau x ws

    Args:
      tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
      period: Step interval at which the target networks are updated.
    Returns:
      A callable that performs a soft update of the target network parameters.
    """
    with tf.name_scope('update_targets'):
      def update():  # pylint: disable=missing-docstring
        # TODO(b/124381161): What about observation normalizer variables?
        critic_update_1 = common.soft_variables_update(
            self._critic_network_1.variables,
            self._target_critic_network_1.variables, tau)
        critic_update_2 = common.soft_variables_update(
            self._critic_network_2.variables,
            self._target_critic_network_2.variables, tau)
        actor_update = common.soft_variables_update(
            self._actor_network.variables, self._target_actor_network.variables,
            tau)
        return tf.group(critic_update_1, critic_update_2, actor_update)

      return common.Periodically(update, period, 'update_targets')
Beispiel #4
0
  def _get_target_updater(self, tau=1.0, period=1):
    """Performs a soft update of the target network parameters.

    For each weight w_s in the original network, and its corresponding
    weight w_t in the target network, a soft update is:
    w_t = (1- tau) x w_t + tau x ws

    Args:
      tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
      period: Step interval at which the target network is updated.

    Returns:
      A callable that performs a soft update of the target network parameters.
    """
    with tf.name_scope('update_target'):

      def update():
        """Update target network."""
        critic_update_1 = common.soft_variables_update(
            self._critic_network1.variables,
            self._target_critic_network1.variables, tau)
        critic_update_2 = common.soft_variables_update(
            self._critic_network2.variables,
            self._target_critic_network2.variables, tau)
        return tf.group(critic_update_1, critic_update_2)

      return common.Periodically(update, period, 'update_targets')
Beispiel #5
0
def get_target_updater(models, target_models, tau=1.0, period=1):
    """Performs a soft update of the target model parameters.

    For each weight w_s in the model, and its corresponding
    weight w_t in the target_model, a soft update is:
    w_t = (1 - tau) * w_t + tau * w_s

    Args:
        models (Network | list[Network]): the current model.
        target_models (Network | list[Network]): the model to be updated.
        tau (float): A float scalar in [0, 1]. Default `tau=1.0` means hard
            update.
        period (int): Step interval at which the target model is updated.

    Returns:
        A callable that performs a soft update of the target model parameters.
    """
    models = as_list(models)
    target_models = as_list(target_models)

    def update():
        update_ops = []
        for model, target_model in zip(models, target_models):
            update_op = tfa_common.soft_variables_update(
                model.variables, target_model.variables, tau)
            update_ops.append(update_op)
        return tf.group(*update_ops)

    return tfa_common.Periodically(update, period, 'periodic_update_targets')
Beispiel #6
0
    def _get_target_updater(self, tau=1.0, period=1):
        def update():
            return tfagents_common.soft_variables_update(
                self._value_network.variables,
                self._target_network.variables,
                tau,
                tau_non_trainable=1.0)

        return tfagents_common.Periodically(update, period, 'update_targets')
Beispiel #7
0
    def _get_target_updater(self, tau=1.0, period=1):
        with tf.name_scope("update_targets"):

            def update():
                return common.soft_variables_update(
                    self._q_network.variables,
                    self._target_q_network.variables, tau)

        return common.Periodically(update, period, "periodic_update_targets")
Beispiel #8
0
def get_target_updater(sc_net, target_sc_net,
                       tau=0.005, period=1.,
                       name='update_target_sc_offline'):
  with tf.name_scope(name):
    def update():
      """Update target network."""
      critic_update = common.soft_variables_update(
        sc_net.variables,
        target_sc_net.variables, tau)
      return critic_update

    return common.Periodically(update, period, 'target_update')
Beispiel #9
0
 def _get_target_updater(self, tau=1.0, period=1):
   scope = 'update_target'
   with tf.name_scope(scope):
     def update():
       """Update target network."""
       critic_update_1 = common.soft_variables_update(
           self._critic_network1.variables,
           self._target_critic_network1.variables, tau)
       critic_update_2 = common.soft_variables_update(
           self._critic_network2.variables,
           self._target_critic_network2.variables, tau)
       return tf.group(critic_update_1, critic_update_2)
     return common.Periodically(update, period, 'update_targets')
Beispiel #10
0
    def _get_target_updater(self, tau=1.0, period=1):
        """Performs a soft update of the target network parameters.

    For each weight w_s in the original network, and its corresponding
    weight w_t in the target network, a soft update is:
    w_t = (1- tau) x w_t + tau x ws

    Args:
      tau: A float scalar in [0, 1]. Default `tau=1.0` means hard update.
      period: Step interval at which the target networks are updated.

    Returns:
      A callable that performs a soft update of the target network parameters.
    """
        with tf.name_scope('update_targets'):

            def update():  # pylint: disable=missing-docstring
                critic_update_1 = common.soft_variables_update(
                    self._critic_network_1.variables,
                    self._target_critic_network_1.variables,
                    tau,
                    tau_non_trainable=1.0)

                critic_2_update_vars = common.deduped_network_variables(
                    self._critic_network_2, self._critic_network_1)
                target_critic_2_update_vars = common.deduped_network_variables(
                    self._target_critic_network_2,
                    self._target_critic_network_1)

                critic_update_2 = common.soft_variables_update(
                    critic_2_update_vars,
                    target_critic_2_update_vars,
                    tau,
                    tau_non_trainable=1.0)

                actor_update_vars = common.deduped_network_variables(
                    self._actor_network, self._critic_network_1,
                    self._critic_network_2)
                target_actor_update_vars = common.deduped_network_variables(
                    self._target_actor_network, self._target_critic_network_1,
                    self._target_critic_network_2)

                actor_update = common.soft_variables_update(
                    actor_update_vars,
                    target_actor_update_vars,
                    tau,
                    tau_non_trainable=1.0)
                return tf.group(critic_update_1, critic_update_2, actor_update)

            return common.Periodically(update, period, 'update_targets')
Beispiel #11
0
    def _get_target_updater(self, tau=1.0, period=1):
        """Performs a soft update of the target network parameters.
        For each weight w_s in the network, and its corresponding
        weight w_t in the target_network, a soft update is:
        w_t = (1 - tau) * w_t + tau * w_s
        """
        with tf.name_scope('update_targets'):

            def update():
                return common.soft_variables_update(
                    self.QMIXNet.variables,
                    self.TargetQMIXNet.variables,
                    tau,
                    tau_non_trainable=1.0)

            return common.Periodically(update, period,
                                       'periodic_update_targets')
Beispiel #12
0
    def _get_target_updater(self, tau=1.0, period=1):
        with tf.compat.v1.name_scope('get_target_updater'):

            def update():
                critic_update_list = []
                for ensemble_index in range(self._ensemble_size):
                    critic_update = common.soft_variables_update(
                        self._critic_network_list[ensemble_index].variables,
                        self._target_critic_network_list[ensemble_index].
                        variables, tau)
                    critic_update_list.append(critic_update)
                actor_update = common.soft_variables_update(
                    self._actor_network.variables,
                    self._target_actor_network.variables, tau)
                return tf.group(critic_update_list + [actor_update])

            return common.Periodically(update, period,
                                       'periodic_update_targets')