def __init__( self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, q_network: network.Network, optimizer: types.Optimizer, observation_and_action_constraint_splitter: Optional[ types.Splitter] = None, epsilon_greedy: types.Float = 0.1, n_step_update: int = 1, boltzmann_temperature: Optional[types.Int] = None, emit_log_probability: bool = False, # Params for target network updates target_q_network: Optional[network.Network] = None, target_update_tau: types.Float = 1.0, target_update_period: int = 1, # Params for training. td_errors_loss_fn: Optional[types.LossFn] = None, gamma: types.Float = 1.0, reward_scale_factor: types.Float = 1.0, gradient_clipping: Optional[types.Float] = None, # Params for debugging debug_summaries: bool = False, summarize_grads_and_vars: bool = False, train_step_counter: Optional[tf.Variable] = None, name: Optional[Text] = None, entropy_tau: types.Float = 0.9, alpha: types.Float = 0.3): tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) if epsilon_greedy is not None and boltzmann_temperature is not None: raise ValueError( 'Configured both epsilon_greedy value {} and temperature {}, ' 'however only one of them can be used for exploration.'.format( epsilon_greedy, boltzmann_temperature)) self._observation_and_action_constraint_splitter = ( observation_and_action_constraint_splitter) self._q_network = q_network net_observation_spec = time_step_spec.observation if observation_and_action_constraint_splitter: net_observation_spec, _ = observation_and_action_constraint_splitter( net_observation_spec) q_network.create_variables(net_observation_spec) if target_q_network: target_q_network.create_variables(net_observation_spec) self._target_q_network = common.maybe_copy_target_network_with_checks( self._q_network, target_q_network, input_spec=net_observation_spec, name='TargetQNetwork') self._check_network_output(self._q_network, 'q_network') self._check_network_output(self._target_q_network, 'target_q_network') self._epsilon_greedy = epsilon_greedy self._n_step_update = n_step_update self._boltzmann_temperature = boltzmann_temperature self._optimizer = optimizer self._td_errors_loss_fn = (td_errors_loss_fn or common.element_wise_huber_loss) self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._gradient_clipping = gradient_clipping self._update_target = self._get_target_updater(target_update_tau, target_update_period) self.entropy_tau = entropy_tau self.alpha = alpha policy, collect_policy = self._setup_policy(time_step_spec, action_spec, boltzmann_temperature, emit_log_probability) if q_network.state_spec and n_step_update != 1: raise NotImplementedError( 'DqnAgent does not currently support n-step updates with stateful ' 'networks (i.e., RNNs), but n_step_update = {}'.format( n_step_update)) train_sequence_length = (n_step_update + 1 if not q_network.state_spec else None) super(dqn_agent.DqnAgent, self).__init__( time_step_spec, action_spec, policy, collect_policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter, validate_args=False, ) if q_network.state_spec: # AsNStepTransition does not support emitting [B, T, ...] tensors, # which we need for DQN-RNN. self._as_transition = data_converter.AsTransition( self.data_context, squeeze_time_dim=False) else: # This reduces the n-step return and removes the extra time dimension, # allowing the rest of the computations to be independent of the # n-step parameter. self._as_transition = data_converter.AsNStepTransition( self.data_context, gamma=gamma, n=n_step_update)
def __init__(self, time_step_spec, action_spec, actor_network, critic_network, actor_optimizer, critic_optimizer, exploration_noise_std=0.1, critic_network_2=None, target_actor_network=None, target_critic_network=None, target_critic_network_2=None, target_update_tau=1.0, target_update_period=1, actor_update_period=1, dqda_clipping=None, td_errors_loss_fn=None, gamma=1.0, reward_scale_factor=1.0, target_policy_noise=0.2, target_policy_noise_clip=0.5, gradient_clipping=None, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, name=None): """Creates a Td3Agent Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. actor_network: A tf_agents.network.Network to be used by the agent. The network will be called with call(observation, step_type). critic_network: A tf_agents.network.Network to be used by the agent. The network will be called with call(observation, action, step_type). actor_optimizer: The default optimizer to use for the actor network. critic_optimizer: The default optimizer to use for the critic network. exploration_noise_std: Scale factor on exploration policy noise. critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as the second critic network during Q learning. The weights from `critic_network` are copied if this is not provided. target_actor_network: (Optional.) A `tf_agents.network.Network` to be used as the target actor network during Q learning. Every `target_update_period` train steps, the weights from `actor_network` are copied (possibly withsmoothing via `target_update_tau`) to ` target_actor_network`. If `target_actor_network` is not provided, it is created by making a copy of `actor_network`, which initializes a new network with the same structure and its own layers and weights. Performing a `Network.copy` does not work when the network instance already has trainable parameters (e.g., has already been built, or when the network is sharing layers with another). In these cases, it is up to you to build a copy having weights that are not shared with the original `actor_network`, so that this can be used as a target network. If you provide a `target_actor_network` that shares any weights with `actor_network`, a warning will be logged but no exception is thrown. target_critic_network: (Optional.) Similar network as target_actor_network but for the critic_network. See documentation for target_actor_network. target_critic_network_2: (Optional.) Similar network as target_actor_network but for the critic_network_2. See documentation for target_actor_network. Will only be used if 'critic_network_2' is also specified. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. actor_update_period: Period for the optimization step on actor network. dqda_clipping: A scalar or float clips the gradient dqda element-wise between [-dqda_clipping, dqda_clipping]. Default is None representing no clippiing. td_errors_loss_fn: A function for computing the TD errors loss. If None, a default value of elementwise huber_loss is used. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. target_policy_noise: Scale factor on target action noise target_policy_noise_clip: Value to clip noise. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. """ tf.Module.__init__(self, name=name) self._actor_network = actor_network actor_network.create_variables() if target_actor_network: target_actor_network.create_variables() self._target_actor_network = common.maybe_copy_target_network_with_checks( self._actor_network, target_actor_network, 'TargetActorNetwork') self._critic_network_1 = critic_network critic_network.create_variables() if target_critic_network: target_critic_network.create_variables() self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks(self._critic_network_1, target_critic_network, 'TargetCriticNetwork1')) if critic_network_2 is not None: self._critic_network_2 = critic_network_2 else: self._critic_network_2 = critic_network.copy(name='CriticNetwork2') # Do not use target_critic_network_2 if critic_network_2 is None. target_critic_network_2 = None self._critic_network_2.create_variables() if target_critic_network_2: target_critic_network_2.create_variables() self._target_critic_network_2 = ( common.maybe_copy_target_network_with_checks(self._critic_network_2, target_critic_network_2, 'TargetCriticNetwork2')) self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._exploration_noise_std = exploration_noise_std self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_update_period = actor_update_period self._dqda_clipping = dqda_clipping self._td_errors_loss_fn = ( td_errors_loss_fn or common.element_wise_huber_loss) self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_policy_noise = target_policy_noise self._target_policy_noise_clip = target_policy_noise_clip self._gradient_clipping = gradient_clipping self._update_target = self._get_target_updater(target_update_tau, target_update_period) policy = actor_policy.ActorPolicy( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, clip=True) collect_policy = actor_policy.ActorPolicy( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, clip=False) collect_policy = gaussian_policy.GaussianPolicy( collect_policy, scale=self._exploration_noise_std, clip=True) super(Td3Agent, self).__init__( time_step_spec, action_spec, policy, collect_policy, train_sequence_length=2 if not self._actor_network.state_spec else None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter)
def __init__(self, time_step_spec, action_spec, critic_network, actor_network, actor_optimizer, critic_optimizer, alpha_optimizer, critic_network_no_entropy=None, critic_no_entropy_optimizer=None, actor_loss_weight=1.0, critic_loss_weight=0.5, alpha_loss_weight=1.0, actor_policy_ctor=actor_policy.ActorPolicy, critic_network_2=None, target_critic_network=None, target_critic_network_2=None, target_update_tau=1.0, target_update_period=1, td_errors_loss_fn=tf.math.squared_difference, gamma=1.0, reward_scale_factor=1.0, initial_log_alpha=0.0, use_log_alpha_in_alpha_loss=True, target_entropy=None, gradient_clipping=None, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, name=None): """Creates a SAC Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. critic_network: A function critic_network((observations, actions)) that returns the q_values for each observation and action. actor_network: A function actor_network(observation, action_spec) that returns action distribution. actor_optimizer: The optimizer to use for the actor network. critic_optimizer: The default optimizer to use for the critic network. alpha_optimizer: The default optimizer to use for the alpha variable. actor_loss_weight: The weight on actor loss. critic_loss_weight: The weight on critic loss. alpha_loss_weight: The weight on alpha loss. actor_policy_ctor: The policy class to use. critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as the second critic network during Q learning. The weights from `critic_network` are copied if this is not provided. target_critic_network: (Optional.) A `tf_agents.network.Network` to be used as the target critic network during Q learning. Every `target_update_period` train steps, the weights from `critic_network` are copied (possibly withsmoothing via `target_update_tau`) to ` target_critic_network`. If `target_critic_network` is not provided, it is created by making a copy of `critic_network`, which initializes a new network with the same structure and its own layers and weights. Performing a `Network.copy` does not work when the network instance already has trainable parameters (e.g., has already been built, or when the network is sharing layers with another). In these cases, it is up to you to build a copy having weights that are not shared with the original `critic_network`, so that this can be used as a target network. If you provide a `target_critic_network` that shares any weights with `critic_network`, a warning will be logged but no exception is thrown. target_critic_network_2: (Optional.) Similar network as target_critic_network but for the critic_network_2. See documentation for target_critic_network. Will only be used if 'critic_network_2' is also specified. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. td_errors_loss_fn: A function for computing the elementwise TD errors loss. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. initial_log_alpha: Initial value for log_alpha. use_log_alpha_in_alpha_loss: A boolean, whether using log_alpha or alpha in alpha loss. Certain implementations of SAC use log_alpha as log values are generally nicer to work with. target_entropy: The target average policy entropy, for updating alpha. The default value is negative of the total number of actions. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. """ tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) self._critic_network_1 = critic_network self._critic_network_1.create_variables() if target_critic_network: target_critic_network.create_variables() self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks( self._critic_network_1, target_critic_network, 'TargetCriticNetwork1')) # for estimating critics without entropy included self._critic_network_no_entropy_1 = critic_network_no_entropy if critic_network_no_entropy is not None: self._critic_network_no_entropy_1.create_variables() self._target_critic_network_no_entropy_1 = ( common.maybe_copy_target_network_with_checks( self._critic_network_no_entropy_1, None, 'TargetCriticNetworkNoEntropy1')) # Network 2 self._critic_network_no_entropy_2 = self._critic_network_no_entropy_1.copy( name='CriticNetworkNoEntropy2') self._critic_network_no_entropy_2.create_variables() self._target_critic_network_no_entropy_2 = ( common.maybe_copy_target_network_with_checks( self._critic_network_no_entropy_2, None, 'TargetCriticNetworkNoEntropy2')) if critic_network_2 is not None: self._critic_network_2 = critic_network_2 else: self._critic_network_2 = critic_network.copy(name='CriticNetwork2') # Do not use target_critic_network_2 if critic_network_2 is None. target_critic_network_2 = None self._critic_network_2.create_variables() if target_critic_network_2: target_critic_network_2.create_variables() self._target_critic_network_2 = ( common.maybe_copy_target_network_with_checks( self._critic_network_2, target_critic_network_2, 'TargetCriticNetwork2')) if actor_network: actor_network.create_variables() self._actor_network = actor_network policy = actor_policy_ctor(time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=False) self._train_policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=True) self._log_alpha = common.create_variable( 'initial_log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) if target_entropy is None: target_entropy = self._get_default_target_entropy(action_spec) self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._critic_no_entropy_optimizer = critic_no_entropy_optimizer self._alpha_optimizer = alpha_optimizer self._actor_loss_weight = actor_loss_weight self._critic_loss_weight = critic_loss_weight self._alpha_loss_weight = alpha_loss_weight self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_entropy = target_entropy self._gradient_clipping = gradient_clipping self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._update_target = self._get_target_updater( tau=self._target_update_tau, period=self._target_update_period) train_sequence_length = 2 if not critic_network.state_spec else None super(SacAgent, self).__init__(time_step_spec, action_spec, policy=policy, collect_policy=policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter)
def __init__(self, time_step_spec, action_spec, critic_network, actor_network, actor_optimizer, critic_optimizer, actor_loss_weight = 1.0, critic_loss_weight = 0.5, actor_policy_ctor = actor_policy.ActorPolicy, critic_network_2 = None, target_critic_network = None, target_critic_network_2 = None, target_update_tau = 1.0, target_update_period = 1, td_errors_loss_fn = tf.math.squared_difference, gamma = 1.0, reward_scale_factor = 1.0, gradient_clipping = None, debug_summaries = False, summarize_grads_and_vars = False, train_step_counter = None, name = None, n_step = None, use_behavior_policy = False): """Creates a RCE Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. critic_network: A function critic_network((observations, actions)) that returns the q_values for each observation and action. actor_network: A function actor_network(observation, action_spec) that returns action distribution. actor_optimizer: The optimizer to use for the actor network. critic_optimizer: The default optimizer to use for the critic network. actor_loss_weight: The weight on actor loss. critic_loss_weight: The weight on critic loss. actor_policy_ctor: The policy class to use. critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as the second critic network during Q learning. The weights from `critic_network` are copied if this is not provided. target_critic_network: (Optional.) A `tf_agents.network.Network` to be used as the target critic network during Q learning. Every `target_update_period` train steps, the weights from `critic_network` are copied (possibly withsmoothing via `target_update_tau`) to ` target_critic_network`. If `target_critic_network` is not provided, it is created by making a copy of `critic_network`, which initializes a new network with the same structure and its own layers and weights. Performing a `Network.copy` does not work when the network instance already has trainable parameters (e.g., has already been built, or when the network is sharing layers with another). In these cases, it is up to you to build a copy having weights that are not shared with the original `critic_network`, so that this can be used as a target network. If you provide a `target_critic_network` that shares any weights with `critic_network`, a warning will be logged but no exception is thrown. target_critic_network_2: (Optional.) Similar network as target_critic_network but for the critic_network_2. See documentation for target_critic_network. Will only be used if 'critic_network_2' is also specified. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. td_errors_loss_fn: A function for computing the elementwise TD errors loss. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. n_step: An integer specifying whether to use n-step returns. Empirically, a value of 10 works well for most tasks. Use None to disable n-step returns. use_behavior_policy: A boolean indicating how to sample actions for the success states. When use_behavior_policy=True, we use the historical average policy; otherwise, we use the current policy. """ tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) self._critic_network_1 = critic_network self._critic_network_1.create_variables( (time_step_spec.observation, action_spec)) if target_critic_network: target_critic_network.create_variables( (time_step_spec.observation, action_spec)) self._target_critic_network_1 = target_critic_network else: self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks(self._critic_network_1, None, 'TargetCriticNetwork1')) if critic_network_2 is not None: self._critic_network_2 = critic_network_2 else: self._critic_network_2 = critic_network.copy(name='CriticNetwork2') # Do not use target_critic_network_2 if critic_network_2 is None. target_critic_network_2 = None self._critic_network_2.create_variables( (time_step_spec.observation, action_spec)) if target_critic_network_2: target_critic_network_2.create_variables( (time_step_spec.observation, action_spec)) self._target_critic_network_2 = target_critic_network else: self._target_critic_network_2 = ( common.maybe_copy_target_network_with_checks(self._critic_network_2, None, 'TargetCriticNetwork2')) if actor_network: actor_network.create_variables(time_step_spec.observation) self._actor_network = actor_network self._use_behavior_policy = use_behavior_policy if use_behavior_policy: self._behavior_actor_network = actor_network.copy( name='BehaviorActorNetwork') self._behavior_policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._behavior_actor_network, training=True) policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=False) self._train_policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=True) self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._actor_loss_weight = actor_loss_weight self._critic_loss_weight = critic_loss_weight self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._gradient_clipping = gradient_clipping self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._update_target = self._get_target_updater( tau=self._target_update_tau, period=self._target_update_period) self._n_step = n_step train_sequence_length = 2 if not critic_network.state_spec else None super(RceAgent, self).__init__( time_step_spec, action_spec, policy=policy, collect_policy=policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter, validate_args=False ) self._as_transition = data_converter.AsTransition( self.data_context, squeeze_time_dim=(train_sequence_length == 2))
def __init__( self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, q_network: network.Network, optimizer: types.Optimizer, observation_and_action_constraint_splitter: Optional[ types.Splitter] = None, epsilon_greedy: types.Float = 0.1, n_step_update: int = 1, boltzmann_temperature: Optional[types.Int] = None, emit_log_probability: bool = False, # Params for target network updates target_q_network: Optional[network.Network] = None, target_update_tau: types.Float = 1.0, target_update_period: int = 1, # Params for training. td_errors_loss_fn: Optional[types.LossFn] = None, gamma: types.Float = 1.0, reward_scale_factor: types.Float = 1.0, gradient_clipping: Optional[types.Float] = None, # Params for debugging debug_summaries: bool = False, summarize_grads_and_vars: bool = False, train_step_counter: Optional[tf.Variable] = None, name: Optional[Text] = None): """Creates a DQN Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. q_network: A `tf_agents.network.Network` to be used by the agent. The network will be called with `call(observation, step_type)` and should emit logits over the action space. optimizer: The optimizer to use for training. observation_and_action_constraint_splitter: A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example `observation_and_action_constraint_splitter` could be as simple as: ``` def observation_and_action_constraint_splitter(observation): return observation['network_input'], observation['constraint'] ``` *Note*: when using `observation_and_action_constraint_splitter`, make sure the provided `q_network` is compatible with the network-specific half of the output of the `observation_and_action_constraint_splitter`. In particular, `observation_and_action_constraint_splitter` will be called on the observation before passing to the network. If `observation_and_action_constraint_splitter` is None, action constraints are not applied. epsilon_greedy: probability of choosing a random action in the default epsilon-greedy collect policy (used only if a wrapper is not provided to the collect_policy method). n_step_update: The number of steps to consider when computing TD error and TD loss. Defaults to single-step updates. Note that this requires the user to call train on Trajectory objects with a time dimension of `n_step_update + 1`. However, note that we do not yet support `n_step_update > 1` in the case of RNNs (i.e., non-empty `q_network.state_spec`). boltzmann_temperature: Temperature value to use for Boltzmann sampling of the actions during data collection. The closer to 0.0, the higher the probability of choosing the best action. emit_log_probability: Whether policies emit log probabilities or not. target_q_network: (Optional.) A `tf_agents.network.Network` to be used as the target network during Q learning. Every `target_update_period` train steps, the weights from `q_network` are copied (possibly with smoothing via `target_update_tau`) to `target_q_network`. If `target_q_network` is not provided, it is created by making a copy of `q_network`, which initializes a new network with the same structure and its own layers and weights. Network copying is performed via the `Network.copy` superclass method, and may inadvertently lead to the resulting network to share weights with the original. This can happen if, for example, the original network accepted a pre-built Keras layer in its `__init__`, or accepted a Keras layer that wasn't built, but neglected to create a new copy. In these cases, it is up to you to provide a target Network having weights that are not shared with the original `q_network`. If you provide a `target_q_network` that shares any weights with `q_network`, a warning will be logged but no exception is thrown. Note; shallow copies of Keras layers may be built via the code: ```python new_layer = type(layer).from_config(layer.get_config()) ``` target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. td_errors_loss_fn: A function for computing the TD errors loss. If None, a default value of element_wise_huber_loss is used. This function takes as input the target and the estimated Q values and returns the loss for each element of the batch. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. Raises: ValueError: If `action_spec` contains more than one action or action spec minimum is not equal to 0. ValueError: If the q networks do not emit floating point outputs with inner shape matching `action_spec`. NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an RNN is provided) and `n_step_update > 1`. """ tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) if epsilon_greedy is not None and boltzmann_temperature is not None: raise ValueError( 'Configured both epsilon_greedy value {} and temperature {}, ' 'however only one of them can be used for exploration.'.format( epsilon_greedy, boltzmann_temperature)) self._observation_and_action_constraint_splitter = ( observation_and_action_constraint_splitter) self._q_network = q_network net_observation_spec = time_step_spec.observation if observation_and_action_constraint_splitter: net_observation_spec, _ = observation_and_action_constraint_splitter( net_observation_spec) q_network.create_variables(net_observation_spec) if target_q_network: target_q_network.create_variables(net_observation_spec) self._target_q_network = common.maybe_copy_target_network_with_checks( self._q_network, target_q_network, input_spec=net_observation_spec, name='TargetQNetwork') self._check_network_output(self._q_network, 'q_network') self._check_network_output(self._target_q_network, 'target_q_network') self._epsilon_greedy = epsilon_greedy self._n_step_update = n_step_update self._boltzmann_temperature = boltzmann_temperature self._optimizer = optimizer self._td_errors_loss_fn = ( td_errors_loss_fn or common.element_wise_huber_loss) self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._gradient_clipping = gradient_clipping self._update_target = self._get_target_updater( target_update_tau, target_update_period) policy, collect_policy = self._setup_policy(time_step_spec, action_spec, boltzmann_temperature, emit_log_probability) if q_network.state_spec and n_step_update != 1: raise NotImplementedError( 'DqnAgent does not currently support n-step updates with stateful ' 'networks (i.e., RNNs), but n_step_update = {}'.format(n_step_update)) train_sequence_length = ( n_step_update + 1 if not q_network.state_spec else None) super(DqnAgent, self).__init__( time_step_spec, action_spec, policy, collect_policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter, validate_args=False, ) if q_network.state_spec: # AsNStepTransition does not support emitting [B, T, ...] tensors, # which we need for DQN-RNN. self._as_transition = data_converter.AsTransition( self.data_context, squeeze_time_dim=False) else: # This reduces the n-step return and removes the extra time dimension, # allowing the rest of the computations to be independent of the # n-step parameter. self._as_transition = data_converter.AsNStepTransition( self.data_context, gamma=gamma, n=n_step_update)
def __init__( self, time_step_spec, action_spec, q_network, optimizer, actions_sampler, epsilon_greedy=0.1, n_step_update=1, emit_log_probability=False, in_graph_bellman_update=True, # Params for cem init_mean_cem=None, init_var_cem=None, num_samples_cem=32, num_elites_cem=4, num_iter_cem=3, # Params for target network updates target_q_network=None, target_update_tau=1.0, target_update_period=1, enable_td3=True, target_q_network_delayed=None, target_q_network_delayed_2=None, delayed_target_update_period=5, # Params for training. td_errors_loss_fn=None, auxiliary_loss_fns=None, gamma=1.0, reward_scale_factor=1.0, gradient_clipping=None, # Params for debugging debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, info_spec=None, name=None): """Creates a Qtopt Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. q_network: A tf_agents.network.Network to be used by the agent. The network will be called with call((observation, action), step_type). The q_network is different from the one used in DQN where the input is state and the output has multiple dimension representing Q values for different actions. The input of this q_network is a tuple of state and action. The output is one dimension representing Q value for that specific action. DDPG critic network can be used directly here. optimizer: The optimizer to use for training. actions_sampler: A tf_agents.policies.sampler.ActionsSampler to be used to sample actions in CEM. epsilon_greedy: probability of choosing a random action in the default epsilon-greedy collect policy (used only if a wrapper is not provided to the collect_policy method). n_step_update: Currently, only n_step_update == 1 is supported. emit_log_probability: Whether policies emit log probabilities or not. in_graph_bellman_update: If False, configures the agent to expect experience containing computed q_values in the policy_step's info field. This allows simplifies splitting the loss calculation across several jobs. init_mean_cem: Initial mean value of the Gaussian distribution to sample actions for CEM. init_var_cem: Initial variance value of the Gaussian distribution to sample actions for CEM. num_samples_cem: Number of samples to sample for each iteration in CEM. num_elites_cem: Number of elites to select for each iteration in CEM. num_iter_cem: Number of iterations in CEM. target_q_network: (Optional.) A `tf_agents.network.Network` to be used as the target network during Q learning. Every `target_update_period` train steps, the weights from `q_network` are copied (possibly with smoothing via `target_update_tau`) to `target_q_network`. If `target_q_network` is not provided, it is created by making a copy of `q_network`, which initializes a new network with the same structure and its own layers and weights. Network copying is performed via the `Network.copy` superclass method, with the same arguments used during the original network's construction and may inadvertently lead to weights being shared between networks. This can happen if, for example, the original network accepted a pre-built Keras layer in its `__init__`, or accepted a Keras layer that wasn't built, but neglected to create a new copy. In these cases, it is up to you to provide a target Network having weights that are not shared with the original `q_network`. If you provide a `target_q_network` that shares any weights with `q_network`, an exception is thrown. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. enable_td3: Whether or not to enable using a delayed target network to calculate q value and assign min(q_delayed, q_delayed_2) as q_next_state. target_q_network_delayed: (Optional.) Similar network as target_q_network but lags behind even more. See documentation for target_q_network. Will only be used if 'enable_td3' is True. target_q_network_delayed_2: (Optional.) Similar network as target_q_network_delayed but lags behind even more. See documentation for target_q_network. Will only be used if 'enable_td3' is True. delayed_target_update_period: Used when enable_td3 is true. Period for soft update of the delayed target networks. td_errors_loss_fn: A function for computing the TD errors loss. If None, a default value of element_wise_huber_loss is used. This function takes as input the target and the estimated Q values and returns the loss for each element of the batch. auxiliary_loss_fns: An optional list of functions for computing auxiliary losses. Each auxiliary_loss_fn expects network and transition as input and should output auxiliary_loss and auxiliary_reg_loss. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. info_spec: If not None, the policy info spec is set to this spec. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. Raises: ValueError: If the action spec contains more than one action or action spec minimum is not equal to 0. NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an RNN is provided) and `n_step_update > 1`. """ tf.Module.__init__(self, name=name) self._sampler = actions_sampler self._init_mean_cem = init_mean_cem self._init_var_cem = init_var_cem self._num_samples_cem = num_samples_cem self._num_elites_cem = num_elites_cem self._num_iter_cem = num_iter_cem self._in_graph_bellman_update = in_graph_bellman_update if not in_graph_bellman_update: if info_spec is not None: self._info_spec = info_spec else: self._info_spec = { 'target_q': tensor_spec.TensorSpec((), tf.float32), } else: self._info_spec = () self._q_network = q_network net_observation_spec = (time_step_spec.observation, action_spec) q_network.create_variables(net_observation_spec) if target_q_network: target_q_network.create_variables(net_observation_spec) self._target_q_network = common.maybe_copy_target_network_with_checks( self._q_network, target_q_network, input_spec=net_observation_spec, name='TargetQNetwork') self._target_updater = self._get_target_updater( target_update_tau, target_update_period) self._enable_td3 = enable_td3 if (not self._enable_td3 and (target_q_network_delayed or target_q_network_delayed_2)): raise ValueError( 'enable_td3 is set to False but target_q_network_delayed' ' or target_q_network_delayed_2 is passed.') if self._enable_td3: if target_q_network_delayed: target_q_network_delayed.create_variables() self._target_q_network_delayed = ( common.maybe_copy_target_network_with_checks( self._q_network, target_q_network_delayed, 'TargetQNetworkDelayed')) self._target_updater_delayed = self._get_target_updater_delayed( 1.0, delayed_target_update_period) if target_q_network_delayed_2: target_q_network_delayed_2.create_variables() self._target_q_network_delayed_2 = ( common.maybe_copy_target_network_with_checks( self._q_network, target_q_network_delayed_2, 'TargetQNetworkDelayed2')) self._target_updater_delayed_2 = self._get_target_updater_delayed_2( 1.0, delayed_target_update_period) self._update_target = self._update_both else: self._update_target = self._target_updater self._target_q_network_delayed = None self._target_q_network_delayed_2 = None self._check_network_output(self._q_network, 'q_network') self._check_network_output(self._target_q_network, 'target_q_network') self._epsilon_greedy = epsilon_greedy self._n_step_update = n_step_update self._optimizer = optimizer self._td_errors_loss_fn = (td_errors_loss_fn or common.element_wise_huber_loss) self._auxiliary_loss_fns = auxiliary_loss_fns self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._gradient_clipping = gradient_clipping policy, collect_policy = self._setup_policy(time_step_spec, action_spec, emit_log_probability) if q_network.state_spec and n_step_update != 1: raise NotImplementedError( 'QtOptAgent does not currently support n-step updates with stateful ' 'networks (i.e., RNNs), but n_step_update = {}'.format( n_step_update)) # Bypass the train_sequence_length check when RNN is used. train_sequence_length = (n_step_update + 1 if not q_network.state_spec else None) super(QtOptAgent, self).__init__( time_step_spec, action_spec, policy, collect_policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter, ) self._setup_data_converter(q_network, gamma, n_step_update)
def __init__(self, time_step_spec, action_spec, actor_network, critic_network, actor_optimizer=None, critic_optimizer=None, ou_stddev=1.0, ou_damping=1.0, target_actor_network=None, target_critic_network=None, target_update_tau=1.0, target_update_period=1, dqda_clipping=None, td_errors_loss_fn=None, gamma=1.0, reward_scale_factor=1.0, gradient_clipping=None, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, name=None): """Creates a DDPG Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. actor_network: A tf_agents.network.Network to be used by the agent. The network will be called with call(observation, step_type[, policy_state]) and should return (action, new_state). critic_network: A tf_agents.network.Network to be used by the agent. The network will be called with call((observation, action), step_type[, policy_state]) and should return (q_value, new_state). actor_optimizer: The optimizer to use for the actor network. critic_optimizer: The optimizer to use for the critic network. ou_stddev: Standard deviation for the Ornstein-Uhlenbeck (OU) noise added in the default collect policy. ou_damping: Damping factor for the OU noise added in the default collect policy. target_actor_network: (Optional.) A `tf_agents.network.Network` to be used as the actor target network during Q learning. Every `target_update_period` train steps, the weights from `actor_network` are copied (possibly withsmoothing via `target_update_tau`) to ` target_q_network`. If `target_actor_network` is not provided, it is created by making a copy of `actor_network`, which initializes a new network with the same structure and its own layers and weights. Performing a `Network.copy` does not work when the network instance already has trainable parameters (e.g., has already been built, or when the network is sharing layers with another). In these cases, it is up to you to build a copy having weights that are not shared with the original `actor_network`, so that this can be used as a target network. If you provide a `target_actor_network` that shares any weights with `actor_network`, a warning will be logged but no exception is thrown. target_critic_network: (Optional.) Similar network as target_actor_network but for the critic_network. See documentation for target_actor_network. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. dqda_clipping: when computing the actor loss, clips the gradient dqda element-wise between [-dqda_clipping, dqda_clipping]. Does not perform clipping if dqda_clipping == 0. td_errors_loss_fn: A function for computing the TD errors loss. If None, a default value of elementwise huber_loss is used. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. """ tf.Module.__init__(self, name=name) self._actor_network = actor_network actor_network.create_variables() if target_actor_network: target_actor_network.create_variables() self._target_actor_network = common.maybe_copy_target_network_with_checks( self._actor_network, target_actor_network, 'TargetActorNetwork') self._critic_network = critic_network critic_network.create_variables() if target_critic_network: target_critic_network.create_variables() self._target_critic_network = common.maybe_copy_target_network_with_checks( self._critic_network, target_critic_network, 'TargetCriticNetwork') self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._ou_stddev = ou_stddev self._ou_damping = ou_damping self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._dqda_clipping = dqda_clipping self._td_errors_loss_fn = (td_errors_loss_fn or common.element_wise_huber_loss) self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._gradient_clipping = gradient_clipping self._update_target = self._get_target_updater(target_update_tau, target_update_period) """Nitty: change time_step_spec to that of individual agent from total spec""" individual_time_step_spec = ts.get_individual_time_step_spec( time_step_spec) policy = actor_policy.ActorPolicy( time_step_spec=individual_time_step_spec, action_spec=action_spec, actor_network=self._actor_network, clip=True) collect_policy = actor_policy.ActorPolicy( time_step_spec=individual_time_step_spec, action_spec=action_spec, actor_network=self._actor_network, clip=False) # policy = actor_policy.ActorPolicy( # time_step_spec=time_step_spec, action_spec=action_spec, # actor_network=self._actor_network, clip=True) # collect_policy = actor_policy.ActorPolicy( # time_step_spec=time_step_spec, action_spec=action_spec, # actor_network=self._actor_network, clip=False) collect_policy = ou_noise_policy.OUNoisePolicy( collect_policy, ou_stddev=self._ou_stddev, ou_damping=self._ou_damping, clip=True) super(DdpgAgent, self).__init__(time_step_spec, action_spec, policy, collect_policy, train_sequence_length=2 if not self._actor_network.state_spec else None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter)
def __init__(self, time_step_spec, action_spec, q_networks, critic_optimizer, exploration_noise_std=0.1, boltzmann_temperature=10.0, epsilon_greedy=0.1, q_networks_2=None, target_q_networks=None, target_q_networks_2=None, target_update_tau=1.0, target_update_period=1, dqda_clipping=None, td_errors_loss_fn=None, gamma=1.0, reward_scale_factor=1.0, target_policy_noise=0.2, target_policy_noise_clip=0.5, gradient_clipping=None, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, action_params_mask=None, n_step_update=1, name=None): """Creates a Td3Agent Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A namedtuple of nested BoundedTensorSpec representing the actions. action_spec: A list of tf_agents.network.Network to be used by the agent. The network will be called with call(observation, action, step_type). q_networks: A tf_agents.network.Network to be used by the agent. The network will be called with call(observation, action, step_type). critic_optimizer: The default optimizer to use for the critic network. exploration_noise_std: Scale factor on exploration policy noise. q_networks_2: (Optional.) A `tf_agents.network.Network` to be used as the second critic network during Q learning. The weights from `q_network` are copied if this is not provided. target_q_networks: (Optional.) A `tf_agents.network.Network` to be used as the target Q network during Q learning. Every `target_update_period` train steps, the weights from `q_networks` are copied (possibly withsmoothing via `target_update_tau`) to ` target_actor_network`. If `target_actor_network` is not provided, it is created by making a copy of `actor_network`, which initializes a new network with the same structure and its own layers and weights. Performing a `Network.copy` does not work when the network instance already has trainable parameters (e.g., has already been built, or when the network is sharing layers with another). In these cases, it is up to you to build a copy having weights that are not shared with the original `actor_network`, so that this can be used as a target network. If you provide a `target_actor_network` that shares any weights with `actor_network`, a warning will be logged but no exception is thrown. target_q_networks_2: (Optional.) Similar network as target_actor_network but for the q_network. See documentation for target_actor_network. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. dqda_clipping: A scalar or float clips the gradient dqda element-wise between [-dqda_clipping, dqda_clipping]. Default is None representing no clippiing. td_errors_loss_fn: A function for computing the TD errors loss. If None, a default value of elementwise huber_loss is used. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. target_policy_noise: Scale factor on target action noise target_policy_noise_clip: Value to clip noise. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. action_params_mask: A mask of continuous parameter actions for discrete action name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. """ tf.Module.__init__(self, name=name) # critic network here is Q-network if not isinstance(q_networks, (list, tuple)): q_networks = (q_networks,) self._q_network_1 = q_networks if target_q_networks is not None and not isinstance(target_q_networks, (list, tuple)): target_q_networks = (target_q_networks,) if target_q_networks is None: target_q_networks = [None]*len(self._q_network_1) assert len(self._q_network_1) == len(target_q_networks) self._target_q_network_1 = [common.maybe_copy_target_network_with_checks( q, target_q, 'Target'+q.name+'_1') for q, target_q in zip(self._q_network_1, target_q_networks)] if q_networks_2 is not None: self._q_network_2 = q_networks_2 else: self._q_network_2 = [q.copy(name=q.name+'_2') for q in q_networks] # Do not use target_q_network_2 if q_network_2 is None. target_q_networks_2 = None if target_q_networks_2 is not None and not isinstance(target_q_networks, (list, tuple)): target_q_networks_2 = (target_q_networks_2,) if target_q_networks_2 is None: target_q_networks_2 = [None]*len(self._q_network_2) self._target_q_network_2 = [ common.maybe_copy_target_network_with_checks( q, target_q, 'Target'+q.name+'_2') for q, target_q in zip(self._q_network_2, target_q_networks_2)] self._critic_optimizer = critic_optimizer self._exploration_noise_std = exploration_noise_std self._epsilon_greedy = epsilon_greedy self._boltzmann_temperature = boltzmann_temperature self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._dqda_clipping = dqda_clipping self._td_errors_loss_fn = ( td_errors_loss_fn or common.element_wise_huber_loss) self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_policy_noise = target_policy_noise self._target_policy_noise_clip = target_policy_noise_clip self._gradient_clipping = gradient_clipping self._update_target = self._get_target_updater( target_update_tau, target_update_period) num_networks = len(self._q_network_1) if len(action_spec) != num_networks or len(time_step_spec) != num_networks: raise ValueError("number of networks should match number of action_spec and time_step_spec") policies = [] self._q_value_policies_1, self._q_value_policies_2 = [], [] for i, q_network in enumerate(self._q_network_1): use_previous_action = i != 0 policies.append(hetero_q_policy.HeteroQPolicy(time_step_spec=time_step_spec[i], action_spec=action_spec[i], mixed_q_network=q_networks[i], func_arg_mask=action_params_mask[i], use_previous_action=use_previous_action, name="HeteroQPolicy_1_"+str(i))) self._q_value_policies_2.append( hetero_q_policy.HeteroQPolicy(time_step_spec=time_step_spec[i], action_spec=action_spec[i], mixed_q_network=self._q_network_2[i], func_arg_mask=action_params_mask[i], use_previous_action=use_previous_action, name="HeteroQPolicy_2_"+str(i))) self._q_value_policies_1 = policies collect_policies = [epsilon_boltzmann_policy.EpsilonBoltzmannPolicy( p, temperature=boltzmann_temperature, epsilon=self._epsilon_greedy, remove_neg_inf=True) for p in policies] collect_policy = sequential_policy.SequentialPolicy(collect_policies) policies = [greedy_policy.GreedyPolicy(p, remove_neg_inf=True) for p in policies] policy = sequential_policy.SequentialPolicy(policies) # Create self._target_greedy_policy in order to compute target Q-values in _compute_next_q_values. target_policies = [] self._target_q_value_policies_1, self._target_q_value_policies_2 = [], [] for i, q_network in enumerate(self._q_network_1): use_previous_action = i != 0 target_policies.append(hetero_q_policy.HeteroQPolicy( time_step_spec=time_step_spec[i], action_spec=action_spec[i], mixed_q_network=self._target_q_network_1[i], func_arg_mask=action_params_mask[i], use_previous_action=use_previous_action, name="TargetHeteroQPolicy_1"+str(i))) self._target_q_value_policies_2.append(hetero_q_policy.HeteroQPolicy( time_step_spec=time_step_spec[i], action_spec=action_spec[i], mixed_q_network=self._target_q_network_2[i], func_arg_mask=action_params_mask[i], use_previous_action=use_previous_action, name="TargetHeteroQPolicy_2"+str(i))) self._target_q_value_policies_1 = target_policies self._target_q_value_policies_1 = [greedy_policy.GreedyPolicy(p, remove_neg_inf=True) for p in self._target_q_value_policies_1] self._target_q_value_policies_2 = [greedy_policy.GreedyPolicy(p, remove_neg_inf=True) for p in self._target_q_value_policies_2] target_policies = self._target_q_value_policies_1 target_policy = sequential_policy.SequentialPolicy(target_policies) self._target_greedy_policies = target_policy self._action_params_mask = action_params_mask self._n_step_update = n_step_update super(Td3DqnAgent, self).__init__( policy.time_step_spec, policy.action_spec, policy, collect_policy, train_sequence_length=2 if not self._q_network_1[0].state_spec else None, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter)
def train_eval( load_root_dir, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args safety_critic_joint_fc_layers=None, safety_critic_lr=3e-4, safety_critic_bias_init_val=None, safety_critic_kernel_scale=None, n_envs=None, target_safety=0.2, fail_weight=None, # Params for train num_global_steps=10000, batch_size=256, # Params for eval run_eval=False, eval_metrics=[], num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, debug_summaries=False, seed=None): if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format( agent_class) agent_class = ALGOS.get(agent_class) train_ckpt_dir = osp.join(load_root_dir, 'train') rb_ckpt_dir = osp.join(load_root_dir, 'train', 'replay_buffer') py_env = env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) tf_env = tf_py_environment.TFPyEnvironment(py_env) if monitor: vid_path = os.path.join(load_root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) if run_eval: eval_dir = os.path.join(load_root_dir, 'eval') n_envs = n_envs or num_eval_episodes eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric( prefix='EvalMetrics', buffer_size=num_eval_episodes, batch_size=n_envs) ] + [ tf_py_metric.TFPyMetric(m, name='EvalMetrics/{}'.format(m.name)) for m in eval_metrics ] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment([ lambda: env_load_fn(env_name, gym_env_wrappers=gym_env_wrappers) ] * n_envs)) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: eval_tf_env.pyenv.seed(seeds) except: pass global_step = tf.compat.v1.train.get_or_create_global_step() time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=False) else: tf_agent = agent_class(time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=False) collect_data_spec = tf_agent.collect_data_spec replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=1000000) replay_buffer = misc.load_rb_ckpt(rb_ckpt_dir, replay_buffer) tf_agent, _ = misc.load_agent_ckpt(train_ckpt_dir, tf_agent) if agent_class in SAFETY_AGENTS: target_safety = target_safety or tf_agent._target_safety loaded_train_steps = global_step.numpy() logging.info("Loaded agent from %s trained for %d steps", train_ckpt_dir, loaded_train_steps) global_step.assign(0) tf.summary.experimental.set_step(global_step) thresholds = [target_safety, 0.5] sc_metrics = [ tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds) ] if seed: tf.compat.v1.set_random_seed(seed) summaries_flush_secs = 10 timestamp = datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S') offline_train_dir = osp.join(train_ckpt_dir, 'offline', timestamp) config_saver = gin.tf.GinConfigSaverHook(offline_train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() sc_summary_writer = tf.compat.v2.summary.create_file_writer( offline_train_dir, flush_millis=summaries_flush_secs * 1000) sc_summary_writer.set_as_default() if safety_critic_kernel_scale is not None: ki = tf.compat.v1.variance_scaling_initializer( scale=safety_critic_kernel_scale, mode='fan_in', distribution='truncated_normal') else: ki = tf.compat.v1.keras.initializers.VarianceScaling( scale=1. / 3., mode='fan_in', distribution='uniform') if safety_critic_bias_init_val is not None: bi = tf.constant_initializer(safety_critic_bias_init_val) else: bi = None sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=safety_critic_joint_fc_layers, kernel_initializer=ki, value_bias_initializer=bi, name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') optimizer = tf.keras.optimizers.Adam(safety_critic_lr) sc_net_off_ckpt_dir = os.path.join(offline_train_dir, 'safety_critic') sc_checkpointer = common.Checkpointer( ckpt_dir=sc_net_off_ckpt_dir, safety_critic=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=optimizer, global_step=global_step, max_to_keep=5) sc_checkpointer.initialize_or_restore() resample_counter = py_metrics.CounterMetric('ActionResampleCounter') eval_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_net, safety_critic_network=sc_net_off, safety_threshold=target_safety, resample_counter=resample_counter, training=True) dataset = replay_buffer.as_dataset(num_parallel_calls=3, num_steps=2, sample_batch_size=batch_size // 2).prefetch(3) data = iter(dataset) full_data = replay_buffer.gather_all() fail_mask = tf.cast(full_data.observation['task_agn_rew'], tf.bool) fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, fail_mask), full_data) init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, full_data.is_first()), full_data) before_fail_mask = tf.roll(fail_mask, [-1], axis=[1]) after_init_mask = tf.roll(full_data.is_first(), [1], axis=[1]) before_fail_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, before_fail_mask), full_data) after_init_step = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, after_init_mask), full_data) filter_mask = tf.squeeze(tf.logical_or(before_fail_mask, fail_mask)) filter_mask = tf.pad( filter_mask, [[0, replay_buffer._max_length - filter_mask.shape[0]]]) n_failures = tf.reduce_sum(tf.cast(filter_mask, tf.int32)).numpy() failure_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=n_failures, dataset_window_shift=1) data_utils.copy_rb(replay_buffer, failure_buffer, filter_mask) sc_dataset_neg = failure_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=batch_size // 2, num_steps=2).prefetch(3) neg_data = iter(sc_dataset_neg) get_action = lambda ts: tf_agent._actions_and_log_probs(ts)[0] eval_sc = log_utils.eval_fn(before_fail_step, fail_step, init_step, after_init_step, get_action) losses = [] mean_loss = tf.keras.metrics.Mean(name='mean_ep_loss') target_update = train_utils.get_target_updater(sc_net_off, target_sc_net_off) with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): while global_step.numpy() < num_global_steps: pos_experience, _ = next(data) neg_experience, _ = next(neg_data) exp = data_utils.concat_batches(pos_experience, neg_experience, collect_data_spec) boundary_mask = tf.logical_not(exp.is_boundary()[:, 0]) exp = nest_utils.fast_map_structure( lambda *x: tf.boolean_mask(*x, boundary_mask), exp) safe_rew = exp.observation['task_agn_rew'][:, 1] if fail_weight: weights = tf.where(tf.cast(safe_rew, tf.bool), fail_weight / 0.5, (1 - fail_weight) / 0.5) else: weights = None train_loss, sc_loss, lam_loss = train_step( exp, safe_rew, tf_agent, sc_net=sc_net_off, target_sc_net=target_sc_net_off, metrics=sc_metrics, weights=weights, target_safety=target_safety, optimizer=optimizer, target_update=target_update, debug_summaries=debug_summaries) global_step.assign_add(1) global_step_val = global_step.numpy() losses.append( (train_loss.numpy(), sc_loss.numpy(), lam_loss.numpy())) mean_loss(train_loss) with tf.name_scope('Losses'): tf.compat.v2.summary.scalar(name='sc_loss', data=sc_loss, step=global_step_val) tf.compat.v2.summary.scalar(name='lam_loss', data=lam_loss, step=global_step_val) if global_step_val % summary_interval == 0: tf.compat.v2.summary.scalar(name=mean_loss.name, data=mean_loss.result(), step=global_step_val) if global_step_val % summary_interval == 0: with tf.name_scope('Metrics'): for metric in sc_metrics: if len(tf.squeeze(metric.result()).shape) == 0: tf.compat.v2.summary.scalar(name=metric.name, data=metric.result(), step=global_step_val) else: fmt_str = '_{}'.format(thresholds[0]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[0], step=global_step_val) fmt_str = '_{}'.format(thresholds[1]) tf.compat.v2.summary.scalar( name=metric.name + fmt_str, data=metric.result()[1], step=global_step_val) metric.reset_states() if global_step_val % eval_interval == 0: eval_sc(sc_net_off, step=global_step_val) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step_val) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries( train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and global_step_val % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action( monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug( 'saved rollout at timestep %d, rollout length: %d, %4.2f sec', global_step_val, ep_len, time.time() - monitor_start) if global_step_val % train_checkpoint_interval == 0: sc_checkpointer.save(global_step=global_step_val)
def __init__( self, time_step_spec, action_spec, q_network, optimizer, epsilon_greedy=0.1, n_step_update=1, boltzmann_temperature=None, emit_log_probability=False, # Params for target network updates target_q_network=None, target_update_tau=1.0, target_update_period=1, # Params for training. td_errors_loss_fn=None, gamma=1.0, reward_scale_factor=1.0, gradient_clipping=None, # Params for debugging debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, name=None): """Creates a DQN Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. q_network: A `tf_agents.network.Network` to be used by the agent. The network will be called with `call(observation, step_type)` and should emit logits over the action space. optimizer: The optimizer to use for training. epsilon_greedy: probability of choosing a random action in the default epsilon-greedy collect policy (used only if a wrapper is not provided to the collect_policy method). n_step_update: The number of steps to consider when computing TD error and TD loss. Defaults to single-step updates. Note that this requires the user to call train on Trajectory objects with a time dimension of `n_step_update + 1`. However, note that we do not yet support `n_step_update > 1` in the case of RNNs (i.e., non-empty `q_network.state_spec`). boltzmann_temperature: Temperature value to use for Boltzmann sampling of the actions during data collection. The closer to 0.0, the higher the probability of choosing the best action. emit_log_probability: Whether policies emit log probabilities or not. target_q_network: (Optional.) A `tf_agents.network.Network` to be used as the target network during Q learning. Every `target_update_period` train steps, the weights from `q_network` are copied (possibly with smoothing via `target_update_tau`) to `target_q_network`. If `target_q_network` is not provided, it is created by making a copy of `q_network`, which initializes a new network with the same structure and its own layers and weights. Network copying is performed via the `Network.copy` superclass method, and may inadvertently lead to the resulting network to share weights with the original. This can happen if, for example, the original network accepted a pre-built Keras layer in its `__init__`, or accepted a Keras layer that wasn't built, but neglected to create a new copy. In these cases, it is up to you to provide a target Network having weights that are not shared with the original `q_network`. If you provide a `target_q_network` that shares any weights with `q_network`, a warning will be logged but no exception is thrown. Note; shallow copies of Keras layers may be built via the code: ```python new_layer = type(layer).from_config(layer.get_config()) ``` target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. td_errors_loss_fn: A function for computing the TD errors loss. If None, a default value of element_wise_huber_loss is used. This function takes as input the target and the estimated Q values and returns the loss for each element of the batch. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. Raises: ValueError: If the action spec contains more than one action or action spec minimum is not equal to 0. NotImplementedError: If `q_network` has non-empty `state_spec` (i.e., an RNN is provided) and `n_step_update > 1`. """ tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) if epsilon_greedy is not None and boltzmann_temperature is not None: raise ValueError( 'Configured both epsilon_greedy value {} and temperature {}, ' 'however only one of them can be used for exploration.'.format( epsilon_greedy, boltzmann_temperature)) self._q_network = q_network self._target_q_network = common.maybe_copy_target_network_with_checks( self._q_network, target_q_network, 'TargetQNetwork') self._epsilon_greedy = epsilon_greedy self._n_step_update = n_step_update self._boltzmann_temperature = boltzmann_temperature self._optimizer = optimizer self._td_errors_loss_fn = td_errors_loss_fn or common.element_wise_huber_loss self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._gradient_clipping = gradient_clipping self._update_target = self._get_target_updater( target_update_tau, target_update_period) policy, collect_policy = self._setup_policy(time_step_spec, action_spec, boltzmann_temperature, emit_log_probability) if q_network.state_spec and n_step_update != 1: raise NotImplementedError( 'DqnAgent does not currently support n-step updates with stateful ' 'networks (i.e., RNNs), but n_step_update = {}'.format(n_step_update)) train_sequence_length = ( n_step_update + 1 if not q_network.state_spec else None) super(DqnAgent, self).__init__( time_step_spec, action_spec, policy, collect_policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter)
def __init__(self, time_step_spec, action_spec, critic_network, actor_network, actor_optimizer, critic_optimizer, actor_loss_weight=1.0, critic_loss_weight=0.5, actor_policy_ctor=actor_policy.ActorPolicy, critic_network_2=None, target_critic_network=None, target_critic_network_2=None, target_update_tau=1.0, target_update_period=1, td_errors_loss_fn=tf.math.squared_difference, gamma=1.0, gradient_clipping=None, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, name=None): """Creates a C-learning Agent. By default, the environment observation contains the current state and goal state. By setting the obs_to_goal gin config in c_learning_utils, the user can specify that the agent should only look at certain subsets of the goal state. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. critic_network: A function critic_network((observations, actions)) that returns the q_values for each observation and action. actor_network: A function actor_network(observation, action_spec) that returns action distribution. actor_optimizer: The optimizer to use for the actor network. critic_optimizer: The default optimizer to use for the critic network. actor_loss_weight: The weight on actor loss. critic_loss_weight: The weight on critic loss. actor_policy_ctor: The policy class to use. critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as the second critic network during Q learning. The weights from `critic_network` are copied if this is not provided. target_critic_network: (Optional.) A `tf_agents.network.Network` to be used as the target critic network during Q learning. Every `target_update_period` train steps, the weights from `critic_network` are copied (possibly withsmoothing via `target_update_tau`) to ` target_critic_network`. If `target_critic_network` is not provided, it is created by making a copy of `critic_network`, which initializes a new network with the same structure and its own layers and weights. Performing a `Network.copy` does not work when the network instance already has trainable parameters (e.g., has already been built, or when the network is sharing layers with another). In these cases, it is up to you to build a copy having weights that are not shared with the original `critic_network`, so that this can be used as a target network. If you provide a `target_critic_network` that shares any weights with `critic_network`, a warning will be logged but no exception is thrown. target_critic_network_2: (Optional.) Similar network as target_critic_network but for the critic_network_2. See documentation for target_critic_network. Will only be used if 'critic_network_2' is also specified. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. td_errors_loss_fn: A function for computing the elementwise TD errors loss. gamma: A discount factor for future rewards. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. """ tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) self._critic_network_1 = critic_network self._critic_network_1.create_variables() if target_critic_network: target_critic_network.create_variables() self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks( self._critic_network_1, target_critic_network, 'TargetCriticNetwork1')) if critic_network_2 is not None: self._critic_network_2 = critic_network_2 else: self._critic_network_2 = critic_network.copy(name='CriticNetwork2') # Do not use target_critic_network_2 if critic_network_2 is None. target_critic_network_2 = None self._critic_network_2.create_variables() if target_critic_network_2: target_critic_network_2.create_variables() self._target_critic_network_2 = ( common.maybe_copy_target_network_with_checks( self._critic_network_2, target_critic_network_2, 'TargetCriticNetwork2')) if actor_network: actor_network.create_variables() self._actor_network = actor_network policy = actor_policy_ctor(time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=False) self._train_policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=True) self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._actor_loss_weight = actor_loss_weight self._critic_loss_weight = critic_loss_weight self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._gradient_clipping = gradient_clipping self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._update_target = self._get_target_updater( tau=self._target_update_tau, period=self._target_update_period) train_sequence_length = 2 if not critic_network.state_spec else None super(CLearningAgent, self).__init__(time_step_spec, action_spec, policy=policy, collect_policy=policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter)
def __init__(self, time_step_spec, action_spec, critic_networks, actor_network, actor_optimizer, critic_optimizers, alpha_optimizer, actor_policy_ctor=actor_policy.ActorPolicy, target_critic_networks=None, target_update_tau=0.001, target_update_period=1, td_errors_loss_fn=tf.math.squared_difference, gamma=1.0, reward_scale_factor=1.0, initial_log_alpha=0.0, target_entropy=None, gradient_clipping=None, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, percentile=0.3, name=None): """Creates a SAC Agent. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. action_spec: A nest of BoundedTensorSpec representing the actions. critic_networks: A function critic_network((observations, actions)) that returns the q_values for each observation and action. actor_network: A function actor_network(observation, action_spec) that returns action distribution. actor_optimizer: The optimizer to use for the actor network. critic_optimizers: The default optimizer to use for the critic network. alpha_optimizer: The default optimizer to use for the alpha variable. actor_policy_ctor: The policy class to use. critic_network_2: (Optional.) A `tf_agents.network.Network` to be used as the second critic network during Q learning. The weights from `critic_network` are copied if this is not provided. target_critic_networks: (Optional.) A `tf_agents.network.Network` to be used as the target critic network during Q learning. Every `target_update_period` train steps, the weights from `critic_network` are copied (possibly withsmoothing via `target_update_tau`) to ` target_critic_network`. If `target_critic_network` is not provided, it is created by making a copy of `critic_network`, which initializes a new network with the same structure and its own layers and weights. Performing a `Network.copy` does not work when the network instance already has trainable parameters (e.g., has already been built, or when the network is sharing layers with another). In these cases, it is up to you to build a copy having weights that are not shared with the original `critic_network`, so that this can be used as a target network. If you provide a `target_critic_network` that shares any weights with `critic_network`, a warning will be logged but no exception is thrown. target_critic_network_2: (Optional.) Similar network as target_critic_network but for the critic_network_2. See documentation for target_critic_network. Will only be used if 'critic_network_2' is also specified. target_update_tau: Factor for soft update of the target networks. target_update_period: Period for soft update of the target networks. td_errors_loss_fn: A function for computing the elementwise TD errors loss. gamma: A discount factor for future rewards. reward_scale_factor: Multiplicative scale for the reward. initial_log_alpha: Initial value for log_alpha. target_entropy: The target average policy entropy, for updating alpha. The default value is negative of the total number of actions. gradient_clipping: Norm length to clip gradients. debug_summaries: A bool to gather debug summaries. summarize_grads_and_vars: If True, gradient and network variable summaries will be written during training. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. name: The name of this agent. All variables in this module will fall under that name. Defaults to the class name. """ tf.Module.__init__(self, name=name) flat_action_spec = tf.nest.flatten(action_spec) for spec in flat_action_spec: if spec.dtype.is_integer: raise NotImplementedError( 'SacAgent does not currently support discrete actions. ' 'Action spec: {}'.format(action_spec)) self._critic_networks = critic_networks [cn.create_variables() for cn in self._critic_networks] if target_critic_networks: self._target_critic_networks = target_critic_networks [cn.create_variables() for cn in self._target_critic_networks] else: self._target_critic_networks = [None for _ in range(len(self._critic_networks))] self._target_critic_networks = [ common.maybe_copy_target_network_with_checks(cn, tcn, 'TargetCriticNetwork{}'.format(i)) for i, (tcn, cn) in enumerate(zip(self._target_critic_networks, self._critic_networks))] if actor_network: actor_network.create_variables() self._actor_network = actor_network policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network) self._train_policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network) self._log_alpha = common.create_variable( 'initial_log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) # If target_entropy was not passed, set it to negative of the total number # of action dimensions. if target_entropy is None: flat_action_spec = tf.nest.flatten(action_spec) target_entropy = -np.sum([ np.product(single_spec.shape.as_list()) for single_spec in flat_action_spec ]) self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizers self._alpha_optimizer = alpha_optimizer self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_entropy = target_entropy self._gradient_clipping = gradient_clipping self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._update_target = self._get_target_updater( tau=self._target_update_tau, period=self._target_update_period) self._percentile = percentile train_sequence_length = 2 if not critic_networks[0].state_spec else None super(EnsembleSacAgent, self).__init__( time_step_spec, action_spec, policy=policy, collect_policy=policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter)
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, critic_network: network.Network, actor_network: network.Network, actor_optimizer: types.Optimizer, critic_optimizer: types.Optimizer, alpha_optimizer: types.Optimizer, actor_loss_weight: types.Float = 1.0, critic_loss_weight: types.Float = 0.5, alpha_loss_weight: types.Float = 1.0, actor_policy_ctor: Callable[ ..., tf_policy.TFPolicy] = actor_policy.ActorPolicy, critic_network_2: Optional[network.Network] = None, target_critic_network: Optional[network.Network] = None, target_critic_network_2: Optional[network.Network] = None, target_update_tau: types.Float = 1.0, target_update_period: types.Int = 1, td_errors_loss_fn: types.LossFn = tf.math.squared_difference, gamma: types.Float = 1.0, sigma: types.Float = 0.9, reward_scale_factor: types.Float = 1.0, initial_log_alpha: types.Float = 0.0, use_log_alpha_in_alpha_loss: bool = True, target_entropy: Optional[types.Float] = None, gradient_clipping: Optional[types.Float] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, train_step_counter: Optional[tf.Variable] = None, name: Optional[Text] = None): tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) net_observation_spec = time_step_spec.observation critic_spec = (net_observation_spec, action_spec) self._critic_network_1 = critic_network if critic_network_2 is not None: self._critic_network_2 = critic_network_2 else: self._critic_network_2 = critic_network.copy(name='CriticNetwork2') # Do not use target_critic_network_2 if critic_network_2 is None. target_critic_network_2 = None # Wait until critic_network_2 has been copied from critic_network_1 before # creating variables on both. self._critic_network_1.create_variables(critic_spec) self._critic_network_2.create_variables(critic_spec) if target_critic_network: target_critic_network.create_variables(critic_spec) self._target_critic_network_1 = ( common.maybe_copy_target_network_with_checks( self._critic_network_1, target_critic_network, input_spec=critic_spec, name='TargetCriticNetwork1')) if target_critic_network_2: target_critic_network_2.create_variables(critic_spec) self._target_critic_network_2 = ( common.maybe_copy_target_network_with_checks( self._critic_network_2, target_critic_network_2, input_spec=critic_spec, name='TargetCriticNetwork2')) if actor_network: actor_network.create_variables(net_observation_spec) self._actor_network = actor_network policy = actor_policy_ctor(time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=False) self._train_policy = actor_policy_ctor( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, training=True) self._log_alpha = common.create_variable( 'initial_log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) if target_entropy is None: target_entropy = self._get_default_target_entropy(action_spec) self._use_log_alpha_in_alpha_loss = use_log_alpha_in_alpha_loss self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._alpha_optimizer = alpha_optimizer self._actor_loss_weight = actor_loss_weight self._critic_loss_weight = critic_loss_weight self._alpha_loss_weight = alpha_loss_weight self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_entropy = target_entropy self._gradient_clipping = gradient_clipping self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._update_target = self._get_target_updater( tau=self._target_update_tau, period=self._target_update_period) self.sigma = sigma train_sequence_length = 2 if not critic_network.state_spec else None super(sac_agent.SacAgent, self).__init__(time_step_spec, action_spec, policy=policy, collect_policy=policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter, validate_args=False) self._as_transition = data_converter.AsTransition( self.data_context, squeeze_time_dim=(train_sequence_length == 2))
def __init__( self, time_step_spec, action_spec, q_network, optimizer, observation_and_action_constraint_splitter=None, epsilon_greedy=0.1, n_step_update=1, boltzmann_temperature=None, emit_log_probability=False, # Params for target network updates target_q_network=None, target_update_tau=1.0, target_update_period=1, td_errors_loss_fn=None, gamma=1.0, reward_scale_factor=1.0, gradient_clipping=None, # Params for debugging debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, name=None): tf.Module.__init__(self, name=name) self._check_action_spec(action_spec) if epsilon_greedy is not None and boltzmann_temperature is not None: raise ValueError( 'Configured both epsilon_greedy value {} and temperature {}, ' 'however only one of them can be used for exploration.'.format( epsilon_greedy, boltzmann_temperature)) self._observation_anc_action_constraint_splitter = ( observation_and_action_constraint_splitter) self._q_network = q_network q_network.create_variables() if target_q_network: target_q_network.create_variables() self._target_q_network = common.maybe_copy_target_network_with_checks( self._q_network, target_q_network, "TargetQNetwork") self._epsilon_greedy = epsilon_greedy self._n_step_update = n_step_update self._boltzmann_temperature = boltzmann_temperature self._optimizer = optimizer self._td_error_loss_fn = (td_errors_loss_fn or common.element_wise_huber_loss) self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._gradient_clipping = gradient_clipping self._update_target = self._get_target_updater(target_update_tau, target_update_period) policy, collect_policy = self._setup_policy(time_step_spec, action_spec, boltzmann_temperature, emit_log_probability) if q_network.state_spec and n_step_update != 1: raise NotImplementedError( 'DqnAgent does not currently support n-step updates with stateful ' 'networks (i.e., RNNs), but n_step_update = {}'.format( n_step_update)) train_sequence_length = (n_step_update + 1 if not q_network.state_spec else None) super(DqnAgent, self).__init__(time_step_spec, action_spec, policy, collect_policy, train_sequence_length=train_sequence_length, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter)
def train_eval( root_dir, load_root_dir=None, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, initial_collect_driver_class=None, collect_driver_class=None, online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver, num_global_steps=1000000, rb_size=None, train_steps_per_iteration=1, train_metrics=None, eval_metrics=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args sc_rb_size=None, target_safety=None, train_sc_steps=10, train_sc_interval=1000, online_critic=False, n_envs=None, finetune_sc=False, pretraining=True, lambda_schedule_nsteps=0, lambda_initial=0., lambda_final=1., kstep_fail=0, # Ensemble Critic training args num_critics=None, critic_learning_rate=3e-4, # Wcpg Critic args critic_preprocessing_layer_size=256, # Params for train batch_size=256, # Params for eval run_eval=False, num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, early_termination_fn=None, debug_summaries=False, seed=None, eager_debug=False, env_metric_factories=None, wandb=False): # pylint: disable=unused-argument """train and eval script for SQRL.""" if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(agent_class) agent_class = ALGOS.get(agent_class) n_envs = n_envs or num_eval_episodes root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') # =====================================================================# # Setup summary metrics, file writers, and create env # # =====================================================================# train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] eval_metrics = eval_metrics or [] updating_sc = online_critic and (not load_root_dir or finetune_sc) logging.debug('updating safety critic: %s', updating_sc) if seed: tf.compat.v1.set_random_seed(seed) if agent_class in SAFETY_AGENTS: if online_critic: sc_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * n_envs )) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: sc_tf_env.pyenv.seed(seeds) except: pass if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes, batch_size=n_envs), ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * n_envs )) if seed: try: for i, pyenv in enumerate(eval_tf_env.pyenv.envs): pyenv.seed(seed * n_envs + i) except: pass elif 'Drunk' in env_name: # Just visualizes trajectories in drunk spider environment eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(env_name)) else: eval_tf_env = None if monitor: vid_path = os.path.join(root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): py_env = env_load_fn(env_name) tf_env = tf_py_environment.TFPyEnvironment(py_env) if seed: try: for i, pyenv in enumerate(tf_env.pyenv.envs): pyenv.seed(seed * n_envs + i) except: pass time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() logging.debug('obs spec: %s', observation_spec) logging.debug('action spec: %s', action_spec) # =====================================================================# # Setup agent class # # =====================================================================# if agent_class == wcpg_agent.WcpgAgent: alpha_spec = tensor_spec.BoundedTensorSpec(shape=(1,), dtype=tf.float32, minimum=0., maximum=1., name='alpha') input_tensor_spec = (observation_spec, action_spec, alpha_spec) critic_net = agents.DistributionalCriticNetwork( input_tensor_spec, preprocessing_layer_size=critic_preprocessing_layer_size, joint_fc_layer_params=critic_joint_fc_layers) actor_net = agents.WcpgActorNetwork((observation_spec, alpha_spec), action_spec) else: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: logging.debug('Making SQRL agent') if lambda_schedule_nsteps > 0: lambda_update_every_nsteps = num_global_steps // lambda_schedule_nsteps step_size = (lambda_final - lambda_initial) / lambda_update_every_nsteps lambda_scheduler = lambda lam: common.periodically( body=lambda: tf.group(lam.assign(lam + step_size)), period=lambda_update_every_nsteps) else: lambda_scheduler = None safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) ts = target_safety thresholds = [ts, 0.5] sc_metrics = [tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5)] tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=debug_summaries, safety_pretraining=pretraining, train_critic_online=online_critic, initial_log_lambda=lambda_initial, log_lambda=(lambda_scheduler is None), lambda_scheduler=lambda_scheduler) elif agent_class is ensemble_sac_agent.EnsembleSacAgent: critic_nets, critic_optimizers = [critic_net], [tf.keras.optimizers.Adam(critic_learning_rate)] for _ in range(num_critics - 1): critic_nets.append(agents.CriticNetwork((observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers)) critic_optimizers.append(tf.keras.optimizers.Adam(critic_learning_rate)) tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_networks=critic_nets, critic_optimizers=critic_optimizers, debug_summaries=debug_summaries ) else: # agent is either SacAgent or WcpgAgent logging.debug('critic input_tensor_spec: %s', critic_net.input_tensor_spec) tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=debug_summaries) tf_agent.initialize() # =====================================================================# # Setup replay buffer # # =====================================================================# collect_data_spec = tf_agent.collect_data_spec logging.debug('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. rb_size = rb_size or 1000000 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=rb_size) logging.debug('RB capacity: %i', replay_buffer.capacity) logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec) if agent_class in SAFETY_AGENTS: sc_rb_size = sc_rb_size or num_eval_episodes * 500 sc_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=sc_rb_size, dataset_window_shift=1) num_episodes = tf_metrics.NumberOfEpisodes() num_env_steps = tf_metrics.EnvironmentSteps() return_metric = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ num_episodes, num_env_steps, return_metric, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if 'Minitaur' in env_name and not pretraining: goal_vel = gin.query_parameter("%GOAL_VELOCITY") early_termination_fn = train_utils.MinitaurTerminationFn( speed_metric=train_metrics[-2], total_falls_metric=train_metrics[-3], env_steps_metric=num_env_steps, goal_speed=goal_vel) if env_metric_factories: for env_metric in env_metric_factories: train_metrics.append(tf_py_metric.TFPyMetric(env_metric(tf_env.pyenv.envs))) if run_eval: eval_metrics.append(env_metric([env for env in eval_tf_env.pyenv._envs])) # =====================================================================# # Setup collect policies # # =====================================================================# if not online_critic: eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy if not pretraining and agent_class in SAFETY_AGENTS: collect_policy = tf_agent.safe_policy else: eval_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy collect_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy online_collect_policy = tf_agent.safe_policy # if pretraining else tf_agent.collect_policy if pretraining: online_collect_policy._training = False if not load_root_dir: initial_collect_policy = random_tf_policy.RandomTFPolicy(time_step_spec, action_spec) else: initial_collect_policy = collect_policy if agent_class == wcpg_agent.WcpgAgent: initial_collect_policy = agents.WcpgPolicyWrapper(initial_collect_policy) # =====================================================================# # Setup Checkpointing # # =====================================================================# train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer( ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if online_critic: online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=sc_buffer) # loads agent, replay buffer, and online sc/buffer if online_critic if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_agent_ckpt(load_train_dir, tf_agent) if len(os.listdir(os.path.join(load_train_dir, 'replay_buffer'))) > 1: load_rb_ckpt_dir = os.path.join(load_train_dir, 'replay_buffer') misc.load_rb_ckpt(load_rb_ckpt_dir, replay_buffer) if online_critic: load_online_sc_ckpt_dir = os.path.join(load_root_dir, 'sc') load_online_rb_ckpt_dir = os.path.join(load_train_dir, 'online_replay_buffer') if osp.exists(load_online_rb_ckpt_dir): misc.load_rb_ckpt(load_online_rb_ckpt_dir, sc_buffer) if osp.exists(load_online_sc_ckpt_dir): misc.load_safety_critic_ckpt(load_online_sc_ckpt_dir, safety_critic_net) elif agent_class in SAFETY_AGENTS: offline_run = sorted(os.listdir(os.path.join(load_train_dir, 'offline')))[-1] load_sc_ckpt_dir = os.path.join(load_train_dir, 'offline', offline_run, 'safety_critic') if osp.exists(load_sc_ckpt_dir): sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=(512, 512), name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') sc_optimizer = tf.keras.optimizers.Adam(critic_learning_rate) _ = misc.load_safety_critic_ckpt( load_sc_ckpt_dir, safety_critic_net=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=sc_optimizer) tf_agent._safety_critic_network = sc_net_off tf_agent._target_safety_critic_network = target_sc_net_off tf_agent._safety_critic_optimizer = sc_optimizer else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if online_critic: online_rb_checkpointer.initialize_or_restore() if agent_class in SAFETY_AGENTS: sc_dir = os.path.join(root_dir, 'sc') safety_critic_checkpointer = common.Checkpointer( ckpt_dir=sc_dir, safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access target_safety_critic=tf_agent._target_safety_critic_network, optimizer=tf_agent._safety_critic_optimizer, global_step=global_step) if not (load_root_dir and not online_critic): safety_critic_checkpointer.initialize_or_restore() agent_observers = [replay_buffer.add_batch] + train_metrics collect_driver = collect_driver_class( tf_env, collect_policy, observers=agent_observers) collect_driver.run = common.function_in_tf1()(collect_driver.run) if online_critic: logging.debug('online driver class: %s', online_driver_class) online_agent_observers = [num_episodes, num_env_steps, sc_buffer.add_batch] online_driver = online_driver_class( sc_tf_env, online_collect_policy, observers=online_agent_observers, num_episodes=num_eval_episodes) online_driver.run = common.function_in_tf1()(online_driver.run) if eager_debug: tf.config.experimental_run_functions_eagerly(True) else: config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() if global_step == 0: logging.info('Performing initial collection ...') init_collect_observers = agent_observers if agent_class in SAFETY_AGENTS: init_collect_observers += [sc_buffer.add_batch] initial_collect_driver_class( tf_env, initial_collect_policy, observers=init_collect_observers).run() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) if agent_class in SAFETY_AGENTS: last_id = sc_buffer._get_last_id() # pylint: disable=protected-access logging.debug('Data saved in sc_buffer after initial collection: %d steps', last_id) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 train_step = train_utils.get_train_step(tf_agent, replay_buffer, batch_size) if agent_class in SAFETY_AGENTS: critic_train_step = train_utils.get_critic_train_step( tf_agent, replay_buffer, sc_buffer, batch_size=batch_size, updating_sc=updating_sc, metrics=sc_metrics) if early_termination_fn is None: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if agent_class in SAFETY_AGENTS: resample_counter = collect_policy._resample_counter mean_resample_ac = tf.keras.metrics.Mean(name='mean_unsafe_ac_freq') sc_metrics.append(mean_resample_ac) if online_critic: logging.debug('starting safety critic pretraining') # don't fine-tune safety critic if global_step.numpy() == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())] for critic_metric in sc_metrics: res = critic_metric.result().numpy() if not res.shape: critic_results.append((critic_metric.name, res)) else: for r, thresh in zip(res, thresholds): name = '_'.join([critic_metric.name, str(thresh)]) critic_results.append((name, r)) critic_metric.reset_states() if train_metrics_callback: train_metrics_callback(collections.OrderedDict(critic_results), step=global_step.numpy()) logging.debug('Starting main train loop...') curr_ep = [] global_step_val = global_step.numpy() while global_step_val <= num_global_steps and not early_termination_fn(): start_time = time.time() # MEASURE ACTION RESAMPLING FREQUENCY if agent_class in SAFETY_AGENTS: if pretraining and global_step_val == num_global_steps // 2: if online_critic: online_collect_policy._training = True collect_policy._training = True if online_critic or collect_policy._training: mean_resample_ac(resample_counter.result()) resample_counter.reset() if time_step is None or time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar( name='resample_ac_freq', data=resample_ac_freq, step=global_step) # RUN COLLECTION time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) # get last step taken by step_driver traj = replay_buffer._data_table.read(replay_buffer._get_last_id() % replay_buffer._capacity) curr_ep.append(traj) if time_step.is_last(): if agent_class in SAFETY_AGENTS: if time_step.observation['task_agn_rew']: if kstep_fail: # applies task agn rew. over last k steps for i, traj in enumerate(curr_ep[-kstep_fail:]): traj.observation['task_agn_rew'] = 1. sc_buffer.add_batch(traj) else: [sc_buffer.add_batch(traj) for traj in curr_ep] curr_ep = [] if agent_class == wcpg_agent.WcpgAgent: collect_policy._alpha = None # reset WCPG alpha if (global_step_val + 1) % log_interval == 0: logging.debug('policy eval: %4.2f sec', time.time() - start_time) # PERFORMS TRAIN STEP ON ALGORITHM (OFF-POLICY) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) current_step = global_step.numpy() total_loss = mean_train_loss.result() mean_train_loss.reset_states() if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback( collections.OrderedDict([(k, v.numpy()) for k, v in train_loss.extra._asdict().items()]), step=current_step) train_metrics_callback( {'train_loss': total_loss.numpy()}, step=current_step) # TRAIN AND/OR EVAL SAFETY CRITIC if agent_class in SAFETY_AGENTS and current_step % train_sc_interval == 0: if online_critic: batch_time_step = sc_tf_env.reset() # run online critic training collect & update batch_policy_state = online_collect_policy.get_initial_state( sc_tf_env.batch_size) online_driver.run(time_step=batch_time_step, policy_state=batch_policy_state) for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # log safety_critic loss results critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())] metric_utils.log_metrics(sc_metrics) for critic_metric in sc_metrics: res = critic_metric.result().numpy() if not res.shape: critic_results.append((critic_metric.name, res)) else: for r, thresh in zip(res, thresholds): name = '_'.join([critic_metric.name, str(thresh)]) critic_results.append((name, r)) critic_metric.reset_states() if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback(collections.OrderedDict(critic_results), step=current_step) # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True logging.info('Loss diverged, critic_loss: %s, actor_loss: %s', train_loss.extra.critic_loss, train_loss.extra.actor_loss) break else: loss_divergence_counter = 0 time_acc += time.time() - start_time # LOGGING AND METRICS if current_step % log_interval == 0: metric_utils.log_metrics(train_metrics) logging.info('step = %d, loss = %f', current_step, total_loss) steps_per_sec = (current_step - timed_at_step) / time_acc logging.info('%4.2f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = current_step time_acc = 0 train_results = [] for metric in train_metrics[2:]: if isinstance(metric, (metrics.AverageEarlyFailureMetric, metrics.AverageFallenMetric, metrics.AverageSuccessMetric)): # Plot failure as a fn of return metric.tf_summaries( train_step=global_step, step_metrics=[num_env_steps, num_episodes, return_metric]) else: metric.tf_summaries( train_step=global_step, step_metrics=[num_env_steps, num_env_steps]) train_results.append((metric.name, metric.result().numpy())) if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback(collections.OrderedDict(train_results), step=global_step.numpy()) if current_step % train_checkpoint_interval == 0: train_checkpointer.save(global_step=current_step) if current_step % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=current_step) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save(global_step=current_step) if online_critic: online_rb_checkpointer.save(global_step=current_step) if rb_checkpoint_interval and current_step % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=current_step) if wandb and current_step % eval_interval == 0 and "Drunk" in env_name: misc.record_point_mass_episode(eval_tf_env, eval_policy, current_step) if online_critic: misc.record_point_mass_episode(eval_tf_env, tf_agent.safe_policy, current_step, 'safe-trajectory') if run_eval and current_step % eval_interval == 0: eval_results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(eval_results, current_step) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries(train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and current_step % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action(monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug('saved rollout at timestep %d, rollout length: %d, %4.2f sec', current_step, ep_len, time.time() - monitor_start) global_step_val = current_step if early_termination_fn(): # Early stopped, save all checkpoints if not saved if global_step_val % train_checkpoint_interval != 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval != 0: policy_checkpointer.save(global_step=global_step_val) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save(global_step=global_step_val) if online_critic: online_rb_checkpointer.save(global_step=global_step_val) if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss
def __init__(self, time_step_spec, action_spec, critic_network, actor_network, actor_optimizer, critic_optimizer, safety_critic_optimizer, alpha_optimizer, lambda_optimizer=None, lambda_scheduler=None, actor_policy_ctor=actor_policy.ActorPolicy, safety_critic_network=None, target_safety_critic_network=None, critic_network_2=None, target_critic_network=None, target_critic_network_2=None, target_update_tau=0.005, target_update_period=1, td_errors_loss_fn=tf.math.squared_difference, safe_td_errors_loss_fn=tf.keras.losses.binary_crossentropy, gamma=1.0, safety_gamma=None, reward_scale_factor=1.0, initial_log_alpha=0.0, initial_log_lambda=0.0, log_lambda=True, target_entropy=None, target_safety=0.1, gradient_clipping=None, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, safety_pretraining=True, train_critic_online=True, resample_counter=None, fail_weight=None, name="SqrlAgent"): self._safety_critic_network = safety_critic_network self._train_critic_online = train_critic_online super(SqrlAgent, self).__init__( time_step_spec=time_step_spec, action_spec=action_spec, critic_network=critic_network, actor_network=actor_network, actor_optimizer=actor_optimizer, critic_optimizer=critic_optimizer, alpha_optimizer=alpha_optimizer, critic_network_2=critic_network_2, target_critic_network=target_critic_network, target_critic_network_2=target_critic_network_2, actor_policy_ctor=actor_policy_ctor, target_update_tau=target_update_tau, target_update_period=target_update_period, td_errors_loss_fn=td_errors_loss_fn, gamma=gamma, reward_scale_factor=reward_scale_factor, initial_log_alpha=initial_log_alpha, target_entropy=target_entropy, gradient_clipping=gradient_clipping, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter, name=name) if safety_critic_network is not None: self._safety_critic_network.create_variables() else: self._safety_critic_network = common.maybe_copy_target_network_with_checks( critic_network, None, 'SafetyCriticNetwork') if target_safety_critic_network is not None: self._target_safety_critic_network = target_safety_critic_network self._target_safety_critic_network.create_variables() else: self._target_safety_critic_network = common.maybe_copy_target_network_with_checks( self._safety_critic_network, None, 'TargetSafetyCriticNetwork') lambda_name = 'initial_log_lambda' if log_lambda else 'initial_lambda' self._lambda_var = common.create_variable( lambda_name, initial_value=initial_log_lambda, dtype=tf.float32, trainable=True) self._target_safety = target_safety self._safe_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, safety_critic_network=self._safety_critic_network, safety_threshold=target_safety, resample_counter=resample_counter, training=True) self._collect_policy = agents.SafeActorPolicyRSVar( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, safety_critic_network=self._safety_critic_network, safety_threshold=target_safety, resample_counter=resample_counter, training=(not safety_pretraining)) self._safety_critic_optimizer = safety_critic_optimizer self._lambda_optimizer = lambda_optimizer or alpha_optimizer if lambda_scheduler is None: self._lambda_scheduler = lambda_scheduler else: self._lambda_scheduler = lambda_scheduler(self._lambda_var) self._safety_pretraining = safety_pretraining self._safe_td_errors_loss_fn = safe_td_errors_loss_fn self._safety_gamma = safety_gamma or self._gamma self._fail_weight = fail_weight if train_critic_online: self._update_target_safety_critic = self._get_target_updater_safety_critic( tau=self._target_update_tau, period=self._target_update_period) else: self._update_target_safety_critic = None