예제 #1
0
 def testOuterDimsNestRemovesDimensionsFromSpecs(self, dtype):
     if dtype == tf.string:
         self.skipTest("Not compatible with string type.")
     nested_spec = example_nested_tensor_spec(dtype)
     larger_spec = tensor_spec.add_outer_dims_nest(nested_spec, (3, 4))
     removed_spec = tensor_spec.remove_outer_dims_nest(larger_spec, 2)
     self.assertEqual(nested_spec, removed_spec)
예제 #2
0
 def testOuterDimsNestAddsDimensionsToSpecs(self, dtype):
     if dtype == tf.string:
         self.skipTest("Not compatible with string type.")
     nested_spec = example_nested_tensor_spec(dtype)
     outer_dims = (4, 3)
     self.assertEqual(
         tensor_spec.add_outer_dims_nest(nested_spec, outer_dims),
         example_nested_tensor_spec(dtype, outer_dims))
예제 #3
0
    def __init__(self,
                 root_dir,
                 train_step,
                 agent,
                 experience_dataset_fn=None,
                 after_train_strategy_step_fn=None,
                 triggers=None,
                 checkpoint_interval=100000,
                 summary_interval=1000,
                 max_checkpoints_to_keep=3,
                 use_kwargs_in_agent_train=False,
                 strategy=None,
                 run_optimizer_variable_init=True):
        """Initializes a Learner instance.

    Args:
      root_dir: Main directory path where checkpoints, saved_models, and
        summaries will be written to.
      train_step: a scalar tf.int64 `tf.Variable` which will keep track of the
        number of train steps. This is used for artifacts created like
        summaries, or outputs in the root_dir.
      agent: `tf_agent.TFAgent` instance to train with.
      experience_dataset_fn: a function that will create an instance of a
        tf.data.Dataset used to sample experience for training. Required for
        using the Learner as is. Optional for subclass learners which take a new
        iterator each time when `learner.run` is called.
      after_train_strategy_step_fn: (Optional) callable of the form
        `fn(sample, loss)` which can be used for example to update priorities in
        a replay buffer where sample is pulled from the `experience_iterator`
        and loss is a `LossInfo` named tuple returned from the agent. This is
        called after every train step. It runs using `strategy.run(...)`.
      triggers: List of callables of the form `trigger(train_step)`. After every
        `run` call every trigger is called with the current `train_step` value
        as an np scalar.
      checkpoint_interval: Number of train steps in between checkpoints. Note
        these are placed into triggers and so a check to generate a checkpoint
        only occurs after every `run` call. Set to -1 to disable (this is not
        recommended, because it means that if the pipeline gets preempted, all
        previous progress is lost). This only takes care of the checkpointing
        the training process.  Policies must be explicitly exported through
        triggers.
      summary_interval: Number of train steps in between summaries. Note these
        are placed into triggers and so a check to generate a checkpoint only
        occurs after every `run` call.
      max_checkpoints_to_keep: Maximum number of checkpoints to keep around.
        These are used to recover from pre-emptions when training.
      use_kwargs_in_agent_train: If True the experience from the replay buffer
        is passed into the agent as kwargs. This requires samples from the RB to
        be of the form `dict(experience=experience, kwarg1=kwarg1, ...)`. This
        is useful if you have an agent with a custom argspec.
      strategy: (Optional) `tf.distribute.Strategy` to use during training.
      run_optimizer_variable_init: Specifies if the variables of the optimizer
        are initialized before checkpointing. This should be almost always
        `True` (default) to ensure that the state of the optimizer is
        checkpointed properly. The initialization of the optimizer variables
        happens by building the Tensorflow graph. This is done by calling a
        `get_concrete_function` on the agent's `train` method which requires
        passing some input. Since, no real data is available at this point we
        use the batched form of `training_data_spec` and `train_argspec` to
        achieve this (standard technique). The problem arises when the agent
        expects some agent specific batching of the input. In this case, there
        is no _general_ way at this point in the learner to batch the impacted
        specs properly. To avoid breaking the code in these specific cases, we
        recommend turning off initialization of the optimizer variables by
        setting the value of this field to `False`.
    """
        if checkpoint_interval < 0:
            logging.warning(
                'Warning: checkpointing the training process is manually disabled.'
                'This means training progress will NOT be automatically restored '
                'if the job gets preempted.')

        self._train_dir = os.path.join(root_dir, TRAIN_DIR)
        self.train_summary_writer = tf.compat.v2.summary.create_file_writer(
            self._train_dir, flush_millis=10000)

        self.train_step = train_step
        self._agent = agent
        self.use_kwargs_in_agent_train = use_kwargs_in_agent_train
        self.strategy = strategy or tf.distribute.get_strategy()

        if experience_dataset_fn:
            with self.strategy.scope():
                dataset = self.strategy.experimental_distribute_datasets_from_function(
                    lambda _: experience_dataset_fn())
                self._experience_iterator = iter(dataset)

        self.after_train_strategy_step_fn = after_train_strategy_step_fn
        self.triggers = triggers or []

        # Prevent autograph from going into the agent.
        self._agent.train = tf.autograph.experimental.do_not_convert(
            agent.train)

        checkpoint_dir = os.path.join(self._train_dir, POLICY_CHECKPOINT_DIR)
        with self.strategy.scope():
            agent.initialize()

            if run_optimizer_variable_init:
                # Force a concrete function creation inside of the strategy scope to
                # ensure that all variables, including optimizer slot variables, are
                # created. This has to happen before the checkpointer is created.
                # TODO(b/179694393): The add agent specific outer dimensions.
                batched_specs = tensor_spec.add_outer_dims_nest(
                    self._agent.training_data_spec,
                    (None, self._agent.train_sequence_length))
                batched_train_argspec = tensor_spec.add_outer_dims_nest(
                    self._agent.train_argspec or {}, (None, ))
                if self.use_kwargs_in_agent_train:
                    batched_specs = dict(experience=batched_specs,
                                         **batched_train_argspec)

                @common.function
                def _create_variables(specs):
                    # TODO(b/170516529): Each replica has to be in the same graph.
                    # This can be ensured by placing the `strategy.run(...)` call inside
                    # the `tf.function`.
                    if self.use_kwargs_in_agent_train:
                        return self.strategy.run(self._agent.train,
                                                 kwargs=specs)
                    return self.strategy.run(self._agent.train, args=(specs, ))

                try:
                    _create_variables.get_concrete_function(batched_specs)
                except Exception as e:  # pylint: disable=broad-except
                    six.reraise(
                        type(e),
                        RuntimeError(
                            'The slot variable initialization failed. The learner assumes '
                            'all experience tensors required an `outer_rank = (None, '
                            'agent.train_sequence_length)`. If that\'s not the case for your '
                            'agent try setting `run_optimizer_variable_init=False`.'
                        ))

            self._checkpointer = common.Checkpointer(
                checkpoint_dir,
                max_to_keep=max_checkpoints_to_keep,
                agent=self._agent,
                train_step=self.train_step)
            self._checkpointer.initialize_or_restore()  # pytype: disable=attribute-error

        self.triggers.append(self._get_checkpoint_trigger(checkpoint_interval))
        self.summary_interval = tf.constant(summary_interval, dtype=tf.int64)
예제 #4
0
 def testAddOuterShapeWhenNotTupleOrListThrows(self, dtype):
     with self.assertRaises(ValueError):
         tensor_spec.add_outer_dims_nest(1,
                                         example_nested_tensor_spec(dtype))
예제 #5
0
 def testOuterDimsNestAddsDimensionsToSpecs(self, dtype):
     nested_spec = example_nested_tensor_spec(dtype)
     outer_dims = (4, 3)
     self.assertEqual(
         tensor_spec.add_outer_dims_nest(nested_spec, outer_dims),
         example_nested_tensor_spec(dtype, outer_dims))
예제 #6
0
 def testAddOuterShapeWhenNotTupleOrListThrows(self, dtype):
     if dtype == tf.string:
         self.skipTest("Not compatible with string type.")
     with self.assertRaises(ValueError):
         tensor_spec.add_outer_dims_nest(1,
                                         example_nested_tensor_spec(dtype))