示例#1
0
    def testGarbageCollectionWithCheckpointFrequency(self):
        custom_prefix = 'custom_prefix'
        checkpoint_frequency = 3
        exp_checkpointer = checkpointer.Checkpointer(
            self._test_subdir,
            checkpoint_file_prefix=custom_prefix,
            checkpoint_frequency=checkpoint_frequency)
        data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')}
        deleted_log_files = 6
        total_log_files = (checkpointer.CHECKPOINT_DURATION *
                           checkpoint_frequency) + deleted_log_files + 1

        # The checkpoints will happen in iteration numbers 0,3,6,9,12,15,18.
        # We are checking if checkpoints 0,3,6 are deleted.
        for iteration_number in range(total_log_files):
            exp_checkpointer.save_checkpoint(iteration_number, data)
        for iteration_number in range(total_log_files):
            prefixes = [custom_prefix, 'sentinel_checkpoint_complete']
            for prefix in prefixes:
                checkpoint_file = os.path.join(
                    self._test_subdir, '{}.{}'.format(prefix,
                                                      iteration_number))
                if iteration_number <= deleted_log_files:
                    self.assertFalse(tf.gfile.Exists(checkpoint_file))
                else:
                    if iteration_number % checkpoint_frequency == 0:
                        self.assertTrue(tf.gfile.Exists(checkpoint_file))
                    else:
                        self.assertFalse(tf.gfile.Exists(checkpoint_file))
示例#2
0
 def testCheckpointingInitialization(self):
     # Fails with empty directory.
     with self.assertRaisesRegexp(ValueError,
                                  'No path provided to Checkpointer.'):
         checkpointer.Checkpointer('')
     # Fails with invalid directory.
     invalid_dir = '/does/not/exist'
     with self.assertRaisesRegexp(
             ValueError,
             'Unable to create checkpoint path: {}.'.format(invalid_dir)):
         checkpointer.Checkpointer(invalid_dir)
     # Succeeds with valid directory.
     checkpointer.Checkpointer('/tmp/dopamine_tests')
     # This verifies initialization still works after the directory has already
     # been created.
     self.assertTrue(tf.gfile.Exists('/tmp/dopamine_tests'))
     checkpointer.Checkpointer('/tmp/dopamine_tests')
示例#3
0
 def testLoadLatestCheckpoint(self):
     exp_checkpointer = checkpointer.Checkpointer(self._test_subdir)
     first_iter = 1729
     exp_checkpointer.save_checkpoint(first_iter, first_iter)
     second_iter = first_iter + 1
     exp_checkpointer.save_checkpoint(second_iter, second_iter)
     self.assertEqual(
         second_iter,
         checkpointer.get_latest_checkpoint_number(self._test_subdir))
示例#4
0
 def testLogToFileWithValidDirectoryDefaultPrefix(self):
     exp_checkpointer = checkpointer.Checkpointer(self._test_subdir)
     data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')}
     iteration_number = 1729
     exp_checkpointer.save_checkpoint(iteration_number, data)
     loaded_data = exp_checkpointer.load_checkpoint(iteration_number)
     self.assertEqual(data, loaded_data)
     self.assertEqual(
         None, exp_checkpointer.load_checkpoint(iteration_number + 1))
示例#5
0
    def _initialize_checkpointer_and_maybe_resume(self,
                                                  checkpoint_file_prefix):
        """Reloads the latest checkpoint if it exists.

    This method will first create a `Checkpointer` object and then call
    `checkpointer.get_latest_checkpoint_number` to determine if there is a valid
    checkpoint in self._checkpoint_dir, and what the largest file number is.
    If a valid checkpoint file is found, it will load the bundled data from this
    file and will pass it to the agent for it to reload its data.
    If the generator is able to successfully unbundle, this method will verify that
    the unbundled data contains the keys,'logs' and 'current_iteration'. It will
    then load the `Logger`'s data from the bundle, and will return the iteration
    number keyed by 'current_iteration' as one of the return values (along with
    the `Checkpointer` object).

    Args:
      checkpoint_file_prefix: str, the checkpoint file prefix.

    Returns:
      start_iteration: int, the iteration number to start the experiment from.
      experiment_checkpointer: `Checkpointer` object for the experiment.
    """
        self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir,
                                                       checkpoint_file_prefix)
        self._start_iteration = 0
        # Check if checkpoint exists. Note that the existence of checkpoint 0 means
        # that we have finished iteration 0 (so we will start from iteration 1).
        latest_checkpoint_version = \
          checkpointer.get_latest_checkpoint_number(self._checkpoint_dir)
        if latest_checkpoint_version >= 0:
            experiment_data = \
              self._checkpointer.load_checkpoint(latest_checkpoint_version)
            if self._generator.unbundle(self._checkpoint_dir,
                                        latest_checkpoint_version,
                                        experiment_data):
                if experiment_data is not None:
                    assert 'logs' in experiment_data
                    assert 'current_iteration' in experiment_data
                    self._logger.data = experiment_data['logs']
                    self._start_iteration = experiment_data[
                        'current_iteration'] + 1
                tf.logging.info(
                    'Reloaded checkpoint and will start from iteration %d',
                    self._start_iteration)
示例#6
0
 def testGarbageCollection(self):
     custom_prefix = 'custom_prefix'
     exp_checkpointer = checkpointer.Checkpointer(
         self._test_subdir, checkpoint_file_prefix=custom_prefix)
     data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')}
     deleted_log_files = 7
     total_log_files = checkpointer.CHECKPOINT_DURATION + deleted_log_files
     for iteration_number in range(total_log_files):
         exp_checkpointer.save_checkpoint(iteration_number, data)
     for iteration_number in range(total_log_files):
         prefixes = [custom_prefix, 'sentinel_checkpoint_complete']
         for prefix in prefixes:
             checkpoint_file = os.path.join(
                 self._test_subdir, '{}.{}'.format(prefix,
                                                   iteration_number))
             if iteration_number < deleted_log_files:
                 self.assertFalse(tf.gfile.Exists(checkpoint_file))
             else:
                 self.assertTrue(tf.gfile.Exists(checkpoint_file))
示例#7
0
 def testInitializeCheckpointingWhenCheckpointUnbundleSucceeds(
         self, mock_get_latest):
     latest_checkpoint = 7
     mock_get_latest.return_value = latest_checkpoint
     logs_data = {'a': 1, 'b': 2}
     current_iteration = 1729
     checkpoint_data = {
         'current_iteration': current_iteration,
         'logs': logs_data
     }
     checkpoint_dir = os.path.join(self._test_subdir, 'checkpoints')
     checkpoint = checkpointer.Checkpointer(checkpoint_dir, 'ckpt')
     checkpoint.save_checkpoint(latest_checkpoint, checkpoint_data)
     mock_agent = mock.Mock()
     mock_agent.unbundle.return_value = True
     runner = run_experiment.Runner(self._test_subdir,
                                    lambda x, y, summary_writer: mock_agent,
                                    mock.Mock)
     expected_iteration = current_iteration + 1
     self.assertEqual(expected_iteration, runner._start_iteration)
     self.assertDictEqual(logs_data, runner._logger.data)
     mock_agent.unbundle.assert_called_once_with(checkpoint_dir,
                                                 latest_checkpoint,
                                                 checkpoint_data)