def testMultipleRemoveDirToWritePathIsFine(self): temp_dir = self.get_temp_dir() strategy = DistributedFileUtilsTest.MockedWorkerStrategy() dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy) file_to_write = os.path.join(dir_to_write, 'tmp') self._write_dummy_file(file_to_write) distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy) distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy) distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy)
def testChiefWriteDirAndFilePath(self): dirpath = self.get_temp_dir() filepath = os.path.join(dirpath, 'foo.bar') strategy = DistributedFileUtilsTest.MockedChiefStrategy() self.assertEqual( distributed_file_utils.write_filepath(filepath, strategy), filepath) self.assertEqual( distributed_file_utils.write_dirpath(dirpath, strategy), dirpath)
def testWorkerWriteDirAndFilePath(self): dirpath = self.get_temp_dir() filepath = os.path.join(dirpath, 'foo.bar') strategy = DistributedFileUtilsTest.MockedWorkerStrategy() self.assertEqual( distributed_file_utils.write_filepath(filepath, strategy), os.path.join(dirpath, 'workertemp_3', 'foo.bar')) self.assertEqual( distributed_file_utils.write_dirpath(dirpath, strategy), os.path.join(dirpath, 'workertemp_3'))
def testWorkerDoesRemoveDirPath(self): temp_dir = self.get_temp_dir() strategy = DistributedFileUtilsTest.MockedWorkerStrategy() dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy) file_to_write = os.path.join(dir_to_write, 'tmp') self.assertFalse(os.path.exists(file_to_write)) self._write_dummy_file(file_to_write) self.assertTrue(os.path.exists(file_to_write)) distributed_file_utils.remove_temp_dirpath(temp_dir, strategy) self.assertFalse(os.path.exists(file_to_write))
def __init__(self, model, checkpoint_dir): self._model = model # The epoch at which the checkpoint is saved. Used for fault-tolerance. # GPU device only has int64 dtype registered VarHandleOp. self._ckpt_saved_epoch = variables.Variable( initial_value=constant_op.constant(CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64), name='ckpt_saved_epoch') # Variable initialization. K.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE) # _ckpt_saved_epoch gets tracked and is included in the checkpoint file # when backing up. checkpoint = trackable_util.Checkpoint( model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch) # If this is single-worker training, checkpoint_dir are the same for # write_checkpoint_manager and read_checkpoint_manager. # # If this is multi-worker training, and this worker should not # save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir # with a temp filepath, so it writes to a file that will be removed at the # end of back_up() call. This is necessary because the SyncOnReadVariable # needs to be synced across all the workers in order to be read, and all # workers need to perform `save()`. # But all workers should restore from the same checkpoint_dir as passed in # read_checkpoint_manager. self.write_checkpoint_dir = distributed_file_utils.write_dirpath( checkpoint_dir, None)#self._model.distribute_strategy) self.write_checkpoint_manager = CheckpointManager(checkpoint, directory=self.write_checkpoint_dir, max_to_keep=1, checkpoint_interval=1) if self.write_checkpoint_dir == checkpoint_dir: self.read_checkpoint_manager = self.write_checkpoint_manager else: self.read_checkpoint_manager = CheckpointManager(checkpoint, directory=checkpoint_dir, max_to_keep=1)