def update(): # pylint: disable=missing-docstring # TODO(b/124381161): What about observation normalizer variables? critic_update_1 = common.soft_variables_update( self._critic_network_1.variables, self._target_critic_network_1.variables, tau, tau_non_trainable=1.0) critic_2_update_vars = common.deduped_network_variables( self._critic_network_2, self._critic_network_1) target_critic_2_update_vars = common.deduped_network_variables( self._target_critic_network_2, self._target_critic_network_1) critic_update_2 = common.soft_variables_update( critic_2_update_vars, target_critic_2_update_vars, tau, tau_non_trainable=1.0) actor_update_vars = common.deduped_network_variables( self._actor_network, self._critic_network_1, self._critic_network_2) target_actor_update_vars = common.deduped_network_variables( self._target_actor_network, self._target_critic_network_1, self._target_critic_network_2) actor_update = common.soft_variables_update( actor_update_vars, target_actor_update_vars, tau, tau_non_trainable=1.0) return tf.group(critic_update_1, critic_update_2, actor_update)
def update(): # pylint: disable=missing-docstring critic_update_1 = common.soft_variables_update( self._critic_network_1.variables, self._target_critic_network_1.variables, tau, tau_non_trainable=1.0) critic_2_update_vars = common.deduped_network_variables( self._critic_network_2, self._critic_network_1) target_critic_2_update_vars = common.deduped_network_variables( self._target_critic_network_2, self._target_critic_network_1) critic_update_2 = common.soft_variables_update( critic_2_update_vars, target_critic_2_update_vars, tau, tau_non_trainable=1.0) actor_update_vars = common.deduped_network_variables( self._actor_network, self._critic_network_1, self._critic_network_2) target_actor_update_vars = common.deduped_network_variables( self._target_actor_network, self._target_critic_network_1, self._target_critic_network_2) actor_update = common.soft_variables_update( actor_update_vars, target_actor_update_vars, tau, tau_non_trainable=1.0) return tf.group(critic_update_1, critic_update_2, actor_update)
def update(): """Update target network.""" critic_update_1 = common.soft_variables_update( self._critic_network_1.variables, self._target_critic_network_1.variables, tau, tau_non_trainable=1.0) critic_2_update_vars = common.deduped_network_variables( self._critic_network_2, self._critic_network_1) target_critic_2_update_vars = common.deduped_network_variables( self._target_critic_network_2, self._target_critic_network_1) critic_update_2 = common.soft_variables_update( critic_2_update_vars, target_critic_2_update_vars, tau, tau_non_trainable=1.0) if self._critic_network_no_entropy_1 is None: return tf.group(critic_update_1, critic_update_2) else: critic_no_entropy_update_1 = common.soft_variables_update( self._critic_network_no_entropy_1.variables, self._target_critic_network_no_entropy_1.variables, tau, tau_non_trainable=1.0) critic_no_entropy_2_update_vars = common.deduped_network_variables( self._critic_network_no_entropy_2, self._critic_network_no_entropy_1) target_critic_no_entropy_2_update_vars = common.deduped_network_variables( self._target_critic_network_no_entropy_2, self._target_critic_network_no_entropy_1) critic_no_entropy_update_2 = common.soft_variables_update( critic_no_entropy_2_update_vars, target_critic_no_entropy_2_update_vars, tau, tau_non_trainable=1.0) return tf.group(critic_update_1, critic_update_2, critic_no_entropy_update_1, critic_no_entropy_update_2)