def testCheckPointStateFailsWhenIncomplete(self):
     save_dir = self._get_test_dir("checkpoint_state_fails_when_incomplete")
     os.chdir(save_dir)
     ckpt_path = os.path.join(save_dir, "checkpoint")
     ckpt_file = open(ckpt_path, "w")
     ckpt_file.write("")
     ckpt_file.close()
     with self.assertRaises(ValueError):
         checkpoint_management.get_checkpoint_state(save_dir)
 def testClockReset(self, mock_time):
     directory = self.get_temp_dir()
     mock_time.time.return_value = 10000.
     checkpoint = util.Checkpoint()
     first_manager = checkpoint_management.CheckpointManager(
         checkpoint,
         directory,
         max_to_keep=1,
         keep_checkpoint_every_n_hours=1.)
     first_path = first_manager.save()
     mock_time.time.return_value += 3600.
     second_path = first_manager.save()
     mock_time.time.return_value += 3600.
     third_path = first_manager.save()
     self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual([third_path], first_manager.checkpoints)
     state = checkpoint_management.get_checkpoint_state(directory)
     self.assertEqual(13600., state.last_preserved_timestamp)
     # Set the clock back in time
     mock_time.time.return_value = 5000.
     del first_manager
     with test.mock.patch.object(logging, "warning") as mock_log:
         second_manager = checkpoint_management.CheckpointManager(
             checkpoint, directory, max_to_keep=1)
         self.assertRegex(str(mock_log.call_args),
                          "behind the last preserved checkpoint timestamp")
     # We should err on the side of keeping checkpoints around when we're not
     # sure whether they were preserved or not due to clock funkiness.
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     # We know about the existing checkpoints, but they'll never be deleted and
     # so won't go in the CheckpointState proto on save.
     self.assertEqual(third_path, second_manager.latest_checkpoint)
     self.assertEqual([], second_manager.checkpoints)
     mock_time.time.return_value += 10.
     fourth_path = second_manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual(fourth_path, second_manager.latest_checkpoint)
     self.assertEqual([fourth_path], second_manager.checkpoints)
     mock_time.time.return_value += 10.
     fifth_path = second_manager.save()
     self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
     self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
     self.assertEqual([fifth_path], second_manager.checkpoints)
     state = checkpoint_management.get_checkpoint_state(directory)
     self.assertEqual(5000., state.last_preserved_timestamp)
     self.assertEqual([5020.], state.all_model_checkpoint_timestamps)
    def testUpdateCheckpointStateSaveRelativePaths(self):
        save_dir = self._get_test_dir("update_checkpoint_state")
        os.chdir(save_dir)
        abs_path2 = os.path.join(save_dir, "model-2")
        rel_path2 = "model-2"
        abs_path0 = os.path.join(save_dir, "model-0")
        rel_path0 = "model-0"
        checkpoint_management.update_checkpoint_state_internal(
            save_dir=save_dir,
            model_checkpoint_path=abs_path2,
            all_model_checkpoint_paths=[rel_path0, abs_path2],
            save_relative_paths=True)

        # File should contain relative paths.
        file_content = file_io.read_file_to_string(
            os.path.join(save_dir, "checkpoint"))
        ckpt = CheckpointState()
        text_format.Merge(file_content, ckpt)
        self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
        self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
        self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
        self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)

        # get_checkpoint_state should return absolute paths.
        ckpt = checkpoint_management.get_checkpoint_state(save_dir)
        self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
        self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
        self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
        self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
 def testFSPath(self):
     save_dir = self._get_test_dir("fspath")
     os.chdir(save_dir)
     # Make a temporary train directory.
     train_dir = "train"
     os.mkdir(train_dir)
     abs_path = os.path.join(save_dir, "model-0")
     rel_path = os.path.join("train", "model-2")
     checkpoint_management.update_checkpoint_state(
         train_dir,
         rel_path,
         all_model_checkpoint_paths=[abs_path, rel_path])
     ckpt = checkpoint_management.get_checkpoint_state(
         pathlib.Path(train_dir))
     self.assertEqual(ckpt.model_checkpoint_path, rel_path)
 def testUpdateCheckpointState(self):
     save_dir = self._get_test_dir("update_checkpoint_state")
     os.chdir(save_dir)
     # Make a temporary train directory.
     train_dir = "train"
     os.mkdir(train_dir)
     abs_path = os.path.join(save_dir, "model-0")
     rel_path = os.path.join("train", "model-2")
     checkpoint_management.update_checkpoint_state(
         train_dir,
         rel_path,
         all_model_checkpoint_paths=[abs_path, rel_path])
     ckpt = checkpoint_management.get_checkpoint_state(train_dir)
     self.assertEqual(ckpt.model_checkpoint_path, rel_path)
     self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
     self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path)
     self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path)
 def testCheckPointCompletesRelativePaths(self):
     save_dir = self._get_test_dir("checkpoint_completes_relative_paths")
     os.chdir(save_dir)
     ckpt_path = os.path.join(save_dir, "checkpoint")
     ckpt_file = open(ckpt_path, "w")
     ckpt_file.write("""
     model_checkpoint_path: "./model.ckpt-687529"
     all_model_checkpoint_paths: "./model.ckpt-687500"
     all_model_checkpoint_paths: "./model.ckpt-687529"
     """)
     ckpt_file.close()
     ckpt = checkpoint_management.get_checkpoint_state(save_dir)
     self.assertEqual(ckpt.model_checkpoint_path,
                      os.path.join(save_dir, "./model.ckpt-687529"))
     self.assertEqual(ckpt.all_model_checkpoint_paths[0],
                      os.path.join(save_dir, "./model.ckpt-687500"))
     self.assertEqual(ckpt.all_model_checkpoint_paths[1],
                      os.path.join(save_dir, "./model.ckpt-687529"))
 def testCustomNumbering(self):
     directory = self.get_temp_dir()
     step = variables.Variable(0, dtype=dtypes.int64)
     checkpoint = util.Checkpoint(step=step)
     manager = checkpoint_management.CheckpointManager(checkpoint,
                                                       directory,
                                                       max_to_keep=2)
     self.evaluate(step.initializer)
     for i in range(5):
         path = manager.save(checkpoint_number=step)
         expected_suffix = "-%d" % (2 * i, )
         if not path.endswith(expected_suffix):
             self.fail("%s should have suffix %s" % (path, expected_suffix))
         self.evaluate(step.assign_add(2))
     self.assertEqual(5, self.evaluate(checkpoint.save_counter))
     # Test regular integers
     last_path = manager.save(checkpoint_number=32)
     self.assertIn("-32", last_path)
     self.assertEqual(last_path, manager.latest_checkpoint)
     self.assertEqual(last_path,
                      checkpoint_management.latest_checkpoint(directory))
     state = checkpoint_management.get_checkpoint_state(directory)
     # Only the most recent two checkpoints are saved
     self.assertEqual([path, last_path], state.all_model_checkpoint_paths)
