def testGarbageCollectionWithCheckpointFrequency(self): custom_prefix = 'custom_prefix' checkpoint_frequency = 3 exp_checkpointer = checkpointer.Checkpointer( self._test_subdir, checkpoint_file_prefix=custom_prefix, checkpoint_frequency=checkpoint_frequency) data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')} deleted_log_files = 6 total_log_files = (checkpointer.CHECKPOINT_DURATION * checkpoint_frequency) + deleted_log_files + 1 # The checkpoints will happen in iteration numbers 0,3,6,9,12,15,18. # We are checking if checkpoints 0,3,6 are deleted. for iteration_number in range(total_log_files): exp_checkpointer.save_checkpoint(iteration_number, data) for iteration_number in range(total_log_files): prefixes = [custom_prefix, 'sentinel_checkpoint_complete'] for prefix in prefixes: checkpoint_file = os.path.join( self._test_subdir, '{}.{}'.format(prefix, iteration_number)) if iteration_number <= deleted_log_files: self.assertFalse(tf.gfile.Exists(checkpoint_file)) else: if iteration_number % checkpoint_frequency == 0: self.assertTrue(tf.gfile.Exists(checkpoint_file)) else: self.assertFalse(tf.gfile.Exists(checkpoint_file))
def testCheckpointingInitialization(self): # Fails with empty directory. with self.assertRaisesRegexp(ValueError, 'No path provided to Checkpointer.'): checkpointer.Checkpointer('') # Fails with invalid directory. invalid_dir = '/does/not/exist' with self.assertRaisesRegexp( ValueError, 'Unable to create checkpoint path: {}.'.format(invalid_dir)): checkpointer.Checkpointer(invalid_dir) # Succeeds with valid directory. checkpointer.Checkpointer('/tmp/dopamine_tests') # This verifies initialization still works after the directory has already # been created. self.assertTrue(tf.gfile.Exists('/tmp/dopamine_tests')) checkpointer.Checkpointer('/tmp/dopamine_tests')
def testLoadLatestCheckpoint(self): exp_checkpointer = checkpointer.Checkpointer(self._test_subdir) first_iter = 1729 exp_checkpointer.save_checkpoint(first_iter, first_iter) second_iter = first_iter + 1 exp_checkpointer.save_checkpoint(second_iter, second_iter) self.assertEqual( second_iter, checkpointer.get_latest_checkpoint_number(self._test_subdir))
def testLogToFileWithValidDirectoryDefaultPrefix(self): exp_checkpointer = checkpointer.Checkpointer(self._test_subdir) data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')} iteration_number = 1729 exp_checkpointer.save_checkpoint(iteration_number, data) loaded_data = exp_checkpointer.load_checkpoint(iteration_number) self.assertEqual(data, loaded_data) self.assertEqual( None, exp_checkpointer.load_checkpoint(iteration_number + 1))
def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix): """Reloads the latest checkpoint if it exists. This method will first create a `Checkpointer` object and then call `checkpointer.get_latest_checkpoint_number` to determine if there is a valid checkpoint in self._checkpoint_dir, and what the largest file number is. If a valid checkpoint file is found, it will load the bundled data from this file and will pass it to the agent for it to reload its data. If the generator is able to successfully unbundle, this method will verify that the unbundled data contains the keys,'logs' and 'current_iteration'. It will then load the `Logger`'s data from the bundle, and will return the iteration number keyed by 'current_iteration' as one of the return values (along with the `Checkpointer` object). Args: checkpoint_file_prefix: str, the checkpoint file prefix. Returns: start_iteration: int, the iteration number to start the experiment from. experiment_checkpointer: `Checkpointer` object for the experiment. """ self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir, checkpoint_file_prefix) self._start_iteration = 0 # Check if checkpoint exists. Note that the existence of checkpoint 0 means # that we have finished iteration 0 (so we will start from iteration 1). latest_checkpoint_version = \ checkpointer.get_latest_checkpoint_number(self._checkpoint_dir) if latest_checkpoint_version >= 0: experiment_data = \ self._checkpointer.load_checkpoint(latest_checkpoint_version) if self._generator.unbundle(self._checkpoint_dir, latest_checkpoint_version, experiment_data): if experiment_data is not None: assert 'logs' in experiment_data assert 'current_iteration' in experiment_data self._logger.data = experiment_data['logs'] self._start_iteration = experiment_data[ 'current_iteration'] + 1 tf.logging.info( 'Reloaded checkpoint and will start from iteration %d', self._start_iteration)
def testGarbageCollection(self): custom_prefix = 'custom_prefix' exp_checkpointer = checkpointer.Checkpointer( self._test_subdir, checkpoint_file_prefix=custom_prefix) data = {'data1': 1, 'data2': 'two', 'data3': (3, 'three')} deleted_log_files = 7 total_log_files = checkpointer.CHECKPOINT_DURATION + deleted_log_files for iteration_number in range(total_log_files): exp_checkpointer.save_checkpoint(iteration_number, data) for iteration_number in range(total_log_files): prefixes = [custom_prefix, 'sentinel_checkpoint_complete'] for prefix in prefixes: checkpoint_file = os.path.join( self._test_subdir, '{}.{}'.format(prefix, iteration_number)) if iteration_number < deleted_log_files: self.assertFalse(tf.gfile.Exists(checkpoint_file)) else: self.assertTrue(tf.gfile.Exists(checkpoint_file))
def testInitializeCheckpointingWhenCheckpointUnbundleSucceeds( self, mock_get_latest): latest_checkpoint = 7 mock_get_latest.return_value = latest_checkpoint logs_data = {'a': 1, 'b': 2} current_iteration = 1729 checkpoint_data = { 'current_iteration': current_iteration, 'logs': logs_data } checkpoint_dir = os.path.join(self._test_subdir, 'checkpoints') checkpoint = checkpointer.Checkpointer(checkpoint_dir, 'ckpt') checkpoint.save_checkpoint(latest_checkpoint, checkpoint_data) mock_agent = mock.Mock() mock_agent.unbundle.return_value = True runner = run_experiment.Runner(self._test_subdir, lambda x, y, summary_writer: mock_agent, mock.Mock) expected_iteration = current_iteration + 1 self.assertEqual(expected_iteration, runner._start_iteration) self.assertDictEqual(logs_data, runner._logger.data) mock_agent.unbundle.assert_called_once_with(checkpoint_dir, latest_checkpoint, checkpoint_data)