def testCheckPointStateFailsWhenIncomplete(self): save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete") os.chdir(save_dir) ckpt_path = os.path.join(save_dir, "checkpoint") ckpt_file = open(ckpt_path, "w") ckpt_file.write("") ckpt_file.close() with self.assertRaises(ValueError): checkpoint_management.get_checkpoint_state(save_dir)
def testClockReset(self, mock_time): directory = self.get_temp_dir() mock_time.time.return_value = 10000. checkpoint = util.Checkpoint() first_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=1, keep_checkpoint_every_n_hours=1.) first_path = first_manager.save() mock_time.time.return_value += 3600. second_path = first_manager.save() mock_time.time.return_value += 3600. third_path = first_manager.save() self.assertFalse(checkpoint_management.checkpoint_exists(first_path)) self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertEqual([third_path], first_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(13600., state.last_preserved_timestamp) # Set the clock back in time mock_time.time.return_value = 5000. del first_manager with test.mock.patch.object(logging, "warning") as mock_log: second_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=1) self.assertRegex(str(mock_log.call_args), "behind the last preserved checkpoint timestamp") # We should err on the side of keeping checkpoints around when we're not # sure whether they were preserved or not due to clock funkiness. self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) # We know about the existing checkpoints, but they'll never be deleted and # so won't go in the CheckpointState proto on save. self.assertEqual(third_path, second_manager.latest_checkpoint) self.assertEqual([], second_manager.checkpoints) mock_time.time.return_value += 10. fourth_path = second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertEqual(fourth_path, second_manager.latest_checkpoint) self.assertEqual([fourth_path], second_manager.checkpoints) mock_time.time.return_value += 10. fifth_path = second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(second_path)) self.assertTrue(checkpoint_management.checkpoint_exists(third_path)) self.assertEqual([fifth_path], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(5000., state.last_preserved_timestamp) self.assertEqual([5020.], state.all_model_checkpoint_timestamps)
def testUpdateCheckpointStateSaveRelativePaths(self): save_dir = self._get_test_dir("update_checkpoint_state") os.chdir(save_dir) abs_path2 = os.path.join(save_dir, "model-2") rel_path2 = "model-2" abs_path0 = os.path.join(save_dir, "model-0") rel_path0 = "model-0" checkpoint_management.update_checkpoint_state_internal( save_dir=save_dir, model_checkpoint_path=abs_path2, all_model_checkpoint_paths=[rel_path0, abs_path2], save_relative_paths=True) # File should contain relative paths. file_content = file_io.read_file_to_string( os.path.join(save_dir, "checkpoint")) ckpt = CheckpointState() text_format.Merge(file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, rel_path2) self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) # get_checkpoint_state should return absolute paths. ckpt = checkpoint_management.get_checkpoint_state(save_dir) self.assertEqual(ckpt.model_checkpoint_path, abs_path2) self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
def testFSPath(self): save_dir = self._get_test_dir("fspath") os.chdir(save_dir) # Make a temporary train directory. train_dir = "train" os.mkdir(train_dir) abs_path = os.path.join(save_dir, "model-0") rel_path = os.path.join("train", "model-2") checkpoint_management.update_checkpoint_state( train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) ckpt = checkpoint_management.get_checkpoint_state( pathlib.Path(train_dir)) self.assertEqual(ckpt.model_checkpoint_path, rel_path)
def testUpdateCheckpointState(self): save_dir = self._get_test_dir("update_checkpoint_state") os.chdir(save_dir) # Make a temporary train directory. train_dir = "train" os.mkdir(train_dir) abs_path = os.path.join(save_dir, "model-0") rel_path = os.path.join("train", "model-2") checkpoint_management.update_checkpoint_state( train_dir, rel_path, all_model_checkpoint_paths=[abs_path, rel_path]) ckpt = checkpoint_management.get_checkpoint_state(train_dir) self.assertEqual(ckpt.model_checkpoint_path, rel_path) self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path) self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
def testCheckPointCompletesRelativePaths(self): save_dir = self._get_test_dir("checkpoint_completes_relative_paths") os.chdir(save_dir) ckpt_path = os.path.join(save_dir, "checkpoint") ckpt_file = open(ckpt_path, "w") ckpt_file.write(""" model_checkpoint_path: "./model.ckpt-687529" all_model_checkpoint_paths: "./model.ckpt-687500" all_model_checkpoint_paths: "./model.ckpt-687529" """) ckpt_file.close() ckpt = checkpoint_management.get_checkpoint_state(save_dir) self.assertEqual(ckpt.model_checkpoint_path, os.path.join(save_dir, "./model.ckpt-687529")) self.assertEqual(ckpt.all_model_checkpoint_paths[0], os.path.join(save_dir, "./model.ckpt-687500")) self.assertEqual(ckpt.all_model_checkpoint_paths[1], os.path.join(save_dir, "./model.ckpt-687529"))
def testCustomNumbering(self): directory = self.get_temp_dir() step = variables.Variable(0, dtype=dtypes.int64) checkpoint = util.Checkpoint(step=step) manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=2) self.evaluate(step.initializer) for i in range(5): path = manager.save(checkpoint_number=step) expected_suffix = "-%d" % (2 * i, ) if not path.endswith(expected_suffix): self.fail("%s should have suffix %s" % (path, expected_suffix)) self.evaluate(step.assign_add(2)) self.assertEqual(5, self.evaluate(checkpoint.save_counter)) # Test regular integers last_path = manager.save(checkpoint_number=32) self.assertIn("-32", last_path) self.assertEqual(last_path, manager.latest_checkpoint) self.assertEqual(last_path, checkpoint_management.latest_checkpoint(directory)) state = checkpoint_management.get_checkpoint_state(directory) # Only the most recent two checkpoints are saved self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
def _restore_checkpoint(self, master, saver=None, checkpoint_dir=None, checkpoint_filename_with_path=None, wait_for_checkpoint=False, max_wait_secs=7200, config=None): """Creates a `Session`, and tries to restore a checkpoint. Args: master: `String` representation of the TensorFlow master to use. saver: A `Saver` object used to restore a model. checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the dir will be used to restore. checkpoint_filename_with_path: Full file name path to the checkpoint file. wait_for_checkpoint: Whether to wait for checkpoint to become available. max_wait_secs: Maximum time to wait for checkpoints to become available. config: Optional `ConfigProto` proto used to configure the session. Returns: A pair (sess, is_restored) where 'is_restored' is `True` if the session could be restored, `False` otherwise. Raises: ValueError: If both checkpoint_dir and checkpoint_filename_with_path are set. """ self._target = master # This is required to so that we initialize the TPU device before # restoring from checkpoint since we'll be placing variables on the device # and TPUInitialize wipes out the memory of the device. strategy = distribution_strategy_context.get_strategy() if strategy and hasattr(strategy.extended, "_experimental_initialize_system"): strategy.extended._experimental_initialize_system() # pylint: disable=protected-access sess = session.Session(self._target, graph=self._graph, config=config) if checkpoint_dir and checkpoint_filename_with_path: raise ValueError("Can not provide both checkpoint_dir and " "checkpoint_filename_with_path.") # If either saver or checkpoint_* is not specified, cannot restore. Just # return. if not saver or not (checkpoint_dir or checkpoint_filename_with_path): return sess, False if checkpoint_filename_with_path: _restore_checkpoint_and_maybe_run_saved_model_initializers( sess, saver, checkpoint_filename_with_path) return sess, True # Waits up until max_wait_secs for checkpoint to become available. wait_time = 0 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) while not ckpt or not ckpt.model_checkpoint_path: if wait_for_checkpoint and wait_time < max_wait_secs: logging.info("Waiting for checkpoint to be available.") time.sleep(self._recovery_wait_secs) wait_time += self._recovery_wait_secs ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) else: return sess, False # Loads the checkpoint. _restore_checkpoint_and_maybe_run_saved_model_initializers( sess, saver, ckpt.model_checkpoint_path) saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) return sess, True
def testSaveRestoreState(self, mock_time): directory = self.get_temp_dir() mock_time.time.return_value = 3. checkpoint = util.Checkpoint() first_manager = checkpoint_management.CheckpointManager(checkpoint, directory, max_to_keep=2) first_time = 10000. first_name = os.path.join(directory, "ckpt-1") mock_time.time.return_value = first_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) second_time = first_time + 3610. second_name = os.path.join(directory, "ckpt-2") mock_time.time.return_value = second_time first_manager.save() state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual([first_time, second_time], state.all_model_checkpoint_timestamps) self.assertEqual([first_name, second_name], first_manager.checkpoints) self.assertEqual(second_name, first_manager.latest_checkpoint) del first_manager second_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual([first_name, second_name], second_manager.checkpoints) self.assertEqual(second_name, second_manager.latest_checkpoint) third_name = os.path.join(directory, "ckpt-3") third_time = second_time + 3600. * 0.2 mock_time.time.return_value = third_time second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([second_name, third_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) fourth_time = third_time + 3600. * 0.5 mock_time.time.return_value = fourth_time fourth_name = os.path.join(directory, "ckpt-4") second_manager.save() self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertEqual([third_name, fourth_name], second_manager.checkpoints) fifth_time = fourth_time + 3600. * 0.5 mock_time.time.return_value = fifth_time fifth_name = os.path.join(directory, "ckpt-5") second_manager.save() self.assertEqual([fourth_name, fifth_name], second_manager.checkpoints) state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(first_time, state.last_preserved_timestamp) del second_manager third_manager = checkpoint_management.CheckpointManager( checkpoint, directory, max_to_keep=2, keep_checkpoint_every_n_hours=1.5) self.assertEqual(fifth_name, third_manager.latest_checkpoint) mock_time.time.return_value += 10. third_manager.save() sixth_name = os.path.join(directory, "ckpt-6") state = checkpoint_management.get_checkpoint_state(directory) self.assertEqual(fourth_time, state.last_preserved_timestamp) self.assertTrue(checkpoint_management.checkpoint_exists(first_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name)) self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name)) self.assertFalse(checkpoint_management.checkpoint_exists(second_name)) self.assertFalse(checkpoint_management.checkpoint_exists(third_name)) self.assertEqual([fifth_name, sixth_name], third_manager.checkpoints)