def __init__(self, action_spec, actor_network: Network, critic_network: Network, critic_loss=None, target_entropy=None, initial_log_alpha=0.0, target_update_tau=0.05, target_update_period=1, dqda_clipping=None, actor_optimizer=None, critic_optimizer=None, alpha_optimizer=None, gradient_clipping=None, train_step_counter=None, debug_summaries=False, name="SacAlgorithm"): """Create a SacAlgorithm Args: action_spec (nested BoundedTensorSpec): representing the actions. actor_network (Network): The network will be called with call(observation, step_type). critic_network (Network): The network will be called with call(observation, action, step_type). critic_loss (None|OneStepTDLoss): an object for calculating critic loss. If None, a default OneStepTDLoss will be used. initial_log_alpha (float): initial value for variable log_alpha target_entropy (float|None): The target average policy entropy, for updating alpha. target_update_tau (float): Factor for soft update of the target networks. target_update_period (int): Period for soft update of the target networks. dqda_clipping (float): when computing the actor loss, clips the gradient dqda element-wise between [-dqda_clipping, dqda_clipping]. Does not perform clipping if dqda_clipping == 0. actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor. critic_optimizer (tf.optimizers.Optimizer): The optimizer for critic. alpha_optimizer (tf.optimizers.Optimizer): The optimizer for alpha. gradient_clipping (float): Norm length to clip gradients. train_step_counter (tf.Variable): An optional counter to increment every time the a new iteration is started. If None, it will use tf.summary.experimental.get_step(). If this is still None, a counter will be created. debug_summaries (bool): True if debug summaries should be created. name (str): The name of this algorithm. """ critic_network1 = critic_network critic_network2 = critic_network.copy(name='CriticNetwork2') log_alpha = tfa_common.create_variable(name='log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) super().__init__( action_spec, train_state_spec=SacState( share=SacShareState(actor=actor_network.state_spec), actor=SacActorState(critic1=critic_network.state_spec, critic2=critic_network.state_spec), critic=SacCriticState( critic1=critic_network.state_spec, critic2=critic_network.state_spec, target_critic1=critic_network.state_spec, target_critic2=critic_network.state_spec)), action_distribution_spec=actor_network.output_spec, predict_state_spec=actor_network.state_spec, optimizer=[actor_optimizer, critic_optimizer, alpha_optimizer], get_trainable_variables_func=[ lambda: actor_network.trainable_variables, lambda: (critic_network1.trainable_variables + critic_network2. trainable_variables), lambda: [log_alpha] ], gradient_clipping=gradient_clipping, train_step_counter=train_step_counter, debug_summaries=debug_summaries, name=name) self._log_alpha = log_alpha self._actor_network = actor_network self._critic_network1 = critic_network1 self._critic_network2 = critic_network2 self._target_critic_network1 = self._critic_network1.copy( name='TargetCriticNetwork1') self._target_critic_network2 = self._critic_network2.copy( name='TargetCriticNetwork2') self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._alpha_optimizer = alpha_optimizer if critic_loss is None: critic_loss = OneStepTDLoss(debug_summaries=debug_summaries) self._critic_loss = critic_loss flat_action_spec = tf.nest.flatten(self._action_spec) self._is_continuous = tensor_spec.is_continuous(flat_action_spec[0]) if target_entropy is None: target_entropy = np.sum( list( map(dist_utils.calc_default_target_entropy, flat_action_spec))) self._target_entropy = target_entropy self._dqda_clipping = dqda_clipping self._update_target = common.get_target_updater( models=[self._critic_network1, self._critic_network2], target_models=[ self._target_critic_network1, self._target_critic_network2 ], tau=target_update_tau, period=target_update_period) tfa_common.soft_variables_update( self._critic_network1.variables, self._target_critic_network1.variables, tau=1.0) tfa_common.soft_variables_update( self._critic_network2.variables, self._target_critic_network2.variables, tau=1.0)
def __init__(self, action_spec, actor_network: Network, critic_network: Network, ou_stddev=0.2, ou_damping=0.15, critic_loss=None, target_update_tau=0.05, target_update_period=1, dqda_clipping=None, actor_optimizer=None, critic_optimizer=None, gradient_clipping=None, train_step_counter=None, debug_summaries=False, name="DdpgAlgorithm"): """ Args: action_spec (nested BoundedTensorSpec): representing the actions. actor_network (Network): The network will be called with call(observation, step_type). critic_network (Network): The network will be called with call(observation, action, step_type). ou_stddev (float): Standard deviation for the Ornstein-Uhlenbeck (OU) noise added in the default collect policy. ou_damping (float): Damping factor for the OU noise added in the default collect policy. critic_loss (None|OneStepTDLoss): an object for calculating critic loss. If None, a default OneStepTDLoss will be used. target_update_tau (float): Factor for soft update of the target networks. target_update_period (int): Period for soft update of the target networks. dqda_clipping (float): when computing the actor loss, clips the gradient dqda element-wise between [-dqda_clipping, dqda_clipping]. Does not perform clipping if dqda_clipping == 0. actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor. critic_optimizer (tf.optimizers.Optimizer): The optimizer for actor. gradient_clipping (float): Norm length to clip gradients. train_step_counter (tf.Variable): An optional counter to increment every time the a new iteration is started. If None, it will use tf.summary.experimental.get_step(). If this is still None, a counter will be created. debug_summaries (bool): True if debug summaries should be created. name (str): The name of this algorithm. """ train_state_spec = DdpgState( actor=DdpgActorState(actor=actor_network.state_spec, critic=critic_network.state_spec), critic=DdpgCriticState(critic=critic_network.state_spec, target_actor=actor_network.state_spec, target_critic=critic_network.state_spec)) super().__init__(action_spec, train_state_spec=train_state_spec, action_distribution_spec=action_spec, optimizer=[actor_optimizer, critic_optimizer], get_trainable_variables_func=[ lambda: actor_network.trainable_variables, lambda: critic_network.trainable_variables ], gradient_clipping=gradient_clipping, train_step_counter=train_step_counter, debug_summaries=debug_summaries, name=name) self._actor_network = actor_network self._critic_network = critic_network self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._target_actor_network = actor_network.copy( name='target_actor_network') self._target_critic_network = critic_network.copy( name='target_critic_network') self._ou_stddev = ou_stddev self._ou_damping = ou_damping if critic_loss is None: critic_loss = OneStepTDLoss(debug_summaries=debug_summaries) self._critic_loss = critic_loss self._ou_process = self._create_ou_process(ou_stddev, ou_damping) self._update_target = common.get_target_updater( models=[self._actor_network, self._critic_network], target_models=[ self._target_actor_network, self._target_critic_network ], tau=target_update_tau, period=target_update_period) self._dqda_clipping = dqda_clipping tfa_common.soft_variables_update(self._critic_network.variables, self._target_critic_network.variables, tau=1.0) tfa_common.soft_variables_update(self._actor_network.variables, self._target_actor_network.variables, tau=1.0)
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensor, actor_network: network.Network, critic_network: network.Network, actor_optimizer: types.Optimizer, critic_optimizer: types.Optimizer, exploration_noise_std: types.Float = 0.1, critic_network_2: Optional[network.Network] = None, target_actor_network: Optional[network.Network] = None, target_critic_network: Optional[network.Network] = None, target_critic_network_2: Optional[network.Network] = None, target_update_tau: types.Float = 1.0, target_update_period: types.Int = 1, actor_update_period: types.Int = 1, td_errors_loss_fn: Optional[types.LossFn] = None, gamma: types.Float = 1.0, reward_scale_factor: types.Float = 1.0, target_policy_noise: types.Float = 0.2, target_policy_noise_clip: types.Float = 0.5, gradient_clipping: Optional[types.Float] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, train_step_counter: Optional[tf.Variable] = None, name: Text = None): """Creates a Td3Agent Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. actor_network: A tf_agents.network.Network to be used by the agent. The network will be called with call(observation, step_type). critic_network: A tf_agents.network.Network to be used by the agent. The network will be called with call(observation, action, step_type). actor_optimizer: The default optimizer to use for the actor network. critic_optimizer: The default optimizer to use for the critic network. exploration_noise_std: Scale factor on exploration policy noise. critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as the second critic network during Q learning. The weights from `critic_network` are copied if this is not provided. target_actor_network: (Optional.) A `tf_agents.network.Network` to be used as the target actor network during Q learning. Every `target_update_period` train steps, the weights from `actor_network` are copied (possibly withsmoothing via `target_update_tau`) to ` target_actor_network`. If `target_actor_network` is not provided, it is created by making a copy of `actor_network`, which initializes a new network with the same structure and its own layers and weights. Performing a `Network.copy` does not work when the network instance already has trainable parameters (e.g., has already been built, or when the network is sharing layers with another). In these cases, it is up to you to build a copy having weights that are not shared with the original `actor_network`, so that this can be used as a target network. If you provide a `target_actor_network` that shares any weights with `actor_network`, a warning will be logged but no exception is thrown. target_critic_network: (Optional.) Similar network as target_actor_network but for the critic_network. See documentation for target_actor_network. target_critic_network_2: (Optional.) Similar network as target_actor_network but for the critic_network_2. See documentation for target_actor_network. Will only be used if 'critic_network_2' is also specified. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. actor_update_period: Period for the optimization step on actor network. td_errors_loss_fn: A function for computing the TD errors loss. If None, a default value of elementwise huber_loss is used. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. target_policy_noise: Scale factor on target action noise target_policy_noise_clip: Value to clip noise. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. """ tf.Module.__init__(self, name=name) self._actor_network = actor_network actor_network.create_variables() if target_actor_network: target_actor_network.create_variables() self._target_actor_network = common.maybe_copy_target_network_with_checks( self._actor_network, target_actor_network, 'TargetActorNetwork') self._critic_network_1 = critic_network critic_network.create_variables() if target_critic_network: target_critic_network.create_variables() self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks(self._critic_network_1, target_critic_network, 'TargetCriticNetwork1')) if critic_network_2 is not None: self._critic_network_2 = critic_network_2 else: self._critic_network_2 = critic_network.copy(name='CriticNetwork2') # Do not use target_critic_network_2 if critic_network_2 is None. target_critic_network_2 = None self._critic_network_2.create_variables() if target_critic_network_2: target_critic_network_2.create_variables() self._target_critic_network_2 = ( common.maybe_copy_target_network_with_checks(self._critic_network_2, target_critic_network_2, 'TargetCriticNetwork2')) self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._exploration_noise_std = exploration_noise_std self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_update_period = actor_update_period self._td_errors_loss_fn = ( td_errors_loss_fn or common.element_wise_huber_loss) self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_policy_noise = target_policy_noise self._target_policy_noise_clip = target_policy_noise_clip self._gradient_clipping = gradient_clipping self._update_target = self._get_target_updater( target_update_tau, target_update_period) policy = actor_policy.ActorPolicy( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, clip=True) collect_policy = actor_policy.ActorPolicy( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, clip=False) collect_policy = gaussian_policy.GaussianPolicy( collect_policy, scale=self._exploration_noise_std, clip=True) train_sequence_length = 2 if not self._actor_network.state_spec else None super(Td3Agent, self).__init__( time_step_spec, action_spec, policy, collect_policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter, validate_args=False ) self._as_transition = data_converter.AsTransition( self.data_context, squeeze_time_dim=(train_sequence_length == 2))
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, critic_network: network.Network, actor_network: network.Network, actor_optimizer: types.Optimizer, critic_optimizer: types.Optimizer, alpha_optimizer: types.Optimizer, actor_loss_weight: types.Float = 1.0, critic_loss_weight: types.Float = 0.5, alpha_loss_weight: types.Float = 1.0, actor_policy_ctor: Callable[ ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy, critic_network_2: Optional[network.Network] = None, target_critic_network: Optional[network.Network] = None, target_critic_network_2: Optional[network.Network] = None, target_update_tau: types.Float = 1.0, target_update_period: types.Int = 1, td_errors_loss_fn: types.LossFn = tf.math.squared_difference, gamma: types.Float = 1.0, reward_scale_factor: types.Float = 1.0, initial_log_alpha: types.Float = 0.0, use_log_alpha_in_alpha_loss: bool = True, target_entropy: Optional[types.Float] = None, gradient_clipping: Optional[types.Float] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, train_step_counter: Optional[tf.Variable] = None, name: Optional[Text] = None): """Creates a SAC Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. critic_network: A function critic_network((observations, actions)) that returns the q_values for each observation and action. actor_network: A function actor_network(observation, action_spec) that returns action distribution. actor_optimizer: The optimizer to use for the actor network. critic_optimizer: The default optimizer to use for the critic network. alpha_optimizer: The default optimizer to use for the alpha variable. actor_loss_weight: The weight on actor loss. critic_loss_weight: The weight on critic loss. alpha_loss_weight: The weight on alpha loss. actor_policy_ctor: The policy class to use. critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as the second critic network during Q learning. The weights from `critic_network` are copied if this is not provided. target_critic_network: (Optional.) A `tf_agents.network.Network` to be used as the target critic network during Q learning. Every `target_update_period` train steps, the weights from `critic_network` are copied (possibly withsmoothing via `target_update_tau`) to ` target_critic_network`. If `target_critic_network` is not provided, it is created by making a copy of `critic_network`, which initializes a new network with the same structure and its own layers and weights. Performing a `Network.copy` does not work when the network instance already has trainable parameters (e.g., has already been built, or when the network is sharing layers with another). In these cases, it is up to you to build a copy having weights that are not shared with the original `critic_network`, so that this can be used as a target network. If you provide a `target_critic_network` that shares any weights with `critic_network`, a warning will be logged but no exception is thrown. target_critic_network_2: (Optional.) Similar network as target_critic_network but for the critic_network_2. See documentation for target_critic_network. Will only be used if 'critic_network_2' is also specified. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. td_errors_loss_fn: A function for computing the elementwise TD errors loss. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. initial_log_alpha: Initial value for log_alpha. use_log_alpha_in_alpha_loss: A boolean, whether using log_alpha or alpha in alpha loss. Certain implementations of SAC use log_alpha as log values are generally nicer to work with. target_entropy: The target average policy entropy, for updating alpha. The default value is negative of the total number of actions. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. """ tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) self._critic_network_1 = critic_network self._critic_network_1.create_variables( (time_step_spec.observation, action_spec)) if target_critic_network: target_critic_network.create_variables( (time_step_spec.observation, action_spec)) self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks( self._critic_network_1, target_critic_network, 'TargetCriticNetwork1')) if critic_network_2 is not None: self._critic_network_2 = critic_network_2 else: self._critic_network_2 = critic_network.copy(name='CriticNetwork2') # Do not use target_critic_network_2 if critic_network_2 is None. target_critic_network_2 = None self._critic_network_2.create_variables( (time_step_spec.observation, action_spec)) if target_critic_network_2: target_critic_network_2.create_variables( (time_step_spec.observation, action_spec)) self._target_critic_network_2 = ( common.maybe_copy_target_network_with_checks( self._critic_network_2, target_critic_network_2, 'TargetCriticNetwork2')) if actor_network: actor_network.create_variables(time_step_spec.observation) self._actor_network = actor_network policy = actor_policy_ctor(time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=False) self._train_policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=True) self._log_alpha = common.create_variable( 'initial_log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) if target_entropy is None: target_entropy = self._get_default_target_entropy(action_spec) self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._alpha_optimizer = alpha_optimizer self._actor_loss_weight = actor_loss_weight self._critic_loss_weight = critic_loss_weight self._alpha_loss_weight = alpha_loss_weight self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_entropy = target_entropy self._gradient_clipping = gradient_clipping self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._update_target = self._get_target_updater( tau=self._target_update_tau, period=self._target_update_period) train_sequence_length = 2 if not critic_network.state_spec else None super(SacAgent, self).__init__(time_step_spec, action_spec, policy=policy, collect_policy=policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter, validate_args=False) self._as_transition = data_converter.AsTransition( self.data_context, squeeze_time_dim=(train_sequence_length == 2))
def __init__(self, observation_spec, action_spec, actor_network: DistributionNetwork, critic_network: Network, gamma=0.99, ou_stddev=0.2, ou_damping=0.15, actor_optimizer=None, critic_optimizer=None, target_update_tau=0.05, target_update_period=10, dqda_clipping=None, gradient_clipping=None, debug_summaries=False, name="SarsaAlgorithm"): """Create an SarsaAlgorithm. Args: action_spec (nested BoundedTensorSpec): representing the actions. observation_spec (nested TensorSpec): spec for observation. actor_network (Network|DistributionNetwork): The network will be called with call(observation, step_type). If it is DistributionNetwork an action will be sampled. critic_network (Network): The network will be called with call(observation, action, step_type). gamma (float): discount rate for reward ou_stddev (float): Only used for DDPG. Standard deviation for the Ornstein-Uhlenbeck (OU) noise added in the default collect policy. ou_damping (float): Only used for DDPG. Damping factor for the OU noise added in the default collect policy. target_update_tau (float): Factor for soft update of the target networks. target_update_period (int): Period for soft update of the target networks. dqda_clipping (float): when computing the actor loss, clips the gradient dqda element-wise between [-dqda_clipping, dqda_clipping]. Does not perform clipping if dqda_clipping == 0. actor_optimizer (tf.optimizers.Optimizer): The optimizer for actor. critic_optimizer (tf.optimizers.Optimizer): The optimizer for actor. gradient_clipping (float): Norm length to clip gradients. debug_summaries (bool): True if debug summaries should be created. name (str): The name of this algorithm. """ if isinstance(actor_network, DistributionNetwork): self._action_distribution_spec = actor_network.output_spec elif isinstance(actor_network, Network): self._action_distribution_spec = action_spec else: raise ValueError("Expect DistributionNetwork or Network for" " `actor_network`, got %s" % type(actor_network)) super().__init__(observation_spec, action_spec, predict_state_spec=SarsaState( prev_observation=observation_spec, prev_step_type=tf.TensorSpec((), tf.int32), actor=actor_network.state_spec), train_state_spec=SarsaState( prev_observation=observation_spec, prev_step_type=tf.TensorSpec((), tf.int32), actor=actor_network.state_spec, target_actor=actor_network.state_spec, critic=critic_network.state_spec, target_critic=critic_network.state_spec, ), optimizer=[actor_optimizer, critic_optimizer], trainable_module_sets=[[actor_network], [critic_network]], gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, name=name) self._actor_network = actor_network self._critic_network = critic_network self._target_actor_network = actor_network.copy( name='target_actor_network') self._target_critic_network = critic_network.copy( name='target_critic_network') self._update_target = common.get_target_updater( models=[self._actor_network, self._critic_network], target_models=[ self._target_actor_network, self._target_critic_network ], tau=target_update_tau, period=target_update_period) self._dqda_clipping = dqda_clipping self._gamma = gamma self._ou_process = create_ou_process(action_spec, ou_stddev, ou_damping)
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, critic_network: network.Network, actor_network: network.Network, actor_optimizer: types.Optimizer, critic_optimizer: types.Optimizer, alpha_optimizer: types.Optimizer, actor_loss_weight: types.Float = 1.0, critic_loss_weight: types.Float = 0.5, alpha_loss_weight: types.Float = 1.0, actor_policy_ctor: Callable[ ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy, critic_network_2: Optional[network.Network] = None, target_critic_network: Optional[network.Network] = None, target_critic_network_2: Optional[network.Network] = None, target_update_tau: types.Float = 1.0, target_update_period: types.Int = 1, td_errors_loss_fn: types.LossFn = tf.math.squared_difference, gamma: types.Float = 1.0, sigma: types.Float = 0.9, reward_scale_factor: types.Float = 1.0, initial_log_alpha: types.Float = 0.0, use_log_alpha_in_alpha_loss: bool = True, target_entropy: Optional[types.Float] = None, gradient_clipping: Optional[types.Float] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, train_step_counter: Optional[tf.Variable] = None, name: Optional[Text] = None): tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) net_observation_spec = time_step_spec.observation critic_spec = (net_observation_spec, action_spec) self._critic_network_1 = critic_network if critic_network_2 is not None: self._critic_network_2 = critic_network_2 else: self._critic_network_2 = critic_network.copy(name='CriticNetwork2') # Do not use target_critic_network_2 if critic_network_2 is None. target_critic_network_2 = None # Wait until critic_network_2 has been copied from critic_network_1 before # creating variables on both. self._critic_network_1.create_variables(critic_spec) self._critic_network_2.create_variables(critic_spec) if target_critic_network: target_critic_network.create_variables(critic_spec) self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks( self._critic_network_1, target_critic_network, input_spec=critic_spec, name='TargetCriticNetwork1')) if target_critic_network_2: target_critic_network_2.create_variables(critic_spec) self._target_critic_network_2 = ( common.maybe_copy_target_network_with_checks( self._critic_network_2, target_critic_network_2, input_spec=critic_spec, name='TargetCriticNetwork2')) if actor_network: actor_network.create_variables(net_observation_spec) self._actor_network = actor_network policy = actor_policy_ctor(time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=False) self._train_policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=True) self._log_alpha = common.create_variable( 'initial_log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) if target_entropy is None: target_entropy = self._get_default_target_entropy(action_spec) self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._alpha_optimizer = alpha_optimizer self._actor_loss_weight = actor_loss_weight self._critic_loss_weight = critic_loss_weight self._alpha_loss_weight = alpha_loss_weight self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_entropy = target_entropy self._gradient_clipping = gradient_clipping self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._update_target = self._get_target_updater( tau=self._target_update_tau, period=self._target_update_period) self.sigma = sigma train_sequence_length = 2 if not critic_network.state_spec else None super(sac_agent.SacAgent, self).__init__(time_step_spec, action_spec, policy=policy, collect_policy=policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter, validate_args=False) self._as_transition = data_converter.AsTransition( self.data_context, squeeze_time_dim=(train_sequence_length == 2))
def __init__(self, output_dim, noise_dim=32, input_tensor_spec=None, hidden_layers=(256, ), net: Network = None, net_moving_average_rate=None, entropy_regularization=0., kernel_sharpness=2., mi_weight=None, mi_estimator_cls=MIEstimator, optimizer: tf.optimizers.Optimizer = None, name="Generator"): """Create a Generator. Args: output_dim (int): dimension of output noise_dim (int): dimension of noise input_tensor_spec (nested TensorSpec): spec of inputs. If there is no inputs, this should be None. hidden_layers (tuple): size of hidden layers. net (Network): network for generating outputs from [noise, inputs] or noise (if inputs is None). If None, a default one with hidden_layers will be created net_moving_average_rate (float): If provided, use a moving average version of net to do prediction. This has been shown to be effective for GAN training (arXiv:1907.02544, arXiv:1812.04948). entropy_regularization (float): weight of entropy regularization kernel_sharpness (float): Used only for entropy_regularization > 0. We calcualte the kernel in SVGD as: exp(-kernel_sharpness * reduce_mean((x-y)^2/width)), where width is the elementwise moving average of (x-y)^2 mi_estimator_cls (type): the class of mutual information estimator for maximizing the mutual information between [noise, inputs] and [outputs, inputs]. optimizer (tf.optimizers.Optimizer): optimizer (optional) name (str): name of this generator """ super().__init__(train_state_spec=(), optimizer=optimizer, name=name) self._noise_dim = noise_dim self._entropy_regularization = entropy_regularization if entropy_regularization == 0: self._grad_func = self._ml_grad else: self._grad_func = self._stein_grad self._kernel_width_averager = AdaptiveAverager( tensor_spec=tf.TensorSpec(shape=(output_dim, ))) self._kernel_sharpness = kernel_sharpness noise_spec = tf.TensorSpec(shape=[noise_dim]) if net is None: net_input_spec = noise_spec if input_tensor_spec is not None: net_input_spec = [net_input_spec, input_tensor_spec] net = EncodingNetwork( name="Generator", input_tensor_spec=net_input_spec, fc_layer_params=hidden_layers, last_layer_size=output_dim) self._mi_estimator = None self._input_tensor_spec = input_tensor_spec if mi_weight is not None: x_spec = noise_spec y_spec = tf.TensorSpec((output_dim, )) if input_tensor_spec is not None: x_spec = [x_spec, input_tensor_spec] self._mi_estimator = mi_estimator_cls( x_spec, y_spec, sampler='shift') self._mi_weight = mi_weight self._net = net self._predict_net = None self._net_moving_average_rate = net_moving_average_rate if net_moving_average_rate: self._predict_net = net.copy(name="Genrator_average") tfa_common.soft_variables_update( self._net.variables, self._predict_net.variables, tau=1.0)