def testUpdateCheckpointStateSaveRelativePaths(self): save_dir = self._get_test_dir("update_checkpoint_state") os.chdir(save_dir) abs_path2 = os.path.join(save_dir, "model-2") rel_path2 = "model-2" abs_path0 = os.path.join(save_dir, "model-0") rel_path0 = "model-0" checkpoint_management.update_checkpoint_state_internal( save_dir=save_dir, model_checkpoint_path=abs_path2, all_model_checkpoint_paths=[rel_path0, abs_path2], save_relative_paths=True) # File should contain relative paths. file_content = file_io.read_file_to_string( os.path.join(save_dir, "checkpoint")) ckpt = CheckpointState() text_format.Merge(file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, rel_path2) self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) # get_checkpoint_state should return absolute paths. ckpt = checkpoint_management.get_checkpoint_state(save_dir) self.assertEqual(ckpt.model_checkpoint_path, abs_path2) self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
def testUpdateCheckpointStateSaveRelativePaths(self): save_dir = self._get_test_dir("update_checkpoint_state") os.chdir(save_dir) abs_path2 = os.path.join(save_dir, "model-2") rel_path2 = "model-2" abs_path0 = os.path.join(save_dir, "model-0") rel_path0 = "model-0" checkpoint_management.update_checkpoint_state_internal( save_dir=save_dir, model_checkpoint_path=abs_path2, all_model_checkpoint_paths=[rel_path0, abs_path2], save_relative_paths=True) # File should contain relative paths. file_content = file_io.read_file_to_string( os.path.join(save_dir, "checkpoint")) ckpt = CheckpointState() text_format.Merge(file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, rel_path2) self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) self.assertEqual(ckpt.all_model_checkpoint_paths[-1], rel_path2) self.assertEqual(ckpt.all_model_checkpoint_paths[0], rel_path0) # get_checkpoint_state should return absolute paths. ckpt = checkpoint_management.get_checkpoint_state(save_dir) self.assertEqual(ckpt.model_checkpoint_path, abs_path2) self.assertEqual(len(ckpt.all_model_checkpoint_paths), 2) self.assertEqual(ckpt.all_model_checkpoint_paths[-1], abs_path2) self.assertEqual(ckpt.all_model_checkpoint_paths[0], abs_path0)
def save(self, file_prefix): """Saves a training checkpoint and provides basic checkpoint management. The saved checkpoint includes variables created by this object and any trackable objects it depends on at the time `Checkpoint.save()` is called. `save` is a basic convenience wrapper around the `write` method, sequentially numbering checkpoints using `save_counter` and updating the metadata used by `tf.train.latest_checkpoint`. More advanced checkpoint management, for example garbage collection and custom numbering, may be provided by other utilities which also wrap `write` (`tf.train.CheckpointManager` for example). Args: file_prefix: A prefix to use for the checkpoint filenames (/path/to/directory/and_a_prefix). Names are generated based on this prefix and `Checkpoint.save_counter`. Returns: The full path to the checkpoint. """ graph_building = not context.executing_eagerly() if graph_building: if ops.inside_function(): raise NotImplementedError( "Calling tf.train.Checkpoint.save() from a function is not " "supported, as save() modifies saving metadata in ways not " "supported by TensorFlow Operations. Consider using " "tf.train.Checkpoint.write(), a lower-level API which does not " "update metadata. tf.train.latest_checkpoint and related APIs will " "not see this checkpoint.") session = get_session() if self._save_counter is None: # When graph building, if this is a new save counter variable then it # needs to be initialized before assign_add. This is only an issue if # restore() has not been called first. session.run(self.save_counter.initializer) if not graph_building or self._save_assign_op is None: with ops.colocate_with(self.save_counter): assign_op = self.save_counter.assign_add(1, read_value=True) if graph_building: self._save_assign_op = data_structures.NoDependency(assign_op) file_path = self.write(file_prefix) checkpoint_management.update_checkpoint_state_internal( save_dir=os.path.dirname(file_prefix), model_checkpoint_path=file_path, all_model_checkpoint_paths=[file_path], save_relative_paths=True) return file_path
def commit(self, prefix, session): """ Commit the latest checkpoint. """ if self._cached_checkpoint == None: if len(self._checkpoints) > 0: return self._checkpoints[-1] else: return "" if len(self._checkpoints) == self._max_to_keep: for filename in self._get_checkpoint_filenames( self._checkpoints.pop(0)): os.remove(filename) # Replication from checkpoint.save if self._checkpoint._save_counter is None: session.run(self._checkpoint.save_counter.initializer) if self._checkpoint._save_assign_op is None: self._checkpoint._save_assign_op = data_structures.NoDependency( self._checkpoint.save_counter.assign_add(1, read_value=True)) checkpoint_count = session.run(self._checkpoint._save_assign_op) filename_prefix = "%s-%d" % (prefix, checkpoint_count) for filename in self._get_checkpoint_filenames( self._cached_checkpoint): # Change prefix os.rename( filename, filename.replace(self._cached_checkpoint, filename_prefix)) self._checkpoints.append(filename_prefix) self._cached_checkpoint = None # Update checkpoint state file (@tf.train.latest_checkpoint) checkpoint_management.update_checkpoint_state_internal( self._directory, self._checkpoints[-1], self._checkpoints) return filename_prefix