def __init__(self, model, checkpoint_dir, save_freq="epoch"): self._model = model self._save_freq = save_freq # The batch and epoch at which the checkpoint is saved. Used for # fault-tolerance. GPU device only has int64 dtype registered # VarHandleOp. self._ckpt_saved_epoch = tf.Variable( initial_value=tf.constant( self.CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=tf.int64 ), name="ckpt_saved_epoch", ) self._ckpt_saved_batch = tf.Variable( initial_value=tf.constant( self.CKPT_SAVED_BATCH_UNUSED_VALUE, dtype=tf.int64 ), name="ckpt_saved_batch", ) # Variable initialization. backend.set_value( self._ckpt_saved_epoch, self.CKPT_SAVED_EPOCH_UNUSED_VALUE ) backend.set_value( self._ckpt_saved_batch, self.CKPT_SAVED_BATCH_UNUSED_VALUE ) # _ckpt_saved_epoch and _ckpt_saved_batch gets tracked and is included # in the checkpoint file when backing up. checkpoint = tf.train.Checkpoint( model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch, ckpt_saved_batch=self._ckpt_saved_batch, train_counter=self._model._train_counter, ) # If this is single-worker training, checkpoint_dir are the same for # write_checkpoint_manager and read_checkpoint_manager. # # If this is multi-worker training, and this worker should not save # checkpoint, we replace the write_checkpoint_manager's checkpoint_dir # with a temp filepath, so it writes to a file that will be removed at # the end of back_up() call. This is necessary because the # SyncOnReadVariable needs to be synced across all the workers in order # to be read, and all workers need to perform `save()`. But all workers # should restore from the same checkpoint_dir as passed in # read_checkpoint_manager. self.read_checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=os.path.join(checkpoint_dir, "chief"), max_to_keep=1, ) write_checkpoint_dir = distributed_file_utils.write_dirpath( checkpoint_dir, self._model.distribute_strategy ) if self._model.distribute_strategy.extended.should_checkpoint: self.write_checkpoint_manager = self.read_checkpoint_manager else: self.write_checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=write_checkpoint_dir, max_to_keep=1 )
def testChiefWriteDirAndFilePath(self): dirpath = self.get_temp_dir() filepath = os.path.join(dirpath, 'foo.bar') strategy = DistributedFileUtilsTest.MockedChiefStrategy() self.assertEqual( distributed_file_utils.write_filepath(filepath, strategy), filepath) self.assertEqual( distributed_file_utils.write_dirpath(dirpath, strategy), dirpath)
def testMultipleRemoveDirToWritePathIsFine(self): temp_dir = self.get_temp_dir() strategy = DistributedFileUtilsTest.MockedWorkerStrategy() dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy) file_to_write = os.path.join(dir_to_write, "tmp") self._write_dummy_file(file_to_write) distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy) distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy) distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy)
def testWorkerWriteDirAndFilePath(self): dirpath = self.get_temp_dir() filepath = os.path.join(dirpath, 'foo.bar') strategy = DistributedFileUtilsTest.MockedWorkerStrategy() self.assertEqual( distributed_file_utils.write_filepath(filepath, strategy), os.path.join(dirpath, 'workertemp_3', 'foo.bar')) self.assertEqual( distributed_file_utils.write_dirpath(dirpath, strategy), os.path.join(dirpath, 'workertemp_3'))
def testWorkerDoesRemoveDirPath(self): temp_dir = self.get_temp_dir() strategy = DistributedFileUtilsTest.MockedWorkerStrategy() dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy) file_to_write = os.path.join(dir_to_write, "tmp") self.assertFalse(os.path.exists(file_to_write)) self._write_dummy_file(file_to_write) self.assertTrue(os.path.exists(file_to_write)) distributed_file_utils.remove_temp_dirpath(temp_dir, strategy) self.assertFalse(os.path.exists(file_to_write)) self.assertFalse(os.path.exists(os.path.dirname(file_to_write)))
def testChiefDoesNotRemoveDirAndFilePath(self): temp_dir = self.get_temp_dir() strategy = DistributedFileUtilsTest.MockedChiefStrategy() dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy) file_to_write = os.path.join(dir_to_write, 'tmp') self.assertFalse(os.path.exists(file_to_write)) self._write_dummy_file(file_to_write) self.assertTrue(os.path.exists(file_to_write)) distributed_file_utils.remove_temp_dir_with_filepath( file_to_write, strategy) self.assertTrue(os.path.exists(file_to_write))