def testOuterDimsNestRemovesDimensionsFromSpecs(self, dtype): if dtype == tf.string: self.skipTest("Not compatible with string type.") nested_spec = example_nested_tensor_spec(dtype) larger_spec = tensor_spec.add_outer_dims_nest(nested_spec, (3, 4)) removed_spec = tensor_spec.remove_outer_dims_nest(larger_spec, 2) self.assertEqual(nested_spec, removed_spec)
def testOuterDimsNestAddsDimensionsToSpecs(self, dtype): if dtype == tf.string: self.skipTest("Not compatible with string type.") nested_spec = example_nested_tensor_spec(dtype) outer_dims = (4, 3) self.assertEqual( tensor_spec.add_outer_dims_nest(nested_spec, outer_dims), example_nested_tensor_spec(dtype, outer_dims))
def __init__(self, root_dir, train_step, agent, experience_dataset_fn=None, after_train_strategy_step_fn=None, triggers=None, checkpoint_interval=100000, summary_interval=1000, max_checkpoints_to_keep=3, use_kwargs_in_agent_train=False, strategy=None, run_optimizer_variable_init=True): """Initializes a Learner instance. Args: root_dir: Main directory path where checkpoints, saved_models, and summaries will be written to. train_step: a scalar tf.int64 `tf.Variable` which will keep track of the number of train steps. This is used for artifacts created like summaries, or outputs in the root_dir. agent: `tf_agent.TFAgent` instance to train with. experience_dataset_fn: a function that will create an instance of a tf.data.Dataset used to sample experience for training. Required for using the Learner as is. Optional for subclass learners which take a new iterator each time when `learner.run` is called. after_train_strategy_step_fn: (Optional) callable of the form `fn(sample, loss)` which can be used for example to update priorities in a replay buffer where sample is pulled from the `experience_iterator` and loss is a `LossInfo` named tuple returned from the agent. This is called after every train step. It runs using `strategy.run(...)`. triggers: List of callables of the form `trigger(train_step)`. After every `run` call every trigger is called with the current `train_step` value as an np scalar. checkpoint_interval: Number of train steps in between checkpoints. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every `run` call. Set to -1 to disable (this is not recommended, because it means that if the pipeline gets preempted, all previous progress is lost). This only takes care of the checkpointing the training process. Policies must be explicitly exported through triggers. summary_interval: Number of train steps in between summaries. Note these are placed into triggers and so a check to generate a checkpoint only occurs after every `run` call. max_checkpoints_to_keep: Maximum number of checkpoints to keep around. These are used to recover from pre-emptions when training. use_kwargs_in_agent_train: If True the experience from the replay buffer is passed into the agent as kwargs. This requires samples from the RB to be of the form `dict(experience=experience, kwarg1=kwarg1, ...)`. This is useful if you have an agent with a custom argspec. strategy: (Optional) `tf.distribute.Strategy` to use during training. run_optimizer_variable_init: Specifies if the variables of the optimizer are initialized before checkpointing. This should be almost always `True` (default) to ensure that the state of the optimizer is checkpointed properly. The initialization of the optimizer variables happens by building the Tensorflow graph. This is done by calling a `get_concrete_function` on the agent's `train` method which requires passing some input. Since, no real data is available at this point we use the batched form of `training_data_spec` and `train_argspec` to achieve this (standard technique). The problem arises when the agent expects some agent specific batching of the input. In this case, there is no _general_ way at this point in the learner to batch the impacted specs properly. To avoid breaking the code in these specific cases, we recommend turning off initialization of the optimizer variables by setting the value of this field to `False`. """ if checkpoint_interval < 0: logging.warning( 'Warning: checkpointing the training process is manually disabled.' 'This means training progress will NOT be automatically restored ' 'if the job gets preempted.') self._train_dir = os.path.join(root_dir, TRAIN_DIR) self.train_summary_writer = tf.compat.v2.summary.create_file_writer( self._train_dir, flush_millis=10000) self.train_step = train_step self._agent = agent self.use_kwargs_in_agent_train = use_kwargs_in_agent_train self.strategy = strategy or tf.distribute.get_strategy() if experience_dataset_fn: with self.strategy.scope(): dataset = self.strategy.experimental_distribute_datasets_from_function( lambda _: experience_dataset_fn()) self._experience_iterator = iter(dataset) self.after_train_strategy_step_fn = after_train_strategy_step_fn self.triggers = triggers or [] # Prevent autograph from going into the agent. self._agent.train = tf.autograph.experimental.do_not_convert( agent.train) checkpoint_dir = os.path.join(self._train_dir, POLICY_CHECKPOINT_DIR) with self.strategy.scope(): agent.initialize() if run_optimizer_variable_init: # Force a concrete function creation inside of the strategy scope to # ensure that all variables, including optimizer slot variables, are # created. This has to happen before the checkpointer is created. # TODO(b/179694393): The add agent specific outer dimensions. batched_specs = tensor_spec.add_outer_dims_nest( self._agent.training_data_spec, (None, self._agent.train_sequence_length)) batched_train_argspec = tensor_spec.add_outer_dims_nest( self._agent.train_argspec or {}, (None, )) if self.use_kwargs_in_agent_train: batched_specs = dict(experience=batched_specs, **batched_train_argspec) @common.function def _create_variables(specs): # TODO(b/170516529): Each replica has to be in the same graph. # This can be ensured by placing the `strategy.run(...)` call inside # the `tf.function`. if self.use_kwargs_in_agent_train: return self.strategy.run(self._agent.train, kwargs=specs) return self.strategy.run(self._agent.train, args=(specs, )) try: _create_variables.get_concrete_function(batched_specs) except Exception as e: # pylint: disable=broad-except six.reraise( type(e), RuntimeError( 'The slot variable initialization failed. The learner assumes ' 'all experience tensors required an `outer_rank = (None, ' 'agent.train_sequence_length)`. If that\'s not the case for your ' 'agent try setting `run_optimizer_variable_init=False`.' )) self._checkpointer = common.Checkpointer( checkpoint_dir, max_to_keep=max_checkpoints_to_keep, agent=self._agent, train_step=self.train_step) self._checkpointer.initialize_or_restore() # pytype: disable=attribute-error self.triggers.append(self._get_checkpoint_trigger(checkpoint_interval)) self.summary_interval = tf.constant(summary_interval, dtype=tf.int64)
def testAddOuterShapeWhenNotTupleOrListThrows(self, dtype): with self.assertRaises(ValueError): tensor_spec.add_outer_dims_nest(1, example_nested_tensor_spec(dtype))
def testOuterDimsNestAddsDimensionsToSpecs(self, dtype): nested_spec = example_nested_tensor_spec(dtype) outer_dims = (4, 3) self.assertEqual( tensor_spec.add_outer_dims_nest(nested_spec, outer_dims), example_nested_tensor_spec(dtype, outer_dims))
def testAddOuterShapeWhenNotTupleOrListThrows(self, dtype): if dtype == tf.string: self.skipTest("Not compatible with string type.") with self.assertRaises(ValueError): tensor_spec.add_outer_dims_nest(1, example_nested_tensor_spec(dtype))