def testMultipleRemoveDirToWritePathIsFine(self):
     temp_dir = self.get_temp_dir()
     strategy = DistributedFileUtilsTest.MockedWorkerStrategy()
     dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy)
     file_to_write = os.path.join(dir_to_write, 'tmp')
     self._write_dummy_file(file_to_write)
     distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy)
     distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy)
     distributed_file_utils.remove_temp_dirpath(dir_to_write, strategy)
 def testChiefWriteDirAndFilePath(self):
     dirpath = self.get_temp_dir()
     filepath = os.path.join(dirpath, 'foo.bar')
     strategy = DistributedFileUtilsTest.MockedChiefStrategy()
     self.assertEqual(
         distributed_file_utils.write_filepath(filepath, strategy),
         filepath)
     self.assertEqual(
         distributed_file_utils.write_dirpath(dirpath, strategy), dirpath)
 def testWorkerWriteDirAndFilePath(self):
     dirpath = self.get_temp_dir()
     filepath = os.path.join(dirpath, 'foo.bar')
     strategy = DistributedFileUtilsTest.MockedWorkerStrategy()
     self.assertEqual(
         distributed_file_utils.write_filepath(filepath, strategy),
         os.path.join(dirpath, 'workertemp_3', 'foo.bar'))
     self.assertEqual(
         distributed_file_utils.write_dirpath(dirpath, strategy),
         os.path.join(dirpath, 'workertemp_3'))
示例#4
0
 def testWorkerDoesRemoveDirPath(self):
     temp_dir = self.get_temp_dir()
     strategy = DistributedFileUtilsTest.MockedWorkerStrategy()
     dir_to_write = distributed_file_utils.write_dirpath(temp_dir, strategy)
     file_to_write = os.path.join(dir_to_write, 'tmp')
     self.assertFalse(os.path.exists(file_to_write))
     self._write_dummy_file(file_to_write)
     self.assertTrue(os.path.exists(file_to_write))
     distributed_file_utils.remove_temp_dirpath(temp_dir, strategy)
     self.assertFalse(os.path.exists(file_to_write))
示例#5
0
    def __init__(self, model, checkpoint_dir):
        self._model = model

        # The epoch at which the checkpoint is saved. Used for fault-tolerance.
        # GPU device only has int64 dtype registered VarHandleOp.
        self._ckpt_saved_epoch = variables.Variable(
                    initial_value=constant_op.constant(CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64),
                    name='ckpt_saved_epoch')

        # Variable initialization.
        K.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE)

        # _ckpt_saved_epoch gets tracked and is included in the checkpoint file
        # when backing up.
        checkpoint = trackable_util.Checkpoint(
                model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch)

        # If this is single-worker training, checkpoint_dir are the same for
        # write_checkpoint_manager and read_checkpoint_manager.
        #
        # If this is multi-worker training, and this worker should not
        # save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
        # with a temp filepath, so it writes to a file that will be removed at the
        # end of back_up() call. This is necessary because the SyncOnReadVariable
        # needs to be synced across all the workers in order to be read, and all
        # workers need to perform `save()`.
        # But all workers should restore from the same checkpoint_dir as passed in
        # read_checkpoint_manager.
        self.write_checkpoint_dir = distributed_file_utils.write_dirpath(
                                                    checkpoint_dir, None)#self._model.distribute_strategy)
        self.write_checkpoint_manager = CheckpointManager(checkpoint,
                                                          directory=self.write_checkpoint_dir,
                                                          max_to_keep=1,
                                                          checkpoint_interval=1)
        if self.write_checkpoint_dir == checkpoint_dir:
            self.read_checkpoint_manager = self.write_checkpoint_manager
        else:
            self.read_checkpoint_manager = CheckpointManager(checkpoint,
                                                             directory=checkpoint_dir,
                                                             max_to_keep=1)