Beispiel #1
0
 def testLoadLatestCheckpoint(self):
     exp_checkpointer = checkpointer.Checkpointer(self._test_subdir)
     first_iter = 1729
     exp_checkpointer.save_checkpoint(first_iter, first_iter)
     second_iter = first_iter + 1
     exp_checkpointer.save_checkpoint(second_iter, second_iter)
     self.assertEqual(
         second_iter,
         checkpointer.get_latest_checkpoint_number(self._test_subdir))
Beispiel #2
0
    def _initialize_checkpointer_and_maybe_resume(self,
                                                  checkpoint_file_prefix):
        """Reloads the latest checkpoint if it exists.

    This method will first create a `Checkpointer` object and then call
    `checkpointer.get_latest_checkpoint_number` to determine if there is a valid
    checkpoint in self._checkpoint_dir, and what the largest file number is.
    If a valid checkpoint file is found, it will load the bundled data from this
    file and will pass it to the agent for it to reload its data.
    If the generator is able to successfully unbundle, this method will verify that
    the unbundled data contains the keys,'logs' and 'current_iteration'. It will
    then load the `Logger`'s data from the bundle, and will return the iteration
    number keyed by 'current_iteration' as one of the return values (along with
    the `Checkpointer` object).

    Args:
      checkpoint_file_prefix: str, the checkpoint file prefix.

    Returns:
      start_iteration: int, the iteration number to start the experiment from.
      experiment_checkpointer: `Checkpointer` object for the experiment.
    """
        self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                       checkpoint_file_prefix)
        self._start_iteration = 0
        # Check if checkpoint exists. Note that the existence of checkpoint 0 means
        # that we have finished iteration 0 (so we will start from iteration 1).
        latest_checkpoint_version = \
          checkpointer.get_latest_checkpoint_number(self._checkpoint_dir)
        if latest_checkpoint_version >= 0:
            experiment_data = \
              self._checkpointer.load_checkpoint(latest_checkpoint_version)
            if self._generator.unbundle(self._checkpoint_dir,
                                        latest_checkpoint_version,
                                        experiment_data):
                if experiment_data is not None:
                    assert 'logs' in experiment_data
                    assert 'current_iteration' in experiment_data
                    self._logger.data = experiment_data['logs']
                    self._start_iteration = experiment_data[
                        'current_iteration'] + 1
                tf.logging.info(
                    'Reloaded checkpoint and will start from iteration %d',
                    self._start_iteration)
Beispiel #3
0
 def testLoadLatestCheckpointWithOverride(self):
     override_number = 1729
     self.assertEqual(
         override_number,
         checkpointer.get_latest_checkpoint_number(
             '/ignored', override_number=override_number))
Beispiel #4
0
 def testLoadLatestCheckpointWithEmptyDir(self):
     self.assertEqual(
         -1, checkpointer.get_latest_checkpoint_number(self._test_subdir))
Beispiel #5
0
 def testLoadLatestCheckpointWithInvalidDir(self):
     self.assertEqual(
         -1, checkpointer.get_latest_checkpoint_number('/does/not/exist'))