def testDistributedLinearAgentUpdate(self, batch_size, context_dim, exploration_policy, dtype, use_eigendecomp=False): """Same as above, but uses the distributed train function of the agent.""" # Construct a `Trajectory` for the given action, observation, reward. num_actions = 5 initial_step, final_step = _get_initial_and_final_steps( batch_size, context_dim) action = np.random.randint(num_actions, size=batch_size, dtype=np.int32) action_step = _get_action_step(action) experience = _get_experience(initial_step, action_step, final_step) # Construct an agent and perform the update. observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32) time_step_spec = time_step.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), minimum=0, maximum=num_actions - 1) agent = linear_agent.LinearBanditAgent( exploration_policy=exploration_policy, time_step_spec=time_step_spec, action_spec=action_spec, dtype=dtype) self.evaluate(agent.initialize()) train_fn = common.function_in_tf1()(agent._distributed_train_step) loss_info = train_fn(experience=experience) self.evaluate(loss_info) final_a = self.evaluate(agent.cov_matrix) final_b = self.evaluate(agent.data_vector) # Compute the expected updated estimates. observations_list = tf.dynamic_partition( data=tf.reshape(experience.observation, [batch_size, context_dim]), partitions=tf.convert_to_tensor(action), num_partitions=num_actions) rewards_list = tf.dynamic_partition( data=tf.reshape(experience.reward, [batch_size]), partitions=tf.convert_to_tensor(action), num_partitions=num_actions) expected_a_updated_list = [] expected_b_updated_list = [] expected_theta_updated_list = [] for _, (observations_for_arm, rewards_for_arm) in enumerate( zip(observations_list, rewards_list)): num_samples_for_arm_current = tf.cast( tf.shape(rewards_for_arm)[0], tf.float32) num_samples_for_arm_total = num_samples_for_arm_current # pylint: disable=cell-var-from-loop def true_fn(): a_new = tf.matmul(observations_for_arm, observations_for_arm, transpose_a=True) b_new = bandit_utils.sum_reward_weighted_observations( rewards_for_arm, observations_for_arm) return a_new, b_new def false_fn(): return tf.zeros([context_dim, context_dim]), tf.zeros([context_dim]) a_new, b_new = tf.cond( tf.squeeze(num_samples_for_arm_total) > 0, true_fn, false_fn) theta_new = tf.squeeze(tf.linalg.solve( a_new + tf.eye(context_dim), tf.expand_dims(b_new, axis=-1)), axis=-1) expected_a_updated_list.append(self.evaluate(a_new)) expected_b_updated_list.append(self.evaluate(b_new)) expected_theta_updated_list.append(self.evaluate(theta_new)) # Check that the actual updated estimates match the expectations. self.assertAllClose(expected_a_updated_list, final_a) self.assertAllClose(expected_b_updated_list, final_b)
def __init__(self, time_step_spec, action_spec, policy, collect_policy, train_sequence_length, num_outer_dims=2, train_argspec=None, debug_summaries=False, summarize_grads_and_vars=False, enable_summaries=True, train_step_counter=None): """Meant to be called by subclass constructors. Args: time_step_spec: A nest of tf.TypeSpec representing the time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.Base` representing the Agent's current policy. collect_policy: An instance of `tf_policy.Base` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. train_argspec: (Optional) Describes additional supported arguments to the `train` call. This must be a `dict` mapping strings to nests of specs. Overriding the `experience` arg is also supported. Some algorithms require additional arguments to the `train()` call, and while TF-Agents encourages most of these to be provided in the `policy_info` / `info` field of `experience`, sometimes the extra information doesn't fit well, i.e., when it doesn't come from the policy. **NOTE** kwargs will not have their outer dimensions validated. In particular, `train_sequence_length` is ignored for these inputs, and they may have any, or inconsistent, batch/time dimensions; only their inner shape dimensions are checked against `train_argspec`. Below is an example: ```python class MyAgent(TFAgent): def __init__(self, counterfactual_training, ...): collect_policy = ... train_argspec = None if counterfactual_training: train_argspec = dict( counterfactual=collect_policy.trajectory_spec) super(...).__init__( ... train_argspec=train_argspec) my_agent = MyAgent(...) for ...: experience, counterfactual = next(experience_and_counterfactual_iter) loss_info = my_agent.train(experience, counterfactual=counterfactual) ``` debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. enable_summaries: A bool; if false, subclasses should not gather any summaries (debug or otherwise); subclasses should gate *all* summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. Raises: TypeError: If `train_argspec` is not a `dict`. ValueError: If `train_argspec` has the keys `experience` or `weights`. TypeError: If any leaf nodes in `train_argspec` values are not subclasses of `tf.TypeSpec`. ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`. ValueError: If `num_outer_dims` is not in [1, 2]. """ def _each_isinstance(spec, spec_types): """Checks if each element of `spec` is instance of any of `spec_types`.""" return all( [isinstance(s, spec_types) for s in tf.nest.flatten(spec)]) if not _each_isinstance(time_step_spec, tf.TypeSpec): raise TypeError( "time_step_spec has to contain TypeSpec (TensorSpec, " "SparseTensorSpec, etc) objects, but received: {}".format( time_step_spec)) if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec): raise TypeError( "action_spec has to contain BoundedTensorSpec objects, but received: " "{}".format(action_spec)) common.check_tf1_allowed() common.tf_agents_gauge.get_cell("TFAgent").set(True) common.assert_members_are_not_overridden(base_cls=TFAgent, instance=self) if not isinstance(time_step_spec, ts.TimeStep): raise ValueError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._enable_summaries = enable_summaries if train_argspec is None: train_argspec = {} else: if not isinstance(train_argspec, dict): raise TypeError( "train_argspec must be a dict, but saw: {}".format( train_argspec)) train_argspec = dict(train_argspec) # Create a local copy. if "weights" in train_argspec or "experience" in train_argspec: raise ValueError( "train_argspec must not override 'weights' or " "'experience' keys, but saw: {}".format(train_argspec)) if not all( isinstance(x, tf.TypeSpec) for x in tf.nest.flatten(train_argspec)): raise TypeError( "train_argspec contains non-TensorSpec objects: {}".format( train_argspec)) self._train_argspec = train_argspec if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize)
def testActionBatchWithVariablesAndPolicyUpdate(self, batch_size, actions_from_reward_layer): a_list = [] a_new_list = [] b_list = [] b_new_list = [] num_samples_list = [] num_samples_new_list = [] for k in range(1, self._num_actions + 1): a_initial_value = k + 1 + 2 * k * tf.eye(self._encoding_dim, dtype=tf.float32) a_for_one_arm = tf.compat.v2.Variable(a_initial_value) a_list.append(a_for_one_arm) b_initial_value = tf.constant(k * np.ones(self._encoding_dim), dtype=tf.float32) b_for_one_arm = tf.compat.v2.Variable(b_initial_value) b_list.append(b_for_one_arm) num_samples_initial_value = tf.constant([1], dtype=tf.float32) num_samples_for_one_arm = tf.compat.v2.Variable( num_samples_initial_value) num_samples_list.append(num_samples_for_one_arm) # Variables for the new policy (they differ by an offset). a_new_for_one_arm = tf.compat.v2.Variable(a_initial_value + _POLICY_VARIABLES_OFFSET) a_new_list.append(a_new_for_one_arm) b_new_for_one_arm = tf.compat.v2.Variable(b_initial_value + _POLICY_VARIABLES_OFFSET) b_new_list.append(b_new_for_one_arm) num_samples_for_one_arm_new = tf.compat.v2.Variable( num_samples_initial_value + _POLICY_VARIABLES_OFFSET) num_samples_new_list.append(num_samples_for_one_arm_new) policy = neural_linucb_policy.NeuralLinUCBPolicy( encoding_network=DummyNet(), encoding_dim=self._encoding_dim, reward_layer=get_reward_layer(), actions_from_reward_layer=tf.constant(actions_from_reward_layer, dtype=tf.bool), cov_matrix=a_list, data_vector=b_list, num_samples=num_samples_list, epsilon_greedy=0.0, time_step_spec=self._time_step_spec) new_policy = neural_linucb_policy.NeuralLinUCBPolicy( encoding_network=DummyNet(), encoding_dim=self._encoding_dim, reward_layer=get_reward_layer(), actions_from_reward_layer=tf.constant(actions_from_reward_layer, dtype=tf.bool), cov_matrix=a_new_list, data_vector=b_new_list, num_samples=num_samples_new_list, epsilon_greedy=0.0, time_step_spec=self._time_step_spec) action_step = policy.action( self._time_step_batch(batch_size=batch_size)) new_action_step = new_policy.action( self._time_step_batch(batch_size=batch_size)) self.assertEqual(action_step.action.shape, new_action_step.action.shape) self.assertEqual(action_step.action.dtype, new_action_step.action.dtype) self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(new_policy.update(policy)) action_fn = common.function_in_tf1()(policy.action) action_step = action_fn(self._time_step_batch(batch_size=batch_size)) new_action_fn = common.function_in_tf1()(new_policy.action) new_action_step = new_action_fn( self._time_step_batch(batch_size=batch_size)) actions_, new_actions_ = self.evaluate( [action_step.action, new_action_step.action]) self.assertAllEqual(actions_, new_actions_)
def __init__(self, time_step_spec, action_spec, critic_network, actor_network, model_network, compressor_network, actor_optimizer, critic_optimizer, alpha_optimizer, model_optimizer, sequence_length, target_update_tau=1.0, target_update_period=1, td_errors_loss_fn=tf.math.squared_difference, gamma=1.0, reward_scale_factor=1.0, initial_log_alpha=0.0, target_entropy=None, gradient_clipping=None, trainable_model=True, critic_input='state', actor_input='state', critic_input_stop_gradient=True, actor_input_stop_gradient=False, model_batch_size=None, control_timestep=None, num_images_per_summary=1, debug_summaries=False, summarize_grads_and_vars=False, train_step_counter=None, name=None): tf.Module.__init__(self, name=name) self._critic_network1 = critic_network self._critic_network2 = critic_network.copy(name='CriticNetwork2') self._target_critic_network1 = critic_network.copy( name='TargetCriticNetwork1') self._target_critic_network2 = critic_network.copy( name='TargetCriticNetwork2') self._actor_network = actor_network self._model_network = model_network self._compressor_network = compressor_network policy = ActorSequencePolicy( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, model_network=self._model_network, compressor_network=self._compressor_network, sequence_length=sequence_length, actor_input=actor_input, control_timestep=control_timestep, num_images_per_summary=num_images_per_summary, debug_summaries=debug_summaries) self._log_alpha = common.create_variable( 'initial_log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) # If target_entropy was not passed, set it to negative of the total number # of action dimensions. if target_entropy is None: flat_action_spec = tf.nest.flatten(action_spec) target_entropy = -np.sum([ np.product(single_spec.shape.as_list()) for single_spec in flat_action_spec ]) self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._alpha_optimizer = alpha_optimizer self._model_optimizer = model_optimizer self._sequence_length = sequence_length self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_entropy = target_entropy self._gradient_clipping = gradient_clipping self._trainable_model = trainable_model self._critic_input = critic_input self._actor_input = actor_input self._critic_input_stop_gradient = critic_input_stop_gradient self._actor_input_stop_gradient = actor_input_stop_gradient self._model_batch_size = model_batch_size self._control_timestep = control_timestep self._num_images_per_summary = num_images_per_summary self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._update_target = self._get_target_updater( tau=self._target_update_tau, period=self._target_update_period) self._actor_time_step_spec = time_step_spec._replace( observation=actor_network.input_tensor_spec) super(SlacAgent, self).__init__(time_step_spec, action_spec, policy=policy, collect_policy=policy, train_sequence_length=sequence_length + 1, debug_summaries=debug_summaries, summarize_grads_and_vars=summarize_grads_and_vars, train_step_counter=train_step_counter) self._train_model_fn = common.function_in_tf1()(self._train_model)
def __init__(self, time_step_spec, action_spec, policy_state_spec=(), info_spec=(), clip=True, emit_log_probability=False, automatic_state_reset=True, name=None): """Initialization of Base class. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. Usually provided by the user to the subclass. action_spec: A nest of BoundedTensorSpec representing the actions. Usually provided by the user to the subclass. policy_state_spec: A nest of TensorSpec representing the policy_state. Provided by the subclass, not directly by the user. info_spec: A nest of TensorSpec representing the policy info. Provided by the subclass, not directly by the user. clip: Whether to clip actions to spec before returning them. Default True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped continuous actions for training. emit_log_probability: Emit log-probabilities of actions, if supported. If True, policy_step.info will have CommonFields.LOG_PROBABILITY set. Please consult utility methods provided in policy_step for setting and retrieving these. When working with custom policies, either provide a dictionary info_spec or a namedtuple with the field 'log_probability'. automatic_state_reset: If `True`, then `get_initial_policy_state` is used to clear state in `action()` and `distribution()` for for time steps where `time_step.is_first()`. name: A name for this module. Defaults to the class name. """ super(Base, self).__init__(name=name) common.assert_members_are_not_overridden(base_cls=Base, instance=self) if not isinstance(time_step_spec, ts.TimeStep): raise ValueError( 'The `time_step_spec` must be an instance of `TimeStep`, but is `{}`.' .format(type(time_step_spec))) self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy_state_spec = policy_state_spec self._emit_log_probability = emit_log_probability if emit_log_probability: log_probability_spec = tensor_spec.BoundedTensorSpec( shape=(), dtype=tf.float32, maximum=0, minimum=-float('inf'), name='log_probability') log_probability_spec = tf.nest.map_structure( lambda _: log_probability_spec, action_spec) info_spec = policy_step.set_log_probability( info_spec, log_probability_spec) self._info_spec = info_spec self._setup_specs() self._clip = clip self._action_fn = common.function_in_tf1()(self._action) self._automatic_state_reset = automatic_state_reset
def testNeuralLinUCBUpdateDistributed(self, batch_size=1, context_dim=10): """Same as above but with distributed LinUCB updates.""" # Construct a `Trajectory` for the given action, observation, reward. num_actions = 5 initial_step, final_step = _get_initial_and_final_steps( batch_size, context_dim) action = np.random.randint(num_actions, size=batch_size, dtype=np.int32) action_step = _get_action_step(action) experience = _get_experience(initial_step, action_step, final_step) # Construct an agent and perform the update. observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32) time_step_spec = time_step.time_step_spec(observation_spec) action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32, shape=(), minimum=0, maximum=num_actions - 1) encoder = DummyNet(observation_spec) encoding_dim = 10 agent = neural_linucb_agent.NeuralLinUCBAgent( time_step_spec=time_step_spec, action_spec=action_spec, encoding_network=encoder, encoding_network_num_train_steps=0, encoding_dim=encoding_dim, optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=1e-2)) self.evaluate(agent.initialize()) self.evaluate(tf.compat.v1.global_variables_initializer()) # Call the distributed LinUCB training instead of agent.train(). train_fn = common.function_in_tf1()( agent.compute_loss_using_linucb_distributed) reward = tf.cast(experience.reward, agent._dtype) loss_info = train_fn(experience.observation, action, reward, weights=None) self.evaluate(loss_info) final_a = self.evaluate(agent.cov_matrix) final_b = self.evaluate(agent.data_vector) # Compute the expected updated estimates. observations_list = tf.dynamic_partition( data=tf.reshape(tf.cast(experience.observation, tf.float64), [batch_size, context_dim]), partitions=tf.convert_to_tensor(action), num_partitions=num_actions) rewards_list = tf.dynamic_partition( data=tf.reshape(tf.cast(experience.reward, tf.float64), [batch_size]), partitions=tf.convert_to_tensor(action), num_partitions=num_actions) expected_a_updated_list = [] expected_b_updated_list = [] for _, (observations_for_arm, rewards_for_arm) in enumerate( zip(observations_list, rewards_list)): encoded_observations_for_arm, _ = encoder(observations_for_arm) encoded_observations_for_arm = tf.cast( encoded_observations_for_arm, dtype=tf.float64) num_samples_for_arm_current = tf.cast( tf.shape(rewards_for_arm)[0], tf.float64) num_samples_for_arm_total = num_samples_for_arm_current # pylint: disable=cell-var-from-loop def true_fn(): a_new = tf.matmul(encoded_observations_for_arm, encoded_observations_for_arm, transpose_a=True) b_new = bandit_utils.sum_reward_weighted_observations( rewards_for_arm, encoded_observations_for_arm) return a_new, b_new def false_fn(): return (tf.zeros([encoding_dim, encoding_dim], dtype=tf.float64), tf.zeros([encoding_dim], dtype=tf.float64)) a_new, b_new = tf.cond( tf.squeeze(num_samples_for_arm_total) > 0, true_fn, false_fn) expected_a_updated_list.append(self.evaluate(a_new)) expected_b_updated_list.append(self.evaluate(b_new)) # Check that the actual updated estimates match the expectations. self.assertAllClose(expected_a_updated_list, final_a) self.assertAllClose(expected_b_updated_list, final_b)
def __init__( self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, policy: tf_policy.TFPolicy, collect_policy: tf_policy.TFPolicy, train_sequence_length: Optional[int], num_outer_dims: int = 2, training_data_spec: Optional[types.NestedTensorSpec] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, enable_summaries: bool = True, train_step_counter: Optional[tf.Variable] = None): """Meant to be called by subclass constructors. Args: time_step_spec: A nest of tf.TypeSpec representing the time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.TFPolicy` representing the Agent's current policy. collect_policy: An instance of `tf_policy.TFPolicy` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. training_data_spec: A nest of TensorSpec specifying the structure of data the train() function expects. If None, defaults to the trajectory_spec of the collect_policy. debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. enable_summaries: A bool; if false, subclasses should not gather any summaries (debug or otherwise); subclasses should gate *all* summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. Raises: ValueError: If `num_outer_dims` is not in `[1, 2]`. """ common.check_tf1_allowed() common.tf_agents_gauge.get_cell("TFAgent").set(True) common.tf_agents_gauge.get_cell(str(type(self))).set(True) if not isinstance(time_step_spec, ts.TimeStep): raise TypeError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._enable_summaries = enable_summaries self._training_data_spec = training_data_spec # Data context for data collected directly from the collect policy. self._collect_data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=collect_policy.info_spec) # Data context for data passed to train(). May be different if # training_data_spec is provided. if training_data_spec is not None: # training_data_spec can be anything; so build a data_context # via best-effort with fall-backs to the collect data spec. training_discount_spec = getattr( training_data_spec, "discount", time_step_spec.discount) training_observation_spec = getattr( training_data_spec, "observation", time_step_spec.observation) training_reward_spec = getattr( training_data_spec, "reward", time_step_spec.reward) training_step_type_spec = getattr( training_data_spec, "step_type", time_step_spec.step_type) training_policy_info_spec = getattr( training_data_spec, "policy_info", collect_policy.info_spec) training_action_spec = getattr( training_data_spec, "action", action_spec) self._data_context = data_converter.DataContext( time_step_spec=ts.TimeStep( discount=training_discount_spec, observation=training_observation_spec, reward=training_reward_spec, step_type=training_step_type_spec), action_spec=training_action_spec, info_spec=training_policy_info_spec) else: self._data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=collect_policy.info_spec) if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize) self._preprocess_sequence_fn = common.function_in_tf1()( self._preprocess_sequence) self._loss_fn = common.function_in_tf1()(self._loss)
def __init__(self, time_step_spec, action_spec, policy, collect_policy, train_sequence_length, num_outer_dims=2, debug_summaries=False, summarize_grads_and_vars=False, disable_summaries=False, train_step_counter=None): """Meant to be called by subclass constructors. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.Base` representing the Agent's current policy. collect_policy: An instance of `tf_policy.Base` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. disable_summaries: A bool; if true, subclasses should not gather any summaries (debug or otherwise); subclasses should gate all summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. Raises: ValueError: If `time_step_spec` is not an instance of `ts.TimeStep`. ValueError: If `num_outer_dims` is not in [1, 2]. """ common.assert_members_are_not_overridden(base_cls=TFAgent, instance=self) if not isinstance(time_step_spec, ts.TimeStep): raise ValueError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._disable_summaries = disable_summaries if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize)
def testSparseObs(self, batch_size, actions_from_reward_layer): obs_spec = { 'global': { 'sport': tensor_spec.TensorSpec((), tf.string) }, 'per_arm': { 'name': tensor_spec.TensorSpec((3, ), tf.string), 'fruit': tensor_spec.TensorSpec((3, ), tf.string) } } columns_a = tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_vocabulary_list( 'name', ['bob', 'george', 'wanda'])) columns_b = tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_vocabulary_list( 'fruit', ['banana', 'kiwi', 'pear'])) columns_c = tf.feature_column.indicator_column( tf.feature_column.categorical_column_with_vocabulary_list( 'sport', ['bridge', 'chess', 'snooker'])) dummy_net = arm_network.create_feed_forward_common_tower_network( obs_spec, global_layers=(3, 4, 5), arm_layers=(3, 2), common_layers=(4, 3), output_dim=self._encoding_dim, global_preprocessing_combiner=( tf.compat.v2.keras.layers.DenseFeatures([columns_c])), arm_preprocessing_combiner=tf.compat.v2.keras.layers.DenseFeatures( [columns_a, columns_b])) time_step_spec = ts.time_step_spec(obs_spec) reward_layer = get_per_arm_reward_layer( encoding_dim=self._encoding_dim) policy = neural_linucb_policy.NeuralLinUCBPolicy( dummy_net, self._encoding_dim, reward_layer, actions_from_reward_layer=tf.constant(actions_from_reward_layer, dtype=tf.bool), cov_matrix=self._a[0:1], data_vector=self._b[0:1], num_samples=self._num_samples_per_arm[0:1], epsilon_greedy=0.0, time_step_spec=time_step_spec, accepts_per_arm_features=True, emit_policy_info=('predicted_rewards_mean', )) observations = { 'global': { 'sport': tf.constant(['snooker', 'chess']) }, 'per_arm': { 'name': tf.constant([['george', 'george', 'george'], ['bob', 'bob', 'bob']]), 'fruit': tf.constant([['banana', 'banana', 'banana'], ['kiwi', 'kiwi', 'kiwi']]) } } time_step = ts.restart(observations, batch_size=2) action_fn = common.function_in_tf1()(policy.action) action_step = action_fn(time_step, seed=1) self.assertEqual(action_step.action.shape.as_list(), [2]) self.assertEqual(action_step.action.dtype, tf.int32) # Initialize all variables self.evaluate([ tf.compat.v1.global_variables_initializer(), tf.compat.v1.tables_initializer() ]) action = self.evaluate(action_step.action) self.assertAllEqual(action.shape, [2]) p_info = self.evaluate(action_step.info) self.assertAllEqual(p_info.predicted_rewards_mean.shape, [2, 3]) self.assertAllEqual(p_info.chosen_arm_features['name'].shape, [2]) self.assertAllEqual(p_info.chosen_arm_features['fruit'].shape, [2]) first_action = action[0] first_arm_name_feature = observations[ bandit_spec_utils.PER_ARM_FEATURE_KEY]['name'][0] self.assertAllEqual(p_info.chosen_arm_features['name'][0], first_arm_name_feature[first_action])
def __init__(self, time_step_spec, action_spec, policy_state_spec=(), info_spec=(), clip=True, emit_log_probability=False, automatic_state_reset=True, observation_and_action_constraint_splitter=None, name=None): """Initialization of Base class. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. Usually provided by the user to the subclass. action_spec: A nest of BoundedTensorSpec representing the actions. Usually provided by the user to the subclass. policy_state_spec: A nest of TensorSpec representing the policy_state. Provided by the subclass, not directly by the user. info_spec: A nest of TensorSpec representing the policy info. Provided by the subclass, not directly by the user. clip: Whether to clip actions to spec before returning them. Default True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped continuous actions for training. emit_log_probability: Emit log-probabilities of actions, if supported. If True, policy_step.info will have CommonFields.LOG_PROBABILITY set. Please consult utility methods provided in policy_step for setting and retrieving these. When working with custom policies, either provide a dictionary info_spec or a namedtuple with the field 'log_probability'. automatic_state_reset: If `True`, then `get_initial_policy_state` is used to clear state in `action()` and `distribution()` for for time steps where `time_step.is_first()`. observation_and_action_constraint_splitter: A function used to process observations with action constraints. These constraints can indicate, for example, a mask of valid/invalid actions for a given state of the environment. The function takes in a full observation and returns a tuple consisting of 1) the part of the observation intended as input to the network and 2) the constraint. An example `observation_and_action_constraint_splitter` could be as simple as: ``` def observation_and_action_constraint_splitter(observation): return observation['network_input'], observation['constraint'] ``` *Note*: when using `observation_and_action_constraint_splitter`, make sure the provided `q_network` is compatible with the network-specific half of the output of the `observation_and_action_constraint_splitter`. In particular, `observation_and_action_constraint_splitter` will be called on the observation before passing to the network. If `observation_and_action_constraint_splitter` is None, action constraints are not applied. name: A name for this module. Defaults to the class name. """ super(Base, self).__init__(name=name) common.tf_agents_gauge.get_cell('TFAPolicy').set(True) common.assert_members_are_not_overridden(base_cls=Base, instance=self) if not isinstance(time_step_spec, ts.TimeStep): raise ValueError( 'The `time_step_spec` must be an instance of `TimeStep`, but is `{}`.' .format(type(time_step_spec))) self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy_state_spec = policy_state_spec self._emit_log_probability = emit_log_probability if emit_log_probability: log_probability_spec = tensor_spec.BoundedTensorSpec( shape=(), dtype=tf.float32, maximum=0, minimum=-float('inf'), name='log_probability') log_probability_spec = tf.nest.map_structure( lambda _: log_probability_spec, action_spec) info_spec = policy_step.set_log_probability( info_spec, log_probability_spec) self._info_spec = info_spec self._setup_specs() self._clip = clip self._action_fn = common.function_in_tf1()(self._action) self._automatic_state_reset = automatic_state_reset self._observation_and_action_constraint_splitter = ( observation_and_action_constraint_splitter)
def testPerArmObservation(self, batch_size, actions_from_reward_layer): global_obs_dim = 7 arm_obs_dim = 3 obs_spec = bandit_spec_utils.create_per_arm_observation_spec( global_obs_dim, arm_obs_dim, self._num_actions, add_num_actions_feature=True) time_step_spec = ts.time_step_spec(obs_spec) dummy_net = arm_network.create_feed_forward_common_tower_network( obs_spec, global_layers=(3, 4, 5), arm_layers=(3, 2), common_layers=(4, 3), output_dim=self._encoding_dim) reward_layer = get_per_arm_reward_layer( encoding_dim=self._encoding_dim) policy = neural_linucb_policy.NeuralLinUCBPolicy( dummy_net, self._encoding_dim, reward_layer, actions_from_reward_layer=tf.constant(actions_from_reward_layer, dtype=tf.bool), cov_matrix=self._a[0:1], data_vector=self._b[0:1], num_samples=self._num_samples_per_arm[0:1], epsilon_greedy=0.0, time_step_spec=time_step_spec, accepts_per_arm_features=True, emit_policy_info=('predicted_rewards_mean', 'predicted_rewards_optimistic')) current_time_step = self._per_arm_time_step_batch( batch_size=batch_size, global_obs_dim=global_obs_dim, arm_obs_dim=arm_obs_dim) action_step = policy.action(current_time_step) self.assertEqual(action_step.action.dtype, tf.int32) self.evaluate(tf.compat.v1.global_variables_initializer()) action_fn = common.function_in_tf1()(policy.action) action_step = action_fn(current_time_step) input_observation = current_time_step.observation encoded_observation, _ = dummy_net(input_observation) if actions_from_reward_layer: predicted_rewards_from_reward_layer = reward_layer( encoded_observation) predicted_rewards_expected = self.evaluate( predicted_rewards_from_reward_layer).reshape( (-1, self._num_actions)) else: observation_numpy = self.evaluate(encoded_observation) predicted_rewards_expected = ( self._get_predicted_rewards_from_per_arm_linucb( observation_numpy, batch_size)) p_info = self.evaluate(action_step.info) self.assertEqual(p_info.predicted_rewards_mean.dtype, np.float32) self.assertAllClose(p_info.predicted_rewards_mean, predicted_rewards_expected) self.assertAllGreaterEqual( p_info.predicted_rewards_optimistic - predicted_rewards_expected, 0)
def __init__(self, # counter train_step_counter, # specs time_step_spec, action_spec, # networks critic_network, actor_network, model_network, compressor_network, # optimizers actor_optimizer, critic_optimizer, alpha_optimizer, model_optimizer, # target update target_update_tau=1.0, target_update_period=1, # inputs and stop gradients critic_input='state', actor_input='state', critic_input_stop_gradient=True, actor_input_stop_gradient=False, # model stuff model_batch_size=256, # will round to nearest full trajectory ac_batch_size=128, # other episodes_per_trial = 1, num_tasks_per_train=1, num_batches_per_sampled_trials=1, td_errors_loss_fn=tf.math.squared_difference, gamma=1.0, reward_scale_factor=1.0, task_reward_dim=None, initial_log_alpha=0.0, target_entropy=None, gradient_clipping=None, control_timestep=None, num_images_per_summary=1, offline_ratio=None, override_reward_func=None, ): tf.Module.__init__(self) self.override_reward_func = override_reward_func self.offline_ratio = offline_ratio ################ # critic ################ # networks self._critic_network1 = critic_network self._critic_network2 = critic_network.copy(name='CriticNetwork2') self._target_critic_network1 = critic_network.copy(name='TargetCriticNetwork1') self._target_critic_network2 = critic_network.copy(name='TargetCriticNetwork2') # update the target networks self._target_update_tau = target_update_tau self._target_update_period = target_update_period self._update_target = self._get_target_updater(tau=self._target_update_tau, period=self._target_update_period) ################ # model ################ self._model_network = model_network self.model_input = self._model_network.model_input ################ # compressor ################ self._compressor_network = compressor_network ################ # actor ################ self._actor_network = actor_network ################ # policies ################ self.condition_on_full_latent_dist = (actor_input=="latentDistribution" and critic_input=="latentDistribution") # both policies below share the same actor network # but they process latents (to give to actor network) in potentially different ways # used for eval which_posterior='first' if self._model_network.sparse_reward_inputs: which_rew_input='sparse' else: which_rew_input='dense' policy = MeldPolicy( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, model_network=self._model_network, actor_input=actor_input, which_posterior=which_posterior, which_rew_input=which_rew_input, ) # used for collecting data during training # overwrite if specified (eg for double agent) which_posterior='first' if self._model_network.sparse_reward_inputs: which_rew_input='sparse' else: which_rew_input='dense' collect_policy = MeldPolicy( time_step_spec=time_step_spec, action_spec=action_spec, actor_network=self._actor_network, model_network=self._model_network, actor_input=actor_input, which_posterior=which_posterior, which_rew_input=which_rew_input, ) ################ # more vars ################ self.num_batches_per_sampled_trials = num_batches_per_sampled_trials self.episodes_per_trial = episodes_per_trial self._task_reward_dim = task_reward_dim self._log_alpha = common.create_variable( 'initial_log_alpha', initial_value=initial_log_alpha, dtype=tf.float32, trainable=True) # If target_entropy was not passed, set it to negative of the total number # of action dimensions. if target_entropy is None: flat_action_spec = tf.nest.flatten(action_spec) target_entropy = -np.sum([ np.product(single_spec.shape.as_list()) for single_spec in flat_action_spec ]) self._actor_optimizer = actor_optimizer self._critic_optimizer = critic_optimizer self._alpha_optimizer = alpha_optimizer self._model_optimizer = model_optimizer self._td_errors_loss_fn = td_errors_loss_fn self._gamma = gamma self._reward_scale_factor = reward_scale_factor self._target_entropy = target_entropy self._gradient_clipping = gradient_clipping self._critic_input = critic_input self._actor_input = actor_input self._critic_input_stop_gradient = critic_input_stop_gradient self._actor_input_stop_gradient = actor_input_stop_gradient self._model_batch_size = model_batch_size self._ac_batch_size = ac_batch_size self._control_timestep = control_timestep self._num_images_per_summary = num_images_per_summary self._actor_time_step_spec = time_step_spec._replace(observation=actor_network.input_tensor_spec) self._num_tasks_per_train = num_tasks_per_train ################ # init tf agent ################ super(MeldAgent, self).__init__( time_step_spec, action_spec, policy=policy, collect_policy=collect_policy, #used to set self.step_spec train_sequence_length=None, #train function can accept experience of any length T (i.e., [B,T,...]) train_step_counter=train_step_counter) self._train_model_fn = common.function_in_tf1()(self._train_model) self._train_ac_fn = common.function_in_tf1()(self._train_ac)
def __init__(self, time_step_spec: ts.TimeStep, action_spec: types.NestedTensorSpec, policy: tf_policy.TFPolicy, collect_policy: tf_policy.TFPolicy, train_sequence_length: Optional[int], num_outer_dims: int = 2, training_data_spec: Optional[types.NestedTensorSpec] = None, train_argspec: Optional[Dict[Text, types.NestedTensorSpec]] = None, debug_summaries: bool = False, summarize_grads_and_vars: bool = False, enable_summaries: bool = True, train_step_counter: Optional[tf.Variable] = None, validate_args: bool = True): """Meant to be called by subclass constructors. Args: time_step_spec: A nest of tf.TypeSpec representing the time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.TFPolicy` representing the Agent's current policy. collect_policy: An instance of `tf_policy.TFPolicy` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. num_outer_dims: The number of outer dimensions for the agent. Must be either 1 or 2. If 2, training will require both a batch_size and time dimension on every Tensor; if 1, training will require only a batch_size outer dimension. training_data_spec: A nest of TensorSpec specifying the structure of data the train() function expects. If None, defaults to the trajectory_spec of the collect_policy. train_argspec: (Optional) Describes additional supported arguments to the `train` call. This must be a `dict` mapping strings to nests of specs. Overriding the `experience` arg is also supported. Some algorithms require additional arguments to the `train()` call, and while TF-Agents encourages most of these to be provided in the `policy_info` / `info` field of `experience`, sometimes the extra information doesn't fit well, i.e., when it doesn't come from the policy. **NOTE** kwargs will not have their outer dimensions validated. In particular, `train_sequence_length` is ignored for these inputs, and they may have any, or inconsistent, batch/time dimensions; only their inner shape dimensions are checked against `train_argspec`. Below is an example: ```python class MyAgent(TFAgent): def __init__(self, counterfactual_training, ...): collect_policy = ... train_argspec = None if counterfactual_training: train_argspec = dict( counterfactual=collect_policy.trajectory_spec) super(...).__init__( ... train_argspec=train_argspec) my_agent = MyAgent(...) for ...: experience, counterfactual = next(experience_and_counterfactual_iter) loss_info = my_agent.train(experience, counterfactual=counterfactual) ``` debug_summaries: A bool; if true, subclasses should gather debug summaries. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. enable_summaries: A bool; if false, subclasses should not gather any summaries (debug or otherwise); subclasses should gate *all* summaries using either `summaries_enabled`, `debug_summaries`, or `summarize_grads_and_vars` properties. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. validate_args: Python bool. Whether to verify inputs to, and outputs of, functions like `train` and `preprocess_sequence` against spec structures, dtypes, and shapes. Research code may prefer to set this value to `False` to allow iterating on input and output structures without being hamstrung by overly rigid checking (at the cost of harder-to-debug errors). See also `TFPolicy.validate_args`. Raises: TypeError: If `validate_args is True` and `train_argspec` is not a `dict`. ValueError: If `validate_args is True` and `train_argspec` has the keys `experience` or `weights`. TypeError: If `validate_args is True` and any leaf nodes in `train_argspec` values are not subclasses of `tf.TypeSpec`. ValueError: If `validate_args is True` and `time_step_spec` is not an instance of `ts.TimeStep`. ValueError: If `num_outer_dims` is not in `[1, 2]`. """ if validate_args: def _each_isinstance(spec, spec_types): """Checks if each element of `spec` is instance of `spec_types`.""" return all( [isinstance(s, spec_types) for s in tf.nest.flatten(spec)]) if not _each_isinstance(time_step_spec, tf.TypeSpec): raise TypeError( "time_step_spec has to contain TypeSpec (TensorSpec, " "SparseTensorSpec, etc) objects, but received: {}".format( time_step_spec)) if not _each_isinstance(action_spec, tensor_spec.BoundedTensorSpec): raise TypeError( "action_spec has to contain BoundedTensorSpec objects, but received: " "{}".format(action_spec)) common.check_tf1_allowed() common.tf_agents_gauge.get_cell("TFAgent").set(True) common.tf_agents_gauge.get_cell(str(type(self))).set(True) if not isinstance(time_step_spec, ts.TimeStep): raise TypeError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) if num_outer_dims not in [1, 2]: raise ValueError("num_outer_dims must be in [1, 2].") self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy self._collect_policy = collect_policy self._train_sequence_length = train_sequence_length self._num_outer_dims = num_outer_dims self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars self._enable_summaries = enable_summaries self._training_data_spec = training_data_spec self._validate_args = validate_args # Data context for data collected directly from the collect policy. self._collect_data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=collect_policy.info_spec) # Data context for data passed to train(). May be different if # training_data_spec is provided. if training_data_spec is not None: # training_data_spec can be anything; so build a data_context # via best-effort with fall-backs to the collect data spec. training_discount_spec = getattr(training_data_spec, "discount", time_step_spec.discount) training_observation_spec = getattr(training_data_spec, "observation", time_step_spec.observation) training_reward_spec = getattr(training_data_spec, "reward", time_step_spec.reward) training_step_type_spec = getattr(training_data_spec, "step_type", time_step_spec.step_type) training_policy_info_spec = getattr(training_data_spec, "policy_info", collect_policy.info_spec) training_action_spec = getattr(training_data_spec, "action", action_spec) self._data_context = data_converter.DataContext( time_step_spec=ts.TimeStep( discount=training_discount_spec, observation=training_observation_spec, reward=training_reward_spec, step_type=training_step_type_spec), action_spec=training_action_spec, info_spec=training_policy_info_spec) else: self._data_context = data_converter.DataContext( time_step_spec=time_step_spec, action_spec=action_spec, info_spec=collect_policy.info_spec) if train_argspec is None: train_argspec = {} elif validate_args: if not isinstance(train_argspec, dict): raise TypeError( "train_argspec must be a dict, but saw: {}".format( train_argspec)) if "weights" in train_argspec or "experience" in train_argspec: raise ValueError( "train_argspec must not override 'weights' or " "'experience' keys, but saw: {}".format(train_argspec)) if not all( isinstance(x, tf.TypeSpec) for x in tf.nest.flatten(train_argspec)): raise TypeError( "train_argspec contains non-TensorSpec objects: {}".format( train_argspec)) train_argspec = dict(train_argspec) # Create a local copy. self._train_argspec = train_argspec if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize) self._preprocess_sequence_fn = common.function_in_tf1()( self._preprocess_sequence) self._loss_fn = common.function_in_tf1()(self._loss)
def __init__(self, time_step_spec, action_spec, policy, collect_policy, train_sequence_length, update_period=None, debug_summaries=False, enable_functions=True, summarize_grads_and_vars=False, train_step_counter=None): """Meant to be called by subclass constructors. Args: time_step_spec: A `TimeStep` spec of the expected time_steps. Provided by the user. action_spec: A nest of BoundedTensorSpec representing the actions. Provided by the user. policy: An instance of `tf_policy.Base` representing the Agent's current policy. collect_policy: An instance of `tf_policy.Base` representing the Agent's current data collection policy (used to set `self.step_spec`). train_sequence_length: A python integer or `None`, signifying the number of time steps required from tensors in `experience` as passed to `train()`. All tensors in `experience` will be shaped `[B, T, ...]` but for certain agents, `T` should be fixed. For example, DQN requires transitions in the form of 2 time steps, so for a non-RNN DQN Agent, set this value to 2. For agents that don't care, or which can handle `T` unknown at graph build time (i.e. most RNN-based agents), set this argument to `None`. update_period: Update period. debug_summaries: A bool; if true, subclasses should gather debug summaries. enable_functions: A bool; if true, enable functions. summarize_grads_and_vars: A bool; if true, subclasses should additionally collect gradient and variable summaries. train_step_counter: An optional counter to increment every time the train op is run. Defaults to the global_step. """ common.assert_members_are_not_overridden(base_cls=TFAgent, instance=self) if not isinstance(time_step_spec, ts.TimeStep): raise ValueError( "The `time_step_spec` must be an instance of `TimeStep`, but is `{}`." .format(type(time_step_spec))) self._time_step_spec = time_step_spec self._action_spec = action_spec self._policy = policy # self._py_policy = py_tf_policy.PyTFPolicy(policy) self._collect_policy = collect_policy # self._collect_py_policy = py_tf_policy.PyTFPolicy(collect_policy) self._train_sequence_length = train_sequence_length self.update_period = update_period self._debug_summaries = debug_summaries self._summarize_grads_and_vars = summarize_grads_and_vars if train_step_counter is None: train_step_counter = tf.compat.v1.train.get_or_create_global_step() self._train_step_counter = train_step_counter self._train_fn = common.function_in_tf1()(self._train) self._initialize_fn = common.function_in_tf1()(self._initialize) self._enable_functions = enable_functions
def train_eval( root_dir, load_root_dir=None, env_load_fn=None, gym_env_wrappers=[], monitor=False, env_name=None, agent_class=None, initial_collect_driver_class=None, collect_driver_class=None, online_driver_class=dynamic_episode_driver.DynamicEpisodeDriver, num_global_steps=1000000, rb_size=None, train_steps_per_iteration=1, train_metrics=None, eval_metrics=None, train_metrics_callback=None, # SacAgent args actor_fc_layers=(256, 256), critic_joint_fc_layers=(256, 256), # Safety Critic training args sc_rb_size=None, target_safety=None, train_sc_steps=10, train_sc_interval=1000, online_critic=False, n_envs=None, finetune_sc=False, pretraining=True, lambda_schedule_nsteps=0, lambda_initial=0., lambda_final=1., kstep_fail=0, # Ensemble Critic training args num_critics=None, critic_learning_rate=3e-4, # Wcpg Critic args critic_preprocessing_layer_size=256, # Params for train batch_size=256, # Params for eval run_eval=False, num_eval_episodes=10, eval_interval=1000, # Params for summaries and logging train_checkpoint_interval=10000, policy_checkpoint_interval=5000, rb_checkpoint_interval=50000, keep_rb_checkpoint=False, log_interval=1000, summary_interval=1000, monitor_interval=5000, summaries_flush_secs=10, early_termination_fn=None, debug_summaries=False, seed=None, eager_debug=False, env_metric_factories=None, wandb=False): # pylint: disable=unused-argument """train and eval script for SQRL.""" if isinstance(agent_class, str): assert agent_class in ALGOS, 'trainer.train_eval: agent_class {} invalid'.format(agent_class) agent_class = ALGOS.get(agent_class) n_envs = n_envs or num_eval_episodes root_dir = os.path.expanduser(root_dir) train_dir = os.path.join(root_dir, 'train') # =====================================================================# # Setup summary metrics, file writers, and create env # # =====================================================================# train_summary_writer = tf.compat.v2.summary.create_file_writer( train_dir, flush_millis=summaries_flush_secs * 1000) train_summary_writer.set_as_default() train_metrics = train_metrics or [] eval_metrics = eval_metrics or [] updating_sc = online_critic and (not load_root_dir or finetune_sc) logging.debug('updating safety critic: %s', updating_sc) if seed: tf.compat.v1.set_random_seed(seed) if agent_class in SAFETY_AGENTS: if online_critic: sc_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * n_envs )) if seed: seeds = [seed * n_envs + i for i in range(n_envs)] try: sc_tf_env.pyenv.seed(seeds) except: pass if run_eval: eval_dir = os.path.join(root_dir, 'eval') eval_summary_writer = tf.compat.v2.summary.create_file_writer( eval_dir, flush_millis=summaries_flush_secs * 1000) eval_metrics = [ tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes, batch_size=n_envs), tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes, batch_size=n_envs), ] + [tf_py_metric.TFPyMetric(m) for m in eval_metrics] eval_tf_env = tf_py_environment.TFPyEnvironment( parallel_py_environment.ParallelPyEnvironment( [lambda: env_load_fn(env_name)] * n_envs )) if seed: try: for i, pyenv in enumerate(eval_tf_env.pyenv.envs): pyenv.seed(seed * n_envs + i) except: pass elif 'Drunk' in env_name: # Just visualizes trajectories in drunk spider environment eval_tf_env = tf_py_environment.TFPyEnvironment( env_load_fn(env_name)) else: eval_tf_env = None if monitor: vid_path = os.path.join(root_dir, 'rollouts') monitor_env_wrapper = misc.monitor_freq(1, vid_path) monitor_env = gym.make(env_name) for wrapper in gym_env_wrappers: monitor_env = wrapper(monitor_env) monitor_env = monitor_env_wrapper(monitor_env) # auto_reset must be False to ensure Monitor works correctly monitor_py_env = gym_wrapper.GymWrapper(monitor_env, auto_reset=False) global_step = tf.compat.v1.train.get_or_create_global_step() with tf.summary.record_if( lambda: tf.math.equal(global_step % summary_interval, 0)): py_env = env_load_fn(env_name) tf_env = tf_py_environment.TFPyEnvironment(py_env) if seed: try: for i, pyenv in enumerate(tf_env.pyenv.envs): pyenv.seed(seed * n_envs + i) except: pass time_step_spec = tf_env.time_step_spec() observation_spec = time_step_spec.observation action_spec = tf_env.action_spec() logging.debug('obs spec: %s', observation_spec) logging.debug('action spec: %s', action_spec) # =====================================================================# # Setup agent class # # =====================================================================# if agent_class == wcpg_agent.WcpgAgent: alpha_spec = tensor_spec.BoundedTensorSpec(shape=(1,), dtype=tf.float32, minimum=0., maximum=1., name='alpha') input_tensor_spec = (observation_spec, action_spec, alpha_spec) critic_net = agents.DistributionalCriticNetwork( input_tensor_spec, preprocessing_layer_size=critic_preprocessing_layer_size, joint_fc_layer_params=critic_joint_fc_layers) actor_net = agents.WcpgActorNetwork((observation_spec, alpha_spec), action_spec) else: actor_net = actor_distribution_network.ActorDistributionNetwork( observation_spec, action_spec, fc_layer_params=actor_fc_layers, continuous_projection_net=agents.normal_projection_net) critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) if agent_class in SAFETY_AGENTS: logging.debug('Making SQRL agent') if lambda_schedule_nsteps > 0: lambda_update_every_nsteps = num_global_steps // lambda_schedule_nsteps step_size = (lambda_final - lambda_initial) / lambda_update_every_nsteps lambda_scheduler = lambda lam: common.periodically( body=lambda: tf.group(lam.assign(lam + step_size)), period=lambda_update_every_nsteps) else: lambda_scheduler = None safety_critic_net = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers) ts = target_safety thresholds = [ts, 0.5] sc_metrics = [tf.keras.metrics.AUC(name='safety_critic_auc'), tf.keras.metrics.TruePositives(name='safety_critic_tp', thresholds=thresholds), tf.keras.metrics.FalsePositives(name='safety_critic_fp', thresholds=thresholds), tf.keras.metrics.TrueNegatives(name='safety_critic_tn', thresholds=thresholds), tf.keras.metrics.FalseNegatives(name='safety_critic_fn', thresholds=thresholds), tf.keras.metrics.BinaryAccuracy(name='safety_critic_acc', threshold=0.5)] tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, safety_critic_network=safety_critic_net, train_step_counter=global_step, debug_summaries=debug_summaries, safety_pretraining=pretraining, train_critic_online=online_critic, initial_log_lambda=lambda_initial, log_lambda=(lambda_scheduler is None), lambda_scheduler=lambda_scheduler) elif agent_class is ensemble_sac_agent.EnsembleSacAgent: critic_nets, critic_optimizers = [critic_net], [tf.keras.optimizers.Adam(critic_learning_rate)] for _ in range(num_critics - 1): critic_nets.append(agents.CriticNetwork((observation_spec, action_spec), joint_fc_layer_params=critic_joint_fc_layers)) critic_optimizers.append(tf.keras.optimizers.Adam(critic_learning_rate)) tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_networks=critic_nets, critic_optimizers=critic_optimizers, debug_summaries=debug_summaries ) else: # agent is either SacAgent or WcpgAgent logging.debug('critic input_tensor_spec: %s', critic_net.input_tensor_spec) tf_agent = agent_class( time_step_spec, action_spec, actor_network=actor_net, critic_network=critic_net, train_step_counter=global_step, debug_summaries=debug_summaries) tf_agent.initialize() # =====================================================================# # Setup replay buffer # # =====================================================================# collect_data_spec = tf_agent.collect_data_spec logging.debug('Allocating replay buffer ...') # Add to replay buffer and other agent specific observers. rb_size = rb_size or 1000000 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=rb_size) logging.debug('RB capacity: %i', replay_buffer.capacity) logging.debug('ReplayBuffer Collect data spec: %s', collect_data_spec) if agent_class in SAFETY_AGENTS: sc_rb_size = sc_rb_size or num_eval_episodes * 500 sc_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( collect_data_spec, batch_size=1, max_length=sc_rb_size, dataset_window_shift=1) num_episodes = tf_metrics.NumberOfEpisodes() num_env_steps = tf_metrics.EnvironmentSteps() return_metric = tf_metrics.AverageReturnMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size) train_metrics = [ num_episodes, num_env_steps, return_metric, tf_metrics.AverageEpisodeLengthMetric( buffer_size=num_eval_episodes, batch_size=tf_env.batch_size), ] + [tf_py_metric.TFPyMetric(m) for m in train_metrics] if 'Minitaur' in env_name and not pretraining: goal_vel = gin.query_parameter("%GOAL_VELOCITY") early_termination_fn = train_utils.MinitaurTerminationFn( speed_metric=train_metrics[-2], total_falls_metric=train_metrics[-3], env_steps_metric=num_env_steps, goal_speed=goal_vel) if env_metric_factories: for env_metric in env_metric_factories: train_metrics.append(tf_py_metric.TFPyMetric(env_metric(tf_env.pyenv.envs))) if run_eval: eval_metrics.append(env_metric([env for env in eval_tf_env.pyenv._envs])) # =====================================================================# # Setup collect policies # # =====================================================================# if not online_critic: eval_policy = tf_agent.policy collect_policy = tf_agent.collect_policy if not pretraining and agent_class in SAFETY_AGENTS: collect_policy = tf_agent.safe_policy else: eval_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy collect_policy = tf_agent.collect_policy if pretraining else tf_agent.safe_policy online_collect_policy = tf_agent.safe_policy # if pretraining else tf_agent.collect_policy if pretraining: online_collect_policy._training = False if not load_root_dir: initial_collect_policy = random_tf_policy.RandomTFPolicy(time_step_spec, action_spec) else: initial_collect_policy = collect_policy if agent_class == wcpg_agent.WcpgAgent: initial_collect_policy = agents.WcpgPolicyWrapper(initial_collect_policy) # =====================================================================# # Setup Checkpointing # # =====================================================================# train_checkpointer = common.Checkpointer( ckpt_dir=train_dir, agent=tf_agent, global_step=global_step, metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics')) policy_checkpointer = common.Checkpointer( ckpt_dir=os.path.join(train_dir, 'policy'), policy=eval_policy, global_step=global_step) rb_ckpt_dir = os.path.join(train_dir, 'replay_buffer') rb_checkpointer = common.Checkpointer( ckpt_dir=rb_ckpt_dir, max_to_keep=1, replay_buffer=replay_buffer) if online_critic: online_rb_ckpt_dir = os.path.join(train_dir, 'online_replay_buffer') online_rb_checkpointer = common.Checkpointer( ckpt_dir=online_rb_ckpt_dir, max_to_keep=1, replay_buffer=sc_buffer) # loads agent, replay buffer, and online sc/buffer if online_critic if load_root_dir: load_root_dir = os.path.expanduser(load_root_dir) load_train_dir = os.path.join(load_root_dir, 'train') misc.load_agent_ckpt(load_train_dir, tf_agent) if len(os.listdir(os.path.join(load_train_dir, 'replay_buffer'))) > 1: load_rb_ckpt_dir = os.path.join(load_train_dir, 'replay_buffer') misc.load_rb_ckpt(load_rb_ckpt_dir, replay_buffer) if online_critic: load_online_sc_ckpt_dir = os.path.join(load_root_dir, 'sc') load_online_rb_ckpt_dir = os.path.join(load_train_dir, 'online_replay_buffer') if osp.exists(load_online_rb_ckpt_dir): misc.load_rb_ckpt(load_online_rb_ckpt_dir, sc_buffer) if osp.exists(load_online_sc_ckpt_dir): misc.load_safety_critic_ckpt(load_online_sc_ckpt_dir, safety_critic_net) elif agent_class in SAFETY_AGENTS: offline_run = sorted(os.listdir(os.path.join(load_train_dir, 'offline')))[-1] load_sc_ckpt_dir = os.path.join(load_train_dir, 'offline', offline_run, 'safety_critic') if osp.exists(load_sc_ckpt_dir): sc_net_off = agents.CriticNetwork( (observation_spec, action_spec), joint_fc_layer_params=(512, 512), name='SafetyCriticOffline') sc_net_off.create_variables() target_sc_net_off = common.maybe_copy_target_network_with_checks( sc_net_off, None, 'TargetSafetyCriticNetwork') sc_optimizer = tf.keras.optimizers.Adam(critic_learning_rate) _ = misc.load_safety_critic_ckpt( load_sc_ckpt_dir, safety_critic_net=sc_net_off, target_safety_critic=target_sc_net_off, optimizer=sc_optimizer) tf_agent._safety_critic_network = sc_net_off tf_agent._target_safety_critic_network = target_sc_net_off tf_agent._safety_critic_optimizer = sc_optimizer else: train_checkpointer.initialize_or_restore() rb_checkpointer.initialize_or_restore() if online_critic: online_rb_checkpointer.initialize_or_restore() if agent_class in SAFETY_AGENTS: sc_dir = os.path.join(root_dir, 'sc') safety_critic_checkpointer = common.Checkpointer( ckpt_dir=sc_dir, safety_critic=tf_agent._safety_critic_network, # pylint: disable=protected-access target_safety_critic=tf_agent._target_safety_critic_network, optimizer=tf_agent._safety_critic_optimizer, global_step=global_step) if not (load_root_dir and not online_critic): safety_critic_checkpointer.initialize_or_restore() agent_observers = [replay_buffer.add_batch] + train_metrics collect_driver = collect_driver_class( tf_env, collect_policy, observers=agent_observers) collect_driver.run = common.function_in_tf1()(collect_driver.run) if online_critic: logging.debug('online driver class: %s', online_driver_class) online_agent_observers = [num_episodes, num_env_steps, sc_buffer.add_batch] online_driver = online_driver_class( sc_tf_env, online_collect_policy, observers=online_agent_observers, num_episodes=num_eval_episodes) online_driver.run = common.function_in_tf1()(online_driver.run) if eager_debug: tf.config.experimental_run_functions_eagerly(True) else: config_saver = gin.tf.GinConfigSaverHook(train_dir, summarize_config=True) tf.function(config_saver.after_create_session)() if global_step == 0: logging.info('Performing initial collection ...') init_collect_observers = agent_observers if agent_class in SAFETY_AGENTS: init_collect_observers += [sc_buffer.add_batch] initial_collect_driver_class( tf_env, initial_collect_policy, observers=init_collect_observers).run() last_id = replay_buffer._get_last_id() # pylint: disable=protected-access logging.info('Data saved after initial collection: %d steps', last_id) if agent_class in SAFETY_AGENTS: last_id = sc_buffer._get_last_id() # pylint: disable=protected-access logging.debug('Data saved in sc_buffer after initial collection: %d steps', last_id) if run_eval: results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(results, global_step.numpy()) metric_utils.log_metrics(eval_metrics) time_step = None policy_state = collect_policy.get_initial_state(tf_env.batch_size) timed_at_step = global_step.numpy() time_acc = 0 train_step = train_utils.get_train_step(tf_agent, replay_buffer, batch_size) if agent_class in SAFETY_AGENTS: critic_train_step = train_utils.get_critic_train_step( tf_agent, replay_buffer, sc_buffer, batch_size=batch_size, updating_sc=updating_sc, metrics=sc_metrics) if early_termination_fn is None: early_termination_fn = lambda: False loss_diverged = False # How many consecutive steps was loss diverged for. loss_divergence_counter = 0 mean_train_loss = tf.keras.metrics.Mean(name='mean_train_loss') if agent_class in SAFETY_AGENTS: resample_counter = collect_policy._resample_counter mean_resample_ac = tf.keras.metrics.Mean(name='mean_unsafe_ac_freq') sc_metrics.append(mean_resample_ac) if online_critic: logging.debug('starting safety critic pretraining') # don't fine-tune safety critic if global_step.numpy() == 0: for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())] for critic_metric in sc_metrics: res = critic_metric.result().numpy() if not res.shape: critic_results.append((critic_metric.name, res)) else: for r, thresh in zip(res, thresholds): name = '_'.join([critic_metric.name, str(thresh)]) critic_results.append((name, r)) critic_metric.reset_states() if train_metrics_callback: train_metrics_callback(collections.OrderedDict(critic_results), step=global_step.numpy()) logging.debug('Starting main train loop...') curr_ep = [] global_step_val = global_step.numpy() while global_step_val <= num_global_steps and not early_termination_fn(): start_time = time.time() # MEASURE ACTION RESAMPLING FREQUENCY if agent_class in SAFETY_AGENTS: if pretraining and global_step_val == num_global_steps // 2: if online_critic: online_collect_policy._training = True collect_policy._training = True if online_critic or collect_policy._training: mean_resample_ac(resample_counter.result()) resample_counter.reset() if time_step is None or time_step.is_last(): resample_ac_freq = mean_resample_ac.result() mean_resample_ac.reset_states() tf.compat.v2.summary.scalar( name='resample_ac_freq', data=resample_ac_freq, step=global_step) # RUN COLLECTION time_step, policy_state = collect_driver.run( time_step=time_step, policy_state=policy_state, ) # get last step taken by step_driver traj = replay_buffer._data_table.read(replay_buffer._get_last_id() % replay_buffer._capacity) curr_ep.append(traj) if time_step.is_last(): if agent_class in SAFETY_AGENTS: if time_step.observation['task_agn_rew']: if kstep_fail: # applies task agn rew. over last k steps for i, traj in enumerate(curr_ep[-kstep_fail:]): traj.observation['task_agn_rew'] = 1. sc_buffer.add_batch(traj) else: [sc_buffer.add_batch(traj) for traj in curr_ep] curr_ep = [] if agent_class == wcpg_agent.WcpgAgent: collect_policy._alpha = None # reset WCPG alpha if (global_step_val + 1) % log_interval == 0: logging.debug('policy eval: %4.2f sec', time.time() - start_time) # PERFORMS TRAIN STEP ON ALGORITHM (OFF-POLICY) for _ in range(train_steps_per_iteration): train_loss = train_step() mean_train_loss(train_loss.loss) current_step = global_step.numpy() total_loss = mean_train_loss.result() mean_train_loss.reset_states() if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback( collections.OrderedDict([(k, v.numpy()) for k, v in train_loss.extra._asdict().items()]), step=current_step) train_metrics_callback( {'train_loss': total_loss.numpy()}, step=current_step) # TRAIN AND/OR EVAL SAFETY CRITIC if agent_class in SAFETY_AGENTS and current_step % train_sc_interval == 0: if online_critic: batch_time_step = sc_tf_env.reset() # run online critic training collect & update batch_policy_state = online_collect_policy.get_initial_state( sc_tf_env.batch_size) online_driver.run(time_step=batch_time_step, policy_state=batch_policy_state) for _ in range(train_sc_steps): sc_loss, lambda_loss = critic_train_step() # log safety_critic loss results critic_results = [('sc_loss', sc_loss.numpy()), ('lambda_loss', lambda_loss.numpy())] metric_utils.log_metrics(sc_metrics) for critic_metric in sc_metrics: res = critic_metric.result().numpy() if not res.shape: critic_results.append((critic_metric.name, res)) else: for r, thresh in zip(res, thresholds): name = '_'.join([critic_metric.name, str(thresh)]) critic_results.append((name, r)) critic_metric.reset_states() if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback(collections.OrderedDict(critic_results), step=current_step) # Check for exploding losses. if (math.isnan(total_loss) or math.isinf(total_loss) or total_loss > MAX_LOSS): loss_divergence_counter += 1 if loss_divergence_counter > TERMINATE_AFTER_DIVERGED_LOSS_STEPS: loss_diverged = True logging.info('Loss diverged, critic_loss: %s, actor_loss: %s', train_loss.extra.critic_loss, train_loss.extra.actor_loss) break else: loss_divergence_counter = 0 time_acc += time.time() - start_time # LOGGING AND METRICS if current_step % log_interval == 0: metric_utils.log_metrics(train_metrics) logging.info('step = %d, loss = %f', current_step, total_loss) steps_per_sec = (current_step - timed_at_step) / time_acc logging.info('%4.2f steps/sec', steps_per_sec) tf.compat.v2.summary.scalar( name='global_steps_per_sec', data=steps_per_sec, step=global_step) timed_at_step = current_step time_acc = 0 train_results = [] for metric in train_metrics[2:]: if isinstance(metric, (metrics.AverageEarlyFailureMetric, metrics.AverageFallenMetric, metrics.AverageSuccessMetric)): # Plot failure as a fn of return metric.tf_summaries( train_step=global_step, step_metrics=[num_env_steps, num_episodes, return_metric]) else: metric.tf_summaries( train_step=global_step, step_metrics=[num_env_steps, num_env_steps]) train_results.append((metric.name, metric.result().numpy())) if train_metrics_callback and current_step % summary_interval == 0: train_metrics_callback(collections.OrderedDict(train_results), step=global_step.numpy()) if current_step % train_checkpoint_interval == 0: train_checkpointer.save(global_step=current_step) if current_step % policy_checkpoint_interval == 0: policy_checkpointer.save(global_step=current_step) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save(global_step=current_step) if online_critic: online_rb_checkpointer.save(global_step=current_step) if rb_checkpoint_interval and current_step % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=current_step) if wandb and current_step % eval_interval == 0 and "Drunk" in env_name: misc.record_point_mass_episode(eval_tf_env, eval_policy, current_step) if online_critic: misc.record_point_mass_episode(eval_tf_env, tf_agent.safe_policy, current_step, 'safe-trajectory') if run_eval and current_step % eval_interval == 0: eval_results = metric_utils.eager_compute( eval_metrics, eval_tf_env, eval_policy, num_episodes=num_eval_episodes, train_step=global_step, summary_writer=eval_summary_writer, summary_prefix='EvalMetrics', ) if train_metrics_callback is not None: train_metrics_callback(eval_results, current_step) metric_utils.log_metrics(eval_metrics) with eval_summary_writer.as_default(): for eval_metric in eval_metrics[2:]: eval_metric.tf_summaries(train_step=global_step, step_metrics=eval_metrics[:2]) if monitor and current_step % monitor_interval == 0: monitor_time_step = monitor_py_env.reset() monitor_policy_state = eval_policy.get_initial_state(1) ep_len = 0 monitor_start = time.time() while not monitor_time_step.is_last(): monitor_action = eval_policy.action(monitor_time_step, monitor_policy_state) action, monitor_policy_state = monitor_action.action, monitor_action.state monitor_time_step = monitor_py_env.step(action) ep_len += 1 logging.debug('saved rollout at timestep %d, rollout length: %d, %4.2f sec', current_step, ep_len, time.time() - monitor_start) global_step_val = current_step if early_termination_fn(): # Early stopped, save all checkpoints if not saved if global_step_val % train_checkpoint_interval != 0: train_checkpointer.save(global_step=global_step_val) if global_step_val % policy_checkpoint_interval != 0: policy_checkpointer.save(global_step=global_step_val) if agent_class in SAFETY_AGENTS: safety_critic_checkpointer.save(global_step=global_step_val) if online_critic: online_rb_checkpointer.save(global_step=global_step_val) if rb_checkpoint_interval and global_step_val % rb_checkpoint_interval == 0: rb_checkpointer.save(global_step=global_step_val) if not keep_rb_checkpoint: misc.cleanup_checkpoints(rb_ckpt_dir) if loss_diverged: # Raise an error at the very end after the cleanup. raise ValueError('Loss diverged to {} at step {}, terminating.'.format( total_loss, global_step.numpy())) return total_loss