def update():  # pylint: disable=missing-docstring
        # TODO(b/124381161): What about observation normalizer variables?
        critic_update_1 = common.soft_variables_update(
            self._critic_network_1.variables,
            self._target_critic_network_1.variables,
            tau,
            tau_non_trainable=1.0)

        critic_2_update_vars = common.deduped_network_variables(
            self._critic_network_2, self._critic_network_1)
        target_critic_2_update_vars = common.deduped_network_variables(
            self._target_critic_network_2, self._target_critic_network_1)

        critic_update_2 = common.soft_variables_update(
            critic_2_update_vars,
            target_critic_2_update_vars,
            tau,
            tau_non_trainable=1.0)

        actor_update_vars = common.deduped_network_variables(
            self._actor_network, self._critic_network_1, self._critic_network_2)
        target_actor_update_vars = common.deduped_network_variables(
            self._target_actor_network, self._target_critic_network_1,
            self._target_critic_network_2)

        actor_update = common.soft_variables_update(
            actor_update_vars,
            target_actor_update_vars,
            tau,
            tau_non_trainable=1.0)
        return tf.group(critic_update_1, critic_update_2, actor_update)
Example #2
0
      def update():  # pylint: disable=missing-docstring
        critic_update_1 = common.soft_variables_update(
            self._critic_network_1.variables,
            self._target_critic_network_1.variables,
            tau,
            tau_non_trainable=1.0)

        critic_2_update_vars = common.deduped_network_variables(
            self._critic_network_2, self._critic_network_1)
        target_critic_2_update_vars = common.deduped_network_variables(
            self._target_critic_network_2, self._target_critic_network_1)

        critic_update_2 = common.soft_variables_update(
            critic_2_update_vars,
            target_critic_2_update_vars,
            tau,
            tau_non_trainable=1.0)

        actor_update_vars = common.deduped_network_variables(
            self._actor_network, self._critic_network_1, self._critic_network_2)
        target_actor_update_vars = common.deduped_network_variables(
            self._target_actor_network, self._target_critic_network_1,
            self._target_critic_network_2)

        actor_update = common.soft_variables_update(
            actor_update_vars,
            target_actor_update_vars,
            tau,
            tau_non_trainable=1.0)
        return tf.group(critic_update_1, critic_update_2, actor_update)
Example #3
0
            def update():
                """Update target network."""
                critic_update_1 = common.soft_variables_update(
                    self._critic_network_1.variables,
                    self._target_critic_network_1.variables,
                    tau,
                    tau_non_trainable=1.0)

                critic_2_update_vars = common.deduped_network_variables(
                    self._critic_network_2, self._critic_network_1)

                target_critic_2_update_vars = common.deduped_network_variables(
                    self._target_critic_network_2,
                    self._target_critic_network_1)

                critic_update_2 = common.soft_variables_update(
                    critic_2_update_vars,
                    target_critic_2_update_vars,
                    tau,
                    tau_non_trainable=1.0)

                if self._critic_network_no_entropy_1 is None:
                    return tf.group(critic_update_1, critic_update_2)
                else:
                    critic_no_entropy_update_1 = common.soft_variables_update(
                        self._critic_network_no_entropy_1.variables,
                        self._target_critic_network_no_entropy_1.variables,
                        tau,
                        tau_non_trainable=1.0)

                    critic_no_entropy_2_update_vars = common.deduped_network_variables(
                        self._critic_network_no_entropy_2,
                        self._critic_network_no_entropy_1)

                    target_critic_no_entropy_2_update_vars = common.deduped_network_variables(
                        self._target_critic_network_no_entropy_2,
                        self._target_critic_network_no_entropy_1)

                    critic_no_entropy_update_2 = common.soft_variables_update(
                        critic_no_entropy_2_update_vars,
                        target_critic_no_entropy_2_update_vars,
                        tau,
                        tau_non_trainable=1.0)

                    return tf.group(critic_update_1, critic_update_2,
                                    critic_no_entropy_update_1,
                                    critic_no_entropy_update_2)