def testUpdateCheckpointStateSaveRelativePaths(self):
        save_dir = self._get_test_dir("update_checkpoint_state")
        os.chdir(save_dir)
        abs_path2 = os.path.join(save_dir, "model-2")
        rel_path2 = "model-2"
        abs_path0 = os.path.join(save_dir, "model-0")
        rel_path0 = "model-0"
        checkpoint_management.update_checkpoint_state_internal(
            save_dir=save_dir,
            model_checkpoint_path=abs_path2,
            all_model_checkpoint_paths=[rel_path0, abs_path2],
            save_relative_paths=True)

        # File should contain relative paths.
        file_content = file_io.read_file_to_string(
            os.path.join(save_dir, "checkpoint"))
        ckpt = CheckpointState()
        text_format.Merge(file_content, ckpt)
        self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
        self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
        self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
        self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)

        # get_checkpoint_state should return absolute paths.
        ckpt = checkpoint_management.get_checkpoint_state(save_dir)
        self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
        self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
        self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
        self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
  def testUpdateCheckpointStateSaveRelativePaths(self):
    save_dir = self._get_test_dir("update_checkpoint_state")
    os.chdir(save_dir)
    abs_path2 = os.path.join(save_dir, "model-2")
    rel_path2 = "model-2"
    abs_path0 = os.path.join(save_dir, "model-0")
    rel_path0 = "model-0"
    checkpoint_management.update_checkpoint_state_internal(
        save_dir=save_dir,
        model_checkpoint_path=abs_path2,
        all_model_checkpoint_paths=[rel_path0, abs_path2],
        save_relative_paths=True)

    # File should contain relative paths.
    file_content = file_io.read_file_to_string(
        os.path.join(save_dir, "checkpoint"))
    ckpt = CheckpointState()
    text_format.Merge(file_content, ckpt)
    self.assertEqual(ckpt.model_checkpoint_path, rel_path2)
    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2)
    self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0)

    # get_checkpoint_state should return absolute paths.
    ckpt = checkpoint_management.get_checkpoint_state(save_dir)
    self.assertEqual(ckpt.model_checkpoint_path, abs_path2)
    self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2)
    self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2)
    self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
 def save(self, file_prefix):
     """Saves a training checkpoint and provides basic checkpoint management.
     The saved checkpoint includes variables created by this object and any
     trackable objects it depends on at the time `Checkpoint.save()` is
     called.
     `save` is a basic convenience wrapper around the `write` method,
     sequentially numbering checkpoints using `save_counter` and updating the
     metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint
     management, for example garbage collection and custom numbering, may be
     provided by other utilities which also wrap `write`
     (`tf.train.CheckpointManager` for example).
     Args:
     file_prefix: A prefix to use for the checkpoint filenames
             (/path/to/directory/and_a_prefix). Names are generated based on this
             prefix and `Checkpoint.save_counter`.
     Returns:
     The full path to the checkpoint.
     """
     graph_building = not context.executing_eagerly()
     if graph_building:
         if ops.inside_function():
             raise NotImplementedError(
                 "Calling tf.train.Checkpoint.save() from a function is not "
                 "supported, as save() modifies saving metadata in ways not "
                 "supported by TensorFlow Operations. Consider using "
                 "tf.train.Checkpoint.write(), a lower-level API which does not "
                 "update metadata. tf.train.latest_checkpoint and related APIs will "
                 "not see this checkpoint.")
         session = get_session()
         if self._save_counter is None:
             # When graph building, if this is a new save counter variable then it
             # needs to be initialized before assign_add. This is only an issue if
             # restore() has not been called first.
             session.run(self.save_counter.initializer)
     if not graph_building or self._save_assign_op is None:
         with ops.colocate_with(self.save_counter):
             assign_op = self.save_counter.assign_add(1, read_value=True)
         if graph_building:
             self._save_assign_op = data_structures.NoDependency(assign_op)
     file_path = self.write(file_prefix)
     checkpoint_management.update_checkpoint_state_internal(
         save_dir=os.path.dirname(file_prefix),
         model_checkpoint_path=file_path,
         all_model_checkpoint_paths=[file_path],
         save_relative_paths=True)
     return file_path
Esempio n. 4
0
    def commit(self, prefix, session):
        """
        Commit the latest checkpoint.
        """

        if self._cached_checkpoint == None:
            if len(self._checkpoints) > 0:
                return self._checkpoints[-1]
            else:
                return ""

        if len(self._checkpoints) == self._max_to_keep:
            for filename in self._get_checkpoint_filenames(
                    self._checkpoints.pop(0)):
                os.remove(filename)

        # Replication from checkpoint.save
        if self._checkpoint._save_counter is None:
            session.run(self._checkpoint.save_counter.initializer)
        if self._checkpoint._save_assign_op is None:
            self._checkpoint._save_assign_op = data_structures.NoDependency(
                self._checkpoint.save_counter.assign_add(1, read_value=True))

        checkpoint_count = session.run(self._checkpoint._save_assign_op)
        filename_prefix = "%s-%d" % (prefix, checkpoint_count)

        for filename in self._get_checkpoint_filenames(
                self._cached_checkpoint):
            # Change prefix
            os.rename(
                filename,
                filename.replace(self._cached_checkpoint, filename_prefix))

        self._checkpoints.append(filename_prefix)
        self._cached_checkpoint = None
        # Update checkpoint state file (@tf.train.latest_checkpoint)
        checkpoint_management.update_checkpoint_state_internal(
            self._directory, self._checkpoints[-1], self._checkpoints)
        return filename_prefix