Ejemplo n.º 8
0
  def _restore_checkpoint(self,
                          master,
                          saver=None,
                          checkpoint_dir=None,
                          checkpoint_filename_with_path=None,
                          wait_for_checkpoint=False,
                          max_wait_secs=7200,
                          config=None):
    """Creates a `Session`, and tries to restore a checkpoint.


    Args:
      master: `String` representation of the TensorFlow master to use.
      saver: A `Saver` object used to restore a model.
      checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the
        dir will be used to restore.
      checkpoint_filename_with_path: Full file name path to the checkpoint file.
      wait_for_checkpoint: Whether to wait for checkpoint to become available.
      max_wait_secs: Maximum time to wait for checkpoints to become available.
      config: Optional `ConfigProto` proto used to configure the session.

    Returns:
      A pair (sess, is_restored) where 'is_restored' is `True` if
      the session could be restored, `False` otherwise.

    Raises:
      ValueError: If both checkpoint_dir and checkpoint_filename_with_path are
        set.
    """
    self._target = master

    # This is required to so that we initialize the TPU device before
    # restoring from checkpoint since we'll be placing variables on the device
    # and TPUInitialize wipes out the memory of the device.
    strategy = distribution_strategy_context.get_strategy()
    if strategy and hasattr(strategy.extended,
                            "_experimental_initialize_system"):
      strategy.extended._experimental_initialize_system()  # pylint: disable=protected-access

    sess = session.Session(self._target, graph=self._graph, config=config)
    if checkpoint_dir and checkpoint_filename_with_path:
      raise ValueError("Can not provide both checkpoint_dir and "
                       "checkpoint_filename_with_path.")
    # If either saver or checkpoint_* is not specified, cannot restore. Just
    # return.
    if not saver or not (checkpoint_dir or checkpoint_filename_with_path):
      return sess, False

    if checkpoint_filename_with_path:
      _restore_checkpoint_and_maybe_run_saved_model_initializers(
          sess, saver, checkpoint_filename_with_path)
      return sess, True

    # Waits up until max_wait_secs for checkpoint to become available.
    wait_time = 0
    ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
    while not ckpt or not ckpt.model_checkpoint_path:
      if wait_for_checkpoint and wait_time < max_wait_secs:
        logging.info("Waiting for checkpoint to be available.")
        time.sleep(self._recovery_wait_secs)
        wait_time += self._recovery_wait_secs
        ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir)
      else:
        return sess, False

    # Loads the checkpoint.
    _restore_checkpoint_and_maybe_run_saved_model_initializers(
        sess, saver, ckpt.model_checkpoint_path)
    saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)
    return sess, True
    def testSaveRestoreState(self, mock_time):
        directory = self.get_temp_dir()
        mock_time.time.return_value = 3.
        checkpoint = util.Checkpoint()
        first_manager = checkpoint_management.CheckpointManager(checkpoint,
                                                                directory,
                                                                max_to_keep=2)
        first_time = 10000.
        first_name = os.path.join(directory, "ckpt-1")
        mock_time.time.return_value = first_time
        first_manager.save()
        state = checkpoint_management.get_checkpoint_state(directory)
        second_time = first_time + 3610.
        second_name = os.path.join(directory, "ckpt-2")
        mock_time.time.return_value = second_time
        first_manager.save()
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual([first_time, second_time],
                         state.all_model_checkpoint_timestamps)
        self.assertEqual([first_name, second_name], first_manager.checkpoints)
        self.assertEqual(second_name, first_manager.latest_checkpoint)
        del first_manager

        second_manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory,
            max_to_keep=2,
            keep_checkpoint_every_n_hours=1.5)
        self.assertEqual([first_name, second_name], second_manager.checkpoints)
        self.assertEqual(second_name, second_manager.latest_checkpoint)
        third_name = os.path.join(directory, "ckpt-3")
        third_time = second_time + 3600. * 0.2
        mock_time.time.return_value = third_time
        second_manager.save()
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(second_name))
        self.assertEqual([second_name, third_name], second_manager.checkpoints)
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(first_time, state.last_preserved_timestamp)
        fourth_time = third_time + 3600. * 0.5
        mock_time.time.return_value = fourth_time
        fourth_name = os.path.join(directory, "ckpt-4")
        second_manager.save()
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
        self.assertEqual([third_name, fourth_name], second_manager.checkpoints)
        fifth_time = fourth_time + 3600. * 0.5
        mock_time.time.return_value = fifth_time
        fifth_name = os.path.join(directory, "ckpt-5")
        second_manager.save()
        self.assertEqual([fourth_name, fifth_name], second_manager.checkpoints)
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(first_time, state.last_preserved_timestamp)
        del second_manager
        third_manager = checkpoint_management.CheckpointManager(
            checkpoint,
            directory,
            max_to_keep=2,
            keep_checkpoint_every_n_hours=1.5)
        self.assertEqual(fifth_name, third_manager.latest_checkpoint)
        mock_time.time.return_value += 10.
        third_manager.save()
        sixth_name = os.path.join(directory, "ckpt-6")
        state = checkpoint_management.get_checkpoint_state(directory)
        self.assertEqual(fourth_time, state.last_preserved_timestamp)
        self.assertTrue(checkpoint_management.checkpoint_exists(first_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(fourth_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(fifth_name))
        self.assertTrue(checkpoint_management.checkpoint_exists(sixth_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(second_name))
        self.assertFalse(checkpoint_management.checkpoint_exists(third_name))
        self.assertEqual([fifth_name, sixth_name], third_manager.checkpoints